AutoTab / autotab.py
Ki-Seki's picture
fix: optimize tqdm bar update in AutoTab class
414e9a1
import re
import time
from concurrent.futures import ThreadPoolExecutor
import openai
import pandas as pd
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm import tqdm
class AutoTab:
def __init__(
self,
in_file_path: str,
out_file_path: str,
instruction: str,
max_examples: int,
model_name: str,
generation_config: dict,
request_interval: float,
save_every: int,
api_keys: list[str],
base_url: str,
):
self.in_file_path = in_file_path
self.out_file_path = out_file_path
self.instruction = instruction
self.max_examples = max_examples
self.model_name = model_name
self.generation_config = generation_config
self.request_interval = request_interval
self.save_every = save_every
self.api_keys = api_keys
self.base_url = base_url
self.request_count = 0
self.failed_count = 0
self.data, self.input_fields, self.output_fields = self.load_excel()
self.in_context = self.derive_incontext()
self.num_data = len(self.data)
self.num_example = len(self.data.dropna(subset=self.output_fields))
self.num_missing = self.num_data - self.num_example
# ─── IO ───────────────────────────────────────────────────────────────
def load_excel(self) -> tuple[pd.DataFrame, list, list]:
"""Load the Excel file and identify input and output fields."""
df = pd.read_excel(self.in_file_path)
input_fields = [col for col in df.columns if col.startswith("[Input] ")]
output_fields = [col for col in df.columns if col.startswith("[Output] ")]
return df, input_fields, output_fields
# ─── LLM ──────────────────────────────────────────────────────────────
@retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6))
def openai_request(self, query: str) -> str:
"""Make a request to an OpenAI-format API."""
# Wait for the request interval
time.sleep(self.request_interval)
# Increment the request count
api_key = self.api_keys[self.request_count % len(self.api_keys)]
self.request_count += 1
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": query}],
**self.generation_config,
)
str_response = response.choices[0].message.content.strip()
return str_response
# ─── In-Context Learning ──────────────────────────────────────────────
def derive_incontext(self) -> str:
"""Derive the in-context prompt with angle brackets."""
examples = self.data.dropna(subset=self.output_fields)[: self.max_examples]
in_context = ""
for i in range(len(examples)):
in_context += "".join(
f"<{col.replace('[Input] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
for col in self.input_fields
)
in_context += "".join(
f"<{col.replace('[Output] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
for col in self.output_fields
)
in_context += "\n"
return in_context
def predict_output(self, input_data: pd.DataFrame):
"""Predict the output values for the given input data using the API."""
query = (
self.instruction
+ "\n\n"
+ self.in_context
+ "".join(
f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
for col in self.input_fields
)
)
self.query_example = query
output = self.openai_request(query)
return output
def extract_fields(self, response: str) -> dict[str, str]:
"""Extract fields from the response text based on output columns."""
extracted = {}
for col in self.output_fields:
field = col.replace("[Output] ", "")
match = re.search(f"<{field}>(.*?)</{field}>", response)
extracted[col] = match.group(1) if match else ""
if any(extracted[col] == "" for col in self.output_fields):
self.failed_count += 1
return extracted
# ─── Engine ───────────────────────────────────────────────────────────
def _predict_and_extract(self, row: int) -> dict[str, str]:
"""Helper function to predict and extract fields for a single row."""
# If any output field is empty, predict the output
if any(pd.isnull(self.data.at[row, col]) for col in self.output_fields):
prediction = self.predict_output(self.data.iloc[row])
extracted_fields = self.extract_fields(prediction)
return extracted_fields
else:
return {col: self.data.at[row, col] for col in self.output_fields}
def batch_prediction(self, start_index: int, end_index: int):
"""Process a batch of predictions asynchronously."""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._predict_and_extract, range(start_index, end_index))
)
for i, extracted_fields in zip(range(start_index, end_index), results):
for field_name in self.output_fields:
self.data.at[i, field_name] = extracted_fields.get(field_name, "")
def run(self):
tqdm_bar = tqdm(total=self.num_data, leave=False)
for start in range(0, self.num_data, self.save_every):
end = min(start + self.save_every, self.num_data)
try:
self.batch_prediction(start, end)
except Exception as e:
print(e)
self.data.to_excel(self.out_file_path, index=False)
tqdm_bar.update(min(self.save_every, self.num_data - start))
self.data.to_excel(self.out_file_path, index=False)
print(f"Results saved to {self.out_file_path}")