File size: 6,498 Bytes
25f01d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import logging
import openai
import tiktoken
import re
import anthropic
import cohere
import google.generativeai as genai
import time
from file_utils import read_file
from openai import OpenAI

class Paper:
    def __init__(self, arxiv_id, tex_file):
        self.arxiv_id = arxiv_id
        self.tex_file = tex_file

class PaperProcessor:
    MAX_TOKENS = 127192
    encoding = tiktoken.encoding_for_model("gpt-4-0125-preview")

    def __init__(self, prompt_dir, model, openai_api_key, claude_api_key, gemini_api_key, commandr_api_key):
        self.prompt_dir = prompt_dir
        self.model = model
        self.openai_api_key = openai_api_key
        self.claude_api_key = claude_api_key
        self.gemini_api_key = gemini_api_key
        self.commandr_api_key = commandr_api_key

    def count_tokens(self, text):
        return len(self.encoding.encode(text))

    def truncate_content(self, content):
        token_count = self.count_tokens(content)
        logging.debug(f"Token count before truncation: {token_count}")
        if token_count > self.MAX_TOKENS:
            tokens = self.encoding.encode(content)
            truncated_tokens = tokens[:self.MAX_TOKENS]
            truncated_content = self.encoding.decode(truncated_tokens)
            logging.debug(f"Content truncated. Token count after truncation: {self.count_tokens(truncated_content)}")
            return truncated_content
        return content

    def prepare_base_prompt(self, paper):
        return paper.tex_file

    def call_model(self, prompt, model_type):
        system_role_file_path = os.path.join(self.prompt_dir, "systemrole.txt")
        if not os.path.exists(system_role_file_path):
            logging.error(f"System role file not found: {system_role_file_path}")
            return None

        system_role = read_file(system_role_file_path)
        logging.debug(f"Token count of full prompt: {self.count_tokens(prompt)}")
        logging.info(f"Sending the following prompt to {model_type}: {prompt}")

        try:
            if model_type == 'gpt':
                client = OpenAI(api_key=self.openai_api_key)
                messages = [{"role": "system", "content": system_role}, {"role": "user", "content": prompt}]
                completion = client.chat.completions.create(
                    model="gpt-4-turbo-2024-04-09",
                    messages=messages,
                    temperature=1
                )
                return completion.choices[0].message.content.strip()

            elif model_type == 'claude':
                client = anthropic.Anthropic(api_key=self.claude_api_key)
                response = client.messages.create(
                    model='claude-3-opus-20240229',
                    max_tokens=4096,
                    system=system_role,
                    temperature=0.5, 
                    messages=[{"role": "user", "content": prompt}]
                )
                return response.content[0].text

            elif model_type == 'commandr':
                co = cohere.Client(self.commandr_api_key)
                response = co.chat(
                    model="command-r-plus",
                    message=prompt,
                    preamble=system_role
                )
                return response.text

            elif model_type == 'gemini':
                genai.configure(api_key=self.gemini_api_key)
                model = genai.GenerativeModel('gemini-pro')
                response = model.generate_content(prompt)
                return response.candidates[0].content.parts[0].text

        except Exception as e:
            logging.error(f"Exception occurred: {e}")
            return None

    def is_content_appropriate(self, content):
        try:
            response = openai.moderations.create(input=content)
            return not response["results"][0]["flagged"]
        except Exception as e:
            logging.error(f"Exception occurred while checking content appropriateness: {e}")
            return True  # In case of an error, default to content being appropriate
    
    def get_prompt_files(self, prompt_dir):
        return [f for f in os.listdir(prompt_dir) if f.endswith('.txt') and f.startswith('question')]

    def process_paper(self, paper):
        openai.api_key = self.openai_api_key
        start_time = time.time()

        base_prompt = self.prepare_base_prompt(paper)
        if base_prompt is None:
            return "Error: Base prompt could not be prepared."

        moderation_response = openai.moderations.create(input=base_prompt)
        if moderation_response.results[0].flagged:
            return ["Desk Rejected", "The paper contains inappropriate or harmful content."]

        review_output = []
        previous_responses = []
        header = ['Summary:', 'Soundness:', 'Presentation:', 'Contribution:', 'Strengths:', 'Weaknesses:', 'Questions:', 'Flag For Ethics Review:', 'Rating:', 'Confidence:', 'Code Of Conduct:']
        for i in range(1, 12):
            question_file = os.path.join(self.prompt_dir, f"question{i}.txt")
            question_text = read_file(question_file)

            if i == 1:
                prompt = f"{question_text}\n\n####\n{base_prompt}\n####"
            else:
                prompt = f"\nHere is your review so far:\n{' '.join(previous_responses)}\n\nHere are your reviewer instructions. Please answer the following question:\n{question_text}"

            truncated_prompt = self.truncate_content(prompt)
            logging.info(f"Processing prompt for question {i}")

            response = self.call_model(truncated_prompt, self.model)
            if response is None:
                response = "N/A"

            if i in [2, 3, 4, 10]:
                number_match = re.search(r'\b\d+\b', response)
                if number_match:
                    number = int(number_match.group(0))
                    response = '5/5' if number > 5 else number_match.group(0) + '/5'
            elif i == 9:
                number_match = re.search(r'\b\d+\b', response)
                if number_match:
                    response = number_match.group(0) + '/10'

            response_with_header = f"{header[i-1]} {response}"
            review_output.append(response_with_header)
            previous_responses.append(response)

        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Time taken to process paper: {elapsed_time:.2f} seconds")
        return review_output