Ki-Seki commited on
Commit
4bc02fa
2 Parent(s): bc989cb c2c0cff

Merge branch 'master'

Browse files
app.py CHANGED
@@ -15,10 +15,10 @@ def auto_tabulator_completion(
15
  generation_config: dict,
16
  request_interval: float,
17
  save_every: int,
18
- api_key: str,
19
  base_url: str,
20
  ) -> tuple[str, str, str, pd.DataFrame]:
21
- output_file_name = "ouput.xlsx"
22
  autotab = AutoTab(
23
  in_file_path=in_file_path,
24
  out_file_path=output_file_name,
@@ -28,14 +28,23 @@ def auto_tabulator_completion(
28
  generation_config=json.loads(generation_config),
29
  request_interval=request_interval,
30
  save_every=save_every,
31
- api_key=api_key,
32
  base_url=base_url,
33
  )
34
  start = time.time()
35
  autotab.run()
36
- time_taken = time.strftime("%H:%M:%S", time.gmtime(time.time() - start))
37
 
38
- return time_taken, output_file_name, autotab.query_example, autotab.data[:15]
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  # Gradio interface
@@ -45,7 +54,7 @@ inputs = [
45
  value="You are a helpful assistant. Help me finish the task.",
46
  label="Instruction",
47
  ),
48
- gr.Slider(value=5, minimum=1, maximum=50, step=1, label="Max Examples"),
49
  gr.Textbox(value="Qwen/Qwen2-7B-Instruct", label="Model Name"),
50
  gr.Textbox(
51
  value='{"temperature": 0, "max_tokens": 128}',
@@ -54,13 +63,14 @@ inputs = [
54
  gr.Slider(value=0.1, minimum=0, maximum=10, label="Request Interval in Seconds"),
55
  gr.Slider(value=100, minimum=1, maximum=1000, step=1, label="Save Every N Steps"),
56
  gr.Textbox(
57
- value="sk-exhahhjfqyanmwewndukcqtrpegfdbwszkjucvcpajdufiah", label="API Key"
 
58
  ),
59
  gr.Textbox(value="https://public-beta-api.siliconflow.cn/v1", label="Base URL"),
60
  ]
61
 
62
  outputs = [
63
- gr.Textbox(label="Time Taken"),
64
  gr.File(label="Output Excel File"),
65
  gr.Textbox(label="Query Example"),
66
  gr.Dataframe(label="First 15 rows."),
@@ -71,5 +81,5 @@ gr.Interface(
71
  inputs=inputs,
72
  outputs=outputs,
73
  title="Auto Tabulator Completion",
74
- description="Automatically complete missing output values in tabular data based on in-context learning.",
75
  ).launch()
 
15
  generation_config: dict,
16
  request_interval: float,
17
  save_every: int,
18
+ str_api_keys: str,
19
  base_url: str,
20
  ) -> tuple[str, str, str, pd.DataFrame]:
21
+ output_file_name = f"output_{time.strftime('%Y%m%d%H%M%S')}.xlsx"
22
  autotab = AutoTab(
23
  in_file_path=in_file_path,
24
  out_file_path=output_file_name,
 
28
  generation_config=json.loads(generation_config),
29
  request_interval=request_interval,
30
  save_every=save_every,
31
+ api_keys=str_api_keys.split(),
32
  base_url=base_url,
33
  )
34
  start = time.time()
35
  autotab.run()
36
+ time_taken = time.time() - start
37
 
38
+ report = f"Total data points: {autotab.num_data}\n" + \
39
+ f"Total missing (before): {autotab.num_missing}\n" + \
40
+ f"Total missing (after): {autotab.failed_count}\n" + \
41
+ f"Total queries made: {autotab.request_count}\n" + \
42
+ f"Time taken: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start))}\n" + \
43
+ f"Prediction per second: {autotab.num_missing / time_taken:.2f}\n" + \
44
+ f"Query per second: {autotab.request_count / time_taken:.2f}"
45
+
46
+ query_example = autotab.query_example if autotab.request_count > 0 else "No queries made."
47
+ return report, output_file_name, query_example, autotab.data[:15]
48
 
49
 
50
  # Gradio interface
 
54
  value="You are a helpful assistant. Help me finish the task.",
55
  label="Instruction",
56
  ),
57
+ gr.Slider(value=4, minimum=1, maximum=50, step=1, label="Max Examples"),
58
  gr.Textbox(value="Qwen/Qwen2-7B-Instruct", label="Model Name"),
59
  gr.Textbox(
60
  value='{"temperature": 0, "max_tokens": 128}',
 
63
  gr.Slider(value=0.1, minimum=0, maximum=10, label="Request Interval in Seconds"),
64
  gr.Slider(value=100, minimum=1, maximum=1000, step=1, label="Save Every N Steps"),
65
  gr.Textbox(
66
+ value="sk-exhahhjfqyanmwewndukcqtrpegfdbwszkjucvcpajdufiah",
67
+ label="API Key(s). One per line.",
68
  ),
69
  gr.Textbox(value="https://public-beta-api.siliconflow.cn/v1", label="Base URL"),
70
  ]
71
 
72
  outputs = [
73
+ gr.Textbox(label="Report"),
74
  gr.File(label="Output Excel File"),
75
  gr.Textbox(label="Query Example"),
76
  gr.Dataframe(label="First 15 rows."),
 
81
  inputs=inputs,
82
  outputs=outputs,
83
  title="Auto Tabulator Completion",
84
+ description="Automatically complete missing output values in tabular data based on in-context learning. Check https://github.com/Ki-Seki/autotab.",
85
  ).launch()
assets/demo.png ADDED
autotab.py CHANGED
@@ -19,7 +19,7 @@ class AutoTab:
19
  generation_config: dict,
20
  request_interval: float,
21
  save_every: int,
22
- api_key: str,
23
  base_url: str,
24
  ):
