|
import re |
|
import time |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
import openai |
|
import pandas as pd |
|
from tenacity import retry, stop_after_attempt, wait_random_exponential |
|
from tqdm import tqdm |
|
|
|
|
|
class AutoTab: |
|
def __init__( |
|
self, |
|
in_file_path: str, |
|
out_file_path: str, |
|
instruction: str, |
|
max_examples: int, |
|
model_name: str, |
|
generation_config: dict, |
|
request_interval: float, |
|
save_every: int, |
|
api_keys: list[str], |
|
base_url: str, |
|
): |
|
self.in_file_path = in_file_path |
|
self.out_file_path = out_file_path |
|
self.instruction = instruction |
|
self.max_examples = max_examples |
|
self.model_name = model_name |
|
self.generation_config = generation_config |
|
self.request_interval = request_interval |
|
self.save_every = save_every |
|
self.api_keys = api_keys |
|
self.base_url = base_url |
|
|
|
self.request_count = 0 |
|
self.failed_count = 0 |
|
self.data, self.input_fields, self.output_fields = self.load_excel() |
|
self.in_context = self.derive_incontext() |
|
self.num_data = len(self.data) |
|
self.num_example = len(self.data.dropna(subset=self.output_fields)) |
|
self.num_missing = self.num_data - self.num_example |
|
|
|
|
|
|
|
def load_excel(self) -> tuple[pd.DataFrame, list, list]: |
|
"""Load the Excel file and identify input and output fields.""" |
|
df = pd.read_excel(self.in_file_path) |
|
input_fields = [col for col in df.columns if col.startswith("[Input] ")] |
|
output_fields = [col for col in df.columns if col.startswith("[Output] ")] |
|
return df, input_fields, output_fields |
|
|
|
|
|
|
|
@retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6)) |
|
def openai_request(self, query: str) -> str: |
|
"""Make a request to an OpenAI-format API.""" |
|
|
|
|
|
time.sleep(self.request_interval) |
|
|
|
|
|
api_key = self.api_keys[self.request_count % len(self.api_keys)] |
|
self.request_count += 1 |
|
|
|
client = openai.OpenAI(api_key=api_key, base_url=self.base_url) |
|
response = client.chat.completions.create( |
|
model=self.model_name, |
|
messages=[{"role": "user", "content": query}], |
|
**self.generation_config, |
|
) |
|
str_response = response.choices[0].message.content.strip() |
|
return str_response |
|
|
|
|
|
|
|
def derive_incontext(self) -> str: |
|
"""Derive the in-context prompt with angle brackets.""" |
|
examples = self.data.dropna(subset=self.output_fields)[: self.max_examples] |
|
in_context = "" |
|
for i in range(len(examples)): |
|
in_context += "".join( |
|
f"<{col.replace('[Input] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n" |
|
for col in self.input_fields |
|
) |
|
in_context += "".join( |
|
f"<{col.replace('[Output] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n" |
|
for col in self.output_fields |
|
) |
|
in_context += "\n" |
|
return in_context |
|
|
|
def predict_output(self, input_data: pd.DataFrame): |
|
"""Predict the output values for the given input data using the API.""" |
|
query = ( |
|
self.instruction |
|
+ "\n\n" |
|
+ self.in_context |
|
+ "".join( |
|
f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n" |
|
for col in self.input_fields |
|
) |
|
) |
|
self.query_example = query |
|
output = self.openai_request(query) |
|
return output |
|
|
|
def extract_fields(self, response: str) -> dict[str, str]: |
|
"""Extract fields from the response text based on output columns.""" |
|
extracted = {} |
|
for col in self.output_fields: |
|
field = col.replace("[Output] ", "") |
|
match = re.search(f"<{field}>(.*?)</{field}>", response) |
|
extracted[col] = match.group(1) if match else "" |
|
if any(extracted[col] == "" for col in self.output_fields): |
|
self.failed_count += 1 |
|
return extracted |
|
|
|
|
|
|
|
def _predict_and_extract(self, row: int) -> dict[str, str]: |
|
"""Helper function to predict and extract fields for a single row.""" |
|
|
|
|
|
if any(pd.isnull(self.data.at[row, col]) for col in self.output_fields): |
|
prediction = self.predict_output(self.data.iloc[row]) |
|
extracted_fields = self.extract_fields(prediction) |
|
return extracted_fields |
|
else: |
|
return {col: self.data.at[row, col] for col in self.output_fields} |
|
|
|
def batch_prediction(self, start_index: int, end_index: int): |
|
"""Process a batch of predictions asynchronously.""" |
|
with ThreadPoolExecutor() as executor: |
|
results = list( |
|
executor.map(self._predict_and_extract, range(start_index, end_index)) |
|
) |
|
for i, extracted_fields in zip(range(start_index, end_index), results): |
|
for field_name in self.output_fields: |
|
self.data.at[i, field_name] = extracted_fields.get(field_name, "") |
|
|
|
def run(self): |
|
tqdm_bar = tqdm(total=self.num_data, leave=False) |
|
for start in range(0, self.num_data, self.save_every): |
|
end = min(start + self.save_every, self.num_data) |
|
try: |
|
self.batch_prediction(start, end) |
|
except Exception as e: |
|
print(e) |
|
self.data.to_excel(self.out_file_path, index=False) |
|
tqdm_bar.update(min(self.save_every, self.num_data - start)) |
|
self.data.to_excel(self.out_file_path, index=False) |
|
print(f"Results saved to {self.out_file_path}") |
|
|