reviewer-arena / models.py
openreviewer's picture
Upload folder using huggingface_hub
9342d22 verified
import os
import logging
import openai
import tiktoken
import re
import anthropic
import cohere
import google.generativeai as genai
import time
from file_utils import read_file
from openai import OpenAI
class Paper:
def __init__(self, arxiv_id, tex_file):
self.arxiv_id = arxiv_id
self.tex_file = tex_file
class PaperProcessor:
MAX_TOKENS = 127192
encoding = tiktoken.encoding_for_model("gpt-4")
def __init__(self, prompt_dir, model, openai_api_key, claude_api_key, gemini_api_key, commandr_api_key):
self.prompt_dir = prompt_dir
self.model = model
self.openai_api_key = os.environ.get('OPENAI_API_KEY')
self.claude_api_key = os.environ.get('ANTHROPIC_API_KEY')
self.gemini_api_key = os.environ.get('GEMINI_API_KEY')
self.commandr_api_key = os.environ.get('COMMANDR_API_KEY')
def count_tokens(self, text):
return len(self.encoding.encode(text))
def truncate_content(self, content):
token_count = self.count_tokens(content)
logging.debug(f"Token count before truncation: {token_count}")
if token_count > self.MAX_TOKENS:
tokens = self.encoding.encode(content)
truncated_tokens = tokens[:self.MAX_TOKENS]
truncated_content = self.encoding.decode(truncated_tokens)
logging.debug(f"Content truncated. Token count after truncation: {self.count_tokens(truncated_content)}")
return truncated_content
return content
def prepare_base_prompt(self, paper):
logging.debug(f"Preparing base prompt for paper: {paper.arxiv_id}")
logging.debug(f"Paper content: {paper.tex_file[:500]}") # Log the first 500 characters
return paper.tex_file
def call_model(self, prompt, model_type):
system_role_file_path = os.path.join(self.prompt_dir, "systemrole.txt")
if not os.path.exists(system_role_file_path):
logging.error(f"System role file not found: {system_role_file_path}")
return None
system_role = read_file(system_role_file_path)
logging.debug(f"Token count of full prompt: {self.count_tokens(prompt)}")
logging.info(f"Sending the following prompt to {model_type}: {prompt}")
try:
if model_type == 'gpt-4-turbo-2024-04-09':
client = OpenAI(api_key=self.openai_api_key)
messages = [{"role": "system", "content": system_role}, {"role": "user", "content": prompt}]
completion = client.chat.completions.create(
model="gpt-4-turbo-2024-04-09",
messages=messages,
temperature=1
)
print(completion)
return completion.choices[0].message.content.strip()
elif model_type == 'gpt-4o':
client = OpenAI(api_key=self.openai_api_key)
messages = [{"role": "system", "content": system_role}, {"role": "user", "content": prompt}]
completion = client.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=1
)
print(completion)
return completion.choices[0].message.content.strip()
elif model_type == 'claude-3-opus-20240229':
client = anthropic.Anthropic(api_key=self.claude_api_key)
response = client.messages.create(
model='claude-3-opus-20240229',
max_tokens=4096,
system=system_role,
temperature=0.5,
messages=[{"role": "user", "content": prompt}]
)
print(response)
return response.content[0].text
elif model_type == 'command-r-plus':
co = cohere.Client(self.commandr_api_key)
response = co.chat(
model="command-r-plus",
message=prompt,
preamble=system_role
)
print(response)
return response.text
elif model_type == 'gemini-pro':
genai.configure(api_key=self.gemini_api_key)
model = genai.GenerativeModel('gemini-pro')
response = model.generate_content(prompt)
print(response)
return response.candidates[0].content.parts[0].text
except Exception as e:
logging.error(f"Exception occurred: {e}")
print(e)
return None
def is_content_appropriate(self, content):
try:
response = openai.moderations.create(input=content)
return not response["results"][0]["flagged"]
except Exception as e:
logging.error(f"Exception occurred while checking content appropriateness: {e}")
return True # In case of an error, default to content being appropriate
def get_prompt_files(self, prompt_dir):
return [f for f in os.listdir(prompt_dir) if f.endswith('.txt') and f.startswith('question')]
def process_paper(self, paper):
openai.api_key = self.openai_api_key
start_time = time.time()
base_prompt = self.prepare_base_prompt(paper)
print("BASE PROMPT:", base_prompt)
if base_prompt is None:
return "Error: Base prompt could not be prepared."
moderation_response = openai.moderations.create(input=base_prompt)
if moderation_response.results[0].flagged:
return ["Desk Rejected", "The paper contains inappropriate or harmful content."]
review_output = []
previous_responses = []
header = ['Summary:', 'Soundness:', 'Presentation:', 'Contribution:', 'Strengths:', 'Weaknesses:', 'Questions:', 'Flag For Ethics Review:', 'Rating:', 'Confidence:', 'Code Of Conduct:']
for i in range(1, 12):
question_file = os.path.join(self.prompt_dir, f"question{i}.txt")
question_text = read_file(question_file)
if i == 1:
prompt = f"{question_text}\n\n####\n{base_prompt}\n####"
else:
prompt = f"\nHere is your review so far:\n{' '.join(previous_responses)}\n\nHere are your reviewer instructions. Please answer the following question:\n{question_text}"
truncated_prompt = self.truncate_content(prompt)
logging.info(f"Processing prompt for question {i}")
response = self.call_model(truncated_prompt, self.model)
if response is None:
response = "N/A"
if i in [2, 3, 4, 10]:
number_match = re.search(r'\b\d+\b', response)
if number_match:
number = int(number_match.group(0))
response = '5/5' if number > 5 else number_match.group(0) + '/5'
elif i == 9:
number_match = re.search(r'\b\d+\b', response)
if number_match:
response = number_match.group(0) + '/10'
response_with_header = f"{header[i-1]} {response}"
review_output.append(response_with_header)
previous_responses.append(response)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Time taken to process paper: {elapsed_time:.2f} seconds")
return review_output