25
  self.in_file_path = in_file_path
@@ -30,9 +30,17 @@ class AutoTab:
30
  self.generation_config = generation_config
31
  self.request_interval = request_interval
32
  self.save_every = save_every
33
- self.api_key = api_key
34
  self.base_url = base_url
35
 
 
 
 
 
 
 
 
 
36
  # ─── IO ───────────────────────────────────────────────────────────────
37
 
38
  def load_excel(self) -> tuple[pd.DataFrame, list, list]:
@@ -47,8 +55,15 @@ class AutoTab:
47
  @retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6))
48
  def openai_request(self, query: str) -> str:
49
  """Make a request to an OpenAI-format API."""
 
 
50
  time.sleep(self.request_interval)
51
- client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
 
 
 
 
 
52
  response = client.chat.completions.create(
53
  model=self.model_name,
54
  messages=[{"role": "user", "content": query}],
@@ -59,61 +74,60 @@ class AutoTab:
59
 
60
  # ─── In-Context Learning ──────────────────────────────────────────────
61
 
62
- def derive_incontext(
63
- self, data: pd.DataFrame, input_columns: list[str], output_columns: list[str]
64
- ) -> str:
65
  """Derive the in-context prompt with angle brackets."""
66
- n = min(self.max_examples, len(data.dropna(subset=output_columns)))
67
  in_context = ""
68
- for i in range(n):
69
  in_context += "".join(
70
- f"<{col.replace('[Input] ', '')}>{data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
71
- for col in input_columns
72
  )
73
  in_context += "".join(
74
- f"<{col.replace('[Output] ', '')}>{data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
75
- for col in output_columns
76
  )
77
  in_context += "\n"
78
  return in_context
79
 
80
- def predict_output(
81
- self, in_context: str, input_data: pd.DataFrame, input_fields: str
82
- ):
83
  """Predict the output values for the given input data using the API."""
84
  query = (
85
  self.instruction
86
  + "\n\n"
87
- + in_context
88
  + "".join(
89
  f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
90
- for col in input_fields
91
  )
92
  )
93
  self.query_example = query
94
  output = self.openai_request(query)
95
  return output
96
 
97
- def extract_fields(
98
- self, response: str, output_columns: list[str]
99
- ) -> dict[str, str]:
100
  """Extract fields from the response text based on output columns."""
101
  extracted = {}
102
- for col in output_columns:
103
  field = col.replace("[Output] ", "")
104
  match = re.search(f"<{field}>(.*?)</{field}>", response)
105
  extracted[col] = match.group(1) if match else ""
 
 
106
  return extracted
107
 
108
  # ─── Engine ───────────────────────────────────────────────────────────
109
 
110
- def _predict_and_extract(self, i: int) -> dict[str, str]:
111
  """Helper function to predict and extract fields for a single row."""
112
- prediction = self.predict_output(
113
- self.in_context, self.data.iloc[i], self.input_fields
114
- )
115
- extracted_fields = self.extract_fields(prediction, self.output_fields)
116
- return extracted_fields
 
 
 
117
 
118
  def batch_prediction(self, start_index: int, end_index: int):
119
  """Process a batch of predictions asynchronously."""
@@ -126,16 +140,8 @@ class AutoTab:
126
  self.data.at[i, field_name] = extracted_fields.get(field_name, "")
127
 
128
  def run(self):
129
- self.data, self.input_fields, self.output_fields = self.load_excel()
130
- self.in_context = self.derive_incontext(
131
- self.data, self.input_fields, self.output_fields
132
- )
133
-
134
- self.num_data = len(self.data)
135
- self.num_examples = len(self.data.dropna(subset=self.output_fields))
136
-
137
- tqdm_bar = tqdm(total=self.num_data - self.num_examples, leave=False)
138
- for start in range(self.num_examples, self.num_data, self.save_every):
139
  tqdm_bar.update(min(self.save_every, self.num_data - start))
140
  end = min(start + self.save_every, self.num_data)
141
  try:
 
19
  generation_config: dict,
20
  request_interval: float,
21
  save_every: int,
22
+ api_keys: list[str],
23
  base_url: str,
24
  ):
25
  self.in_file_path = in_file_path
 
30
  self.generation_config = generation_config
31
  self.request_interval = request_interval
32
  self.save_every = save_every
33
+ self.api_keys = api_keys
34
  self.base_url = base_url
35
 
36
+ self.request_count = 0
37
+ self.failed_count = 0
38
+ self.data, self.input_fields, self.output_fields = self.load_excel()
39
+ self.in_context = self.derive_incontext()
40
+ self.num_data = len(self.data)
41
+ self.num_example = len(self.data.dropna(subset=self.output_fields))
42
+ self.num_missing = self.num_data - self.num_example
43
+
44
  # ─── IO ───────────────────────────────────────────────────────────────
45
 
46
  def load_excel(self) -> tuple[pd.DataFrame, list, list]:
 
55
  @retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6))
56
  def openai_request(self, query: str) -> str:
57
  """Make a request to an OpenAI-format API."""
58
+
59
+ # Wait for the request interval
60
  time.sleep(self.request_interval)
61
+
62
+ # Increment the request count
63
+ api_key = self.api_keys[self.request_count % len(self.api_keys)]
64
+ self.request_count += 1
65
+
66
+ client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
67
  response = client.chat.completions.create(
68
  model=self.model_name,
69
  messages=[{"role": "user", "content": query}],
 
74
 
75
  # ─── In-Context Learning ──────────────────────────────────────────────
76
 
77
+ def derive_incontext(self) -> str:
 
 
78
  """Derive the in-context prompt with angle brackets."""
79
+ examples = self.data.dropna(subset=self.output_fields)[: self.max_examples]
80
  in_context = ""
81
+ for i in range(len(examples)):
82
  in_context += "".join(
83
+ f"<{col.replace('[Input] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
84
+ for col in self.input_fields
85
  )
86
  in_context += "".join(
87
+ f"<{col.replace('[Output] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
88
+ for col in self.output_fields
89
  )
90
  in_context += "\n"
91
  return in_context
92
 
93
+ def predict_output(self, input_data: pd.DataFrame):
 
 
94
  """Predict the output values for the given input data using the API."""
95
  query = (
96
  self.instruction
97
  + "\n\n"
98
+ + self.in_context
99
  + "".join(
100
  f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
101
+ for col in self.input_fields
102
  )
103
  )
104
  self.query_example = query
105
  output = self.openai_request(query)
106
  return output
107
 
108
+ def extract_fields(self, response: str) -> dict[str, str]:
 
 
109
  """Extract fields from the response text based on output columns."""
110
  extracted = {}
111
+ for col in self.output_fields:
112
  field = col.replace("[Output] ", "")
113
  match = re.search(f"<{field}>(.*?)</{field}>", response)
114
  extracted[col] = match.group(1) if match else ""
115
+ if any(extracted[col] == "" for col in self.output_fields):
116
+ self.failed_count += 1
117
  return extracted
118
 
119
  # ─── Engine ───────────────────────────────────────────────────────────
120
 
121
+ def _predict_and_extract(self, row: int) -> dict[str, str]:
122
  """Helper function to predict and extract fields for a single row."""
123
+
124
+ # If any output field is empty, predict the output
125
+ if any(pd.isnull(self.data.at[row, col]) for col in self.output_fields):
126
+ prediction = self.predict_output(self.data.iloc[row])
127
+ extracted_fields = self.extract_fields(prediction)
128
+ return extracted_fields
129
+ else:
130
+ return {col: self.data.at[row, col] for col in self.output_fields}
131
 
132
  def batch_prediction(self, start_index: int, end_index: int):
133
  """Process a batch of predictions asynchronously."""
 
140
  self.data.at[i, field_name] = extracted_fields.get(field_name, "")
141
 
142
  def run(self):
143
+ tqdm_bar = tqdm(total=self.num_data, leave=False)
144
+ for start in range(0, self.num_data, self.save_every):
 
 
 
 
 
 
 
 
145
  tqdm_bar.update(min(self.save_every, self.num_data - start))
146
  end = min(start + self.save_every, self.num_data)
147
  try:
data/ch_patent_input.xlsx ADDED
Binary file (33 kB). View file
 
data/ch_patent_output.xlsx ADDED
Binary file (41.5 kB). View file
 
data/en_qa_input.xlsx ADDED
Binary file (6.77 kB). View file
 
data/en_qa_output.xlsx ADDED
Binary file (5.37 kB). View file
 
demo.ipynb ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ " "
13
+ ]
14
+ },
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "Results saved to data/en_qa_output.xlsx\n"
20
+ ]
21
+ },
22
+ {
23
+ "name": "stderr",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "\r"
27
+ ]
28
+ }
29
+ ],
30
+ "source": [
31
+ "from autotab import AutoTab\n",
32
+ "\n",
33
+ "\n",
34
+ "autotab = AutoTab(\n",
35
+ " in_file_path=\"data/en_qa_input.xlsx\",\n",
36
+ " out_file_path=\"data/en_qa_output.xlsx\",\n",
37
+ " instruction=\"You should help me classify the questions and answer them.\",\n",
38
+ " max_examples=5,\n",
39
+ " model_name=\"Qwen/Qwen2-7B-Instruct\",\n",
40
+ " generation_config={\"temperature\": 0, \"max_tokens\": 128},\n",
41
+ " request_interval=0.01,\n",
42
+ " api_keys=[\"sk-exhahhjfqyanmwewndukcqtrpegfdbwszkjucvcpajdufiah\"],\n",
43
+ " base_url=\"https://public-beta-api.siliconflow.cn/v1\",\n",
44
+ " save_every=10,\n",
45
+ ")\n",
46
+ "autotab.run()"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 2,
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "You should help me classify the questions and answer them.\n",
59
+ "\n",
60
+ "<Question>What is the capital of France?</Question>\n",
61
+ "<Category>Geography</Category>\n",
62
+ "<Answer>Paris</Answer>\n",
63
+ "\n",
64
+ "<Question>Who wrote '1984'?</Question>\n",
65
+ "<Category>Literature</Category>\n",
66
+ "<Answer>George Orwell</Answer>\n",
67
+ "\n",
68
+ "<Question>What is the largest planet in the solar system?</Question>\n",
69
+ "<Category>Astronomy</Category>\n",
70
+ "<Answer>Jupiter</Answer>\n",
71
+ "\n",
72
+ "<Question>Who painted the Mona Lisa?</Question>\n",
73
+ "<Category>Art</Category>\n",
74
+ "<Answer>Leonardo da Vinci</Answer>\n",
75
+ "\n",
76
+ "<Question>What is the currency of Japan?</Question>\n",
77
+ "<Category>Economics</Category>\n",
78
+ "<Answer>Yen</Answer>\n",
79
+ "\n",
80
+ "<Question>Who is the first president of the United States?</Question>\n",
81
+ "\n"
82
+ ]
83
+ }
84
+ ],
85
+ "source": [
86
+ "print(autotab.query_example)"
87
+ ]
88
+ }
89
+ ],
90
+ "metadata": {
91
+ "kernelspec": {
92
+ "display_name": "common",
93
+ "language": "python",
94
+ "name": "python3"
95
+ },
96
+ "language_info": {
97
+ "codemirror_mode": {
98
+ "name": "ipython",
99
+ "version": 3
100
+ },
101
+ "file_extension": ".py",
102
+ "mimetype": "text/x-python",
103
+ "name": "python",
104
+ "nbconvert_exporter": "python",
105
+ "pygments_lexer": "ipython3",
106
+ "version": "3.12.2"
107
+ }
108
+ },
109
+ "nbformat": 4,
110
+ "nbformat_minor": 2
111
+ }