Merge branch 'master'
Browse files- app.py +19 -9
- assets/demo.png +0 -0
- autotab.py +43 -37
- data/ch_patent_input.xlsx +0 -0
- data/ch_patent_output.xlsx +0 -0
- data/en_qa_input.xlsx +0 -0
- data/en_qa_output.xlsx +0 -0
- demo.ipynb +111 -0
app.py
CHANGED
@@ -15,10 +15,10 @@ def auto_tabulator_completion(
|
|
15 |
generation_config: dict,
|
16 |
request_interval: float,
|
17 |
save_every: int,
|
18 |
-
|
19 |
base_url: str,
|
20 |
) -> tuple[str, str, str, pd.DataFrame]:
|
21 |
-
output_file_name = "
|
22 |
autotab = AutoTab(
|
23 |
in_file_path=in_file_path,
|
24 |
out_file_path=output_file_name,
|
@@ -28,14 +28,23 @@ def auto_tabulator_completion(
|
|
28 |
generation_config=json.loads(generation_config),
|
29 |
request_interval=request_interval,
|
30 |
save_every=save_every,
|
31 |
-
|
32 |
base_url=base_url,
|
33 |
)
|
34 |
start = time.time()
|
35 |
autotab.run()
|
36 |
-
time_taken = time.
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
# Gradio interface
|
@@ -45,7 +54,7 @@ inputs = [
|
|
45 |
value="You are a helpful assistant. Help me finish the task.",
|
46 |
label="Instruction",
|
47 |
),
|
48 |
-
gr.Slider(value=
|
49 |
gr.Textbox(value="Qwen/Qwen2-7B-Instruct", label="Model Name"),
|
50 |
gr.Textbox(
|
51 |
value='{"temperature": 0, "max_tokens": 128}',
|
@@ -54,13 +63,14 @@ inputs = [
|
|
54 |
gr.Slider(value=0.1, minimum=0, maximum=10, label="Request Interval in Seconds"),
|
55 |
gr.Slider(value=100, minimum=1, maximum=1000, step=1, label="Save Every N Steps"),
|
56 |
gr.Textbox(
|
57 |
-
value="sk-exhahhjfqyanmwewndukcqtrpegfdbwszkjucvcpajdufiah",
|
|
|
58 |
),
|
59 |
gr.Textbox(value="https://public-beta-api.siliconflow.cn/v1", label="Base URL"),
|
60 |
]
|
61 |
|
62 |
outputs = [
|
63 |
-
gr.Textbox(label="
|
64 |
gr.File(label="Output Excel File"),
|
65 |
gr.Textbox(label="Query Example"),
|
66 |
gr.Dataframe(label="First 15 rows."),
|
@@ -71,5 +81,5 @@ gr.Interface(
|
|
71 |
inputs=inputs,
|
72 |
outputs=outputs,
|
73 |
title="Auto Tabulator Completion",
|
74 |
-
description="Automatically complete missing output values in tabular data based on in-context learning.",
|
75 |
).launch()
|
|
|
15 |
generation_config: dict,
|
16 |
request_interval: float,
|
17 |
save_every: int,
|
18 |
+
str_api_keys: str,
|
19 |
base_url: str,
|
20 |
) -> tuple[str, str, str, pd.DataFrame]:
|
21 |
+
output_file_name = f"output_{time.strftime('%Y%m%d%H%M%S')}.xlsx"
|
22 |
autotab = AutoTab(
|
23 |
in_file_path=in_file_path,
|
24 |
out_file_path=output_file_name,
|
|
|
28 |
generation_config=json.loads(generation_config),
|
29 |
request_interval=request_interval,
|
30 |
save_every=save_every,
|
31 |
+
api_keys=str_api_keys.split(),
|
32 |
base_url=base_url,
|
33 |
)
|
34 |
start = time.time()
|
35 |
autotab.run()
|
36 |
+
time_taken = time.time() - start
|
37 |
|
38 |
+
report = f"Total data points: {autotab.num_data}\n" + \
|
39 |
+
f"Total missing (before): {autotab.num_missing}\n" + \
|
40 |
+
f"Total missing (after): {autotab.failed_count}\n" + \
|
41 |
+
f"Total queries made: {autotab.request_count}\n" + \
|
42 |
+
f"Time taken: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start))}\n" + \
|
43 |
+
f"Prediction per second: {autotab.num_missing / time_taken:.2f}\n" + \
|
44 |
+
f"Query per second: {autotab.request_count / time_taken:.2f}"
|
45 |
+
|
46 |
+
query_example = autotab.query_example if autotab.request_count > 0 else "No queries made."
|
47 |
+
return report, output_file_name, query_example, autotab.data[:15]
|
48 |
|
49 |
|
50 |
# Gradio interface
|
|
|
54 |
value="You are a helpful assistant. Help me finish the task.",
|
55 |
label="Instruction",
|
56 |
),
|
57 |
+
gr.Slider(value=4, minimum=1, maximum=50, step=1, label="Max Examples"),
|
58 |
gr.Textbox(value="Qwen/Qwen2-7B-Instruct", label="Model Name"),
|
59 |
gr.Textbox(
|
60 |
value='{"temperature": 0, "max_tokens": 128}',
|
|
|
63 |
gr.Slider(value=0.1, minimum=0, maximum=10, label="Request Interval in Seconds"),
|
64 |
gr.Slider(value=100, minimum=1, maximum=1000, step=1, label="Save Every N Steps"),
|
65 |
gr.Textbox(
|
66 |
+
value="sk-exhahhjfqyanmwewndukcqtrpegfdbwszkjucvcpajdufiah",
|
67 |
+
label="API Key(s). One per line.",
|
68 |
),
|
69 |
gr.Textbox(value="https://public-beta-api.siliconflow.cn/v1", label="Base URL"),
|
70 |
]
|
71 |
|
72 |
outputs = [
|
73 |
+
gr.Textbox(label="Report"),
|
74 |
gr.File(label="Output Excel File"),
|
75 |
gr.Textbox(label="Query Example"),
|
76 |
gr.Dataframe(label="First 15 rows."),
|
|
|
81 |
inputs=inputs,
|
82 |
outputs=outputs,
|
83 |
title="Auto Tabulator Completion",
|
84 |
+
description="Automatically complete missing output values in tabular data based on in-context learning. Check https://github.com/Ki-Seki/autotab.",
|
85 |
).launch()
|
assets/demo.png
ADDED
autotab.py
CHANGED
@@ -19,7 +19,7 @@ class AutoTab:
|
|
19 |
generation_config: dict,
|
20 |
request_interval: float,
|
21 |
save_every: int,
|
22 |
-
|
23 |
base_url: str,
|
24 |
):
|
25 |
self.in_file_path = in_file_path
|
@@ -30,9 +30,17 @@ class AutoTab:
|
|
30 |
self.generation_config = generation_config
|
31 |
self.request_interval = request_interval
|
32 |
self.save_every = save_every
|
33 |
-
self.
|
34 |
self.base_url = base_url
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# ─── IO ───────────────────────────────────────────────────────────────
|
37 |
|
38 |
def load_excel(self) -> tuple[pd.DataFrame, list, list]:
|
@@ -47,8 +55,15 @@ class AutoTab:
|
|
47 |
@retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6))
|
48 |
def openai_request(self, query: str) -> str:
|
49 |
"""Make a request to an OpenAI-format API."""
|
|
|
|
|
50 |
time.sleep(self.request_interval)
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
response = client.chat.completions.create(
|
53 |
model=self.model_name,
|
54 |
messages=[{"role": "user", "content": query}],
|
@@ -59,61 +74,60 @@ class AutoTab:
|
|
59 |
|
60 |
# ─── In-Context Learning ──────────────────────────────────────────────
|
61 |
|
62 |
-
def derive_incontext(
|
63 |
-
self, data: pd.DataFrame, input_columns: list[str], output_columns: list[str]
|
64 |
-
) -> str:
|
65 |
"""Derive the in-context prompt with angle brackets."""
|
66 |
-
|
67 |
in_context = ""
|
68 |
-
for i in range(
|
69 |
in_context += "".join(
|
70 |
-
f"<{col.replace('[Input] ', '')}>{data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
|
71 |
-
for col in
|
72 |
)
|
73 |
in_context += "".join(
|
74 |
-
f"<{col.replace('[Output] ', '')}>{data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
|
75 |
-
for col in
|
76 |
)
|
77 |
in_context += "\n"
|
78 |
return in_context
|
79 |
|
80 |
-
def predict_output(
|
81 |
-
self, in_context: str, input_data: pd.DataFrame, input_fields: str
|
82 |
-
):
|
83 |
"""Predict the output values for the given input data using the API."""
|
84 |
query = (
|
85 |
self.instruction
|
86 |
+ "\n\n"
|
87 |
-
+ in_context
|
88 |
+ "".join(
|
89 |
f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
|
90 |
-
for col in input_fields
|
91 |
)
|
92 |
)
|
93 |
self.query_example = query
|
94 |
output = self.openai_request(query)
|
95 |
return output
|
96 |
|
97 |
-
def extract_fields(
|
98 |
-
self, response: str, output_columns: list[str]
|
99 |
-
) -> dict[str, str]:
|
100 |
"""Extract fields from the response text based on output columns."""
|
101 |
extracted = {}
|
102 |
-
for col in
|
103 |
field = col.replace("[Output] ", "")
|
104 |
match = re.search(f"<{field}>(.*?)</{field}>", response)
|
105 |
extracted[col] = match.group(1) if match else ""
|
|
|
|
|
106 |
return extracted
|
107 |
|
108 |
# ─── Engine ───────────────────────────────────────────────────────────
|
109 |
|
110 |
-
def _predict_and_extract(self,
|
111 |
"""Helper function to predict and extract fields for a single row."""
|
112 |
-
|
113 |
-
|
114 |
-
)
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
117 |
|
118 |
def batch_prediction(self, start_index: int, end_index: int):
|
119 |
"""Process a batch of predictions asynchronously."""
|
@@ -126,16 +140,8 @@ class AutoTab:
|
|
126 |
self.data.at[i, field_name] = extracted_fields.get(field_name, "")
|
127 |
|
128 |
def run(self):
|
129 |
-
|
130 |
-
self.
|
131 |
-
self.data, self.input_fields, self.output_fields
|
132 |
-
)
|
133 |
-
|
134 |
-
self.num_data = len(self.data)
|
135 |
-
self.num_examples = len(self.data.dropna(subset=self.output_fields))
|
136 |
-
|
137 |
-
tqdm_bar = tqdm(total=self.num_data - self.num_examples, leave=False)
|
138 |
-
for start in range(self.num_examples, self.num_data, self.save_every):
|
139 |
tqdm_bar.update(min(self.save_every, self.num_data - start))
|
140 |
end = min(start + self.save_every, self.num_data)
|
141 |
try:
|
|
|
19 |
generation_config: dict,
|
20 |
request_interval: float,
|
21 |
save_every: int,
|
22 |
+
api_keys: list[str],
|
23 |
base_url: str,
|
24 |
):
|
25 |
self.in_file_path = in_file_path
|
|
|
30 |
self.generation_config = generation_config
|
31 |
self.request_interval = request_interval
|
32 |
self.save_every = save_every
|
33 |
+
self.api_keys = api_keys
|
34 |
self.base_url = base_url
|
35 |
|
36 |
+
self.request_count = 0
|
37 |
+
self.failed_count = 0
|
38 |
+
self.data, self.input_fields, self.output_fields = self.load_excel()
|
39 |
+
self.in_context = self.derive_incontext()
|
40 |
+
self.num_data = len(self.data)
|
41 |
+
self.num_example = len(self.data.dropna(subset=self.output_fields))
|
42 |
+
self.num_missing = self.num_data - self.num_example
|
43 |
+
|
44 |
# ─── IO ───────────────────────────────────────────────────────────────
|
45 |
|
46 |
def load_excel(self) -> tuple[pd.DataFrame, list, list]:
|
|
|
55 |
@retry(wait=wait_random_exponential(min=20, max=60), stop=stop_after_attempt(6))
|
56 |
def openai_request(self, query: str) -> str:
|
57 |
"""Make a request to an OpenAI-format API."""
|
58 |
+
|
59 |
+
# Wait for the request interval
|
60 |
time.sleep(self.request_interval)
|
61 |
+
|
62 |
+
# Increment the request count
|
63 |
+
api_key = self.api_keys[self.request_count % len(self.api_keys)]
|
64 |
+
self.request_count += 1
|
65 |
+
|
66 |
+
client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
|
67 |
response = client.chat.completions.create(
|
68 |
model=self.model_name,
|
69 |
messages=[{"role": "user", "content": query}],
|
|
|
74 |
|
75 |
# ─── In-Context Learning ──────────────────────────────────────────────
|
76 |
|
77 |
+
def derive_incontext(self) -> str:
|
|
|
|
|
78 |
"""Derive the in-context prompt with angle brackets."""
|
79 |
+
examples = self.data.dropna(subset=self.output_fields)[: self.max_examples]
|
80 |
in_context = ""
|
81 |
+
for i in range(len(examples)):
|
82 |
in_context += "".join(
|
83 |
+
f"<{col.replace('[Input] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
|
84 |
+
for col in self.input_fields
|
85 |
)
|
86 |
in_context += "".join(
|
87 |
+
f"<{col.replace('[Output] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
|
88 |
+
for col in self.output_fields
|
89 |
)
|
90 |
in_context += "\n"
|
91 |
return in_context
|
92 |
|
93 |
+
def predict_output(self, input_data: pd.DataFrame):
|
|
|
|
|
94 |
"""Predict the output values for the given input data using the API."""
|
95 |
query = (
|
96 |
self.instruction
|
97 |
+ "\n\n"
|
98 |
+
+ self.in_context
|
99 |
+ "".join(
|
100 |
f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
|
101 |
+
for col in self.input_fields
|
102 |
)
|
103 |
)
|
104 |
self.query_example = query
|
105 |
output = self.openai_request(query)
|
106 |
return output
|
107 |
|
108 |
+
def extract_fields(self, response: str) -> dict[str, str]:
|
|
|
|
|
109 |
"""Extract fields from the response text based on output columns."""
|
110 |
extracted = {}
|
111 |
+
for col in self.output_fields:
|
112 |
field = col.replace("[Output] ", "")
|
113 |
match = re.search(f"<{field}>(.*?)</{field}>", response)
|
114 |
extracted[col] = match.group(1) if match else ""
|
115 |
+
if any(extracted[col] == "" for col in self.output_fields):
|
116 |
+
self.failed_count += 1
|
117 |
return extracted
|
118 |
|
119 |
# ─── Engine ───────────────────────────────────────────────────────────
|
120 |
|
121 |
+
def _predict_and_extract(self, row: int) -> dict[str, str]:
|
122 |
"""Helper function to predict and extract fields for a single row."""
|
123 |
+
|
124 |
+
# If any output field is empty, predict the output
|
125 |
+
if any(pd.isnull(self.data.at[row, col]) for col in self.output_fields):
|
126 |
+
prediction = self.predict_output(self.data.iloc[row])
|
127 |
+
extracted_fields = self.extract_fields(prediction)
|
128 |
+
return extracted_fields
|
129 |
+
else:
|
130 |
+
return {col: self.data.at[row, col] for col in self.output_fields}
|
131 |
|
132 |
def batch_prediction(self, start_index: int, end_index: int):
|
133 |
"""Process a batch of predictions asynchronously."""
|
|
|
140 |
self.data.at[i, field_name] = extracted_fields.get(field_name, "")
|
141 |
|
142 |
def run(self):
|
143 |
+
tqdm_bar = tqdm(total=self.num_data, leave=False)
|
144 |
+
for start in range(0, self.num_data, self.save_every):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
tqdm_bar.update(min(self.save_every, self.num_data - start))
|
146 |
end = min(start + self.save_every, self.num_data)
|
147 |
try:
|
data/ch_patent_input.xlsx
ADDED
Binary file (33 kB). View file
|
|
data/ch_patent_output.xlsx
ADDED
Binary file (41.5 kB). View file
|
|
data/en_qa_input.xlsx
ADDED
Binary file (6.77 kB). View file
|
|
data/en_qa_output.xlsx
ADDED
Binary file (5.37 kB). View file
|
|
demo.ipynb
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
" "
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "stdout",
|
17 |
+
"output_type": "stream",
|
18 |
+
"text": [
|
19 |
+
"Results saved to data/en_qa_output.xlsx\n"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"name": "stderr",
|
24 |
+
"output_type": "stream",
|
25 |
+
"text": [
|
26 |
+
"\r"
|
27 |
+
]
|
28 |
+
}
|
29 |
+
],
|
30 |
+
"source": [
|
31 |
+
"from autotab import AutoTab\n",
|
32 |
+
"\n",
|
33 |
+
"\n",
|
34 |
+
"autotab = AutoTab(\n",
|
35 |
+
" in_file_path=\"data/en_qa_input.xlsx\",\n",
|
36 |
+
" out_file_path=\"data/en_qa_output.xlsx\",\n",
|
37 |
+
" instruction=\"You should help me classify the questions and answer them.\",\n",
|
38 |
+
" max_examples=5,\n",
|
39 |
+
" model_name=\"Qwen/Qwen2-7B-Instruct\",\n",
|
40 |
+
" generation_config={\"temperature\": 0, \"max_tokens\": 128},\n",
|
41 |
+
" request_interval=0.01,\n",
|
42 |
+
" api_keys=[\"sk-exhahhjfqyanmwewndukcqtrpegfdbwszkjucvcpajdufiah\"],\n",
|
43 |
+
" base_url=\"https://public-beta-api.siliconflow.cn/v1\",\n",
|
44 |
+
" save_every=10,\n",
|
45 |
+
")\n",
|
46 |
+
"autotab.run()"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 2,
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [
|
54 |
+
{
|
55 |
+
"name": "stdout",
|
56 |
+
"output_type": "stream",
|
57 |
+
"text": [
|
58 |
+
"You should help me classify the questions and answer them.\n",
|
59 |
+
"\n",
|
60 |
+
"<Question>What is the capital of France?</Question>\n",
|
61 |
+
"<Category>Geography</Category>\n",
|
62 |
+
"<Answer>Paris</Answer>\n",
|
63 |
+
"\n",
|
64 |
+
"<Question>Who wrote '1984'?</Question>\n",
|
65 |
+
"<Category>Literature</Category>\n",
|
66 |
+
"<Answer>George Orwell</Answer>\n",
|
67 |
+
"\n",
|
68 |
+
"<Question>What is the largest planet in the solar system?</Question>\n",
|
69 |
+
"<Category>Astronomy</Category>\n",
|
70 |
+
"<Answer>Jupiter</Answer>\n",
|
71 |
+
"\n",
|
72 |
+
"<Question>Who painted the Mona Lisa?</Question>\n",
|
73 |
+
"<Category>Art</Category>\n",
|
74 |
+
"<Answer>Leonardo da Vinci</Answer>\n",
|
75 |
+
"\n",
|
76 |
+
"<Question>What is the currency of Japan?</Question>\n",
|
77 |
+
"<Category>Economics</Category>\n",
|
78 |
+
"<Answer>Yen</Answer>\n",
|
79 |
+
"\n",
|
80 |
+
"<Question>Who is the first president of the United States?</Question>\n",
|
81 |
+
"\n"
|
82 |
+
]
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"source": [
|
86 |
+
"print(autotab.query_example)"
|
87 |
+
]
|
88 |
+
}
|
89 |
+
],
|
90 |
+
"metadata": {
|
91 |
+
"kernelspec": {
|
92 |
+
"display_name": "common",
|
93 |
+
"language": "python",
|
94 |
+
"name": "python3"
|
95 |
+
},
|
96 |
+
"language_info": {
|
97 |
+
"codemirror_mode": {
|
98 |
+
"name": "ipython",
|
99 |
+
"version": 3
|
100 |
+
},
|
101 |
+
"file_extension": ".py",
|
102 |
+
"mimetype": "text/x-python",
|
103 |
+
"name": "python",
|
104 |
+
"nbconvert_exporter": "python",
|
105 |
+
"pygments_lexer": "ipython3",
|
106 |
+
"version": "3.12.2"
|
107 |
+
}
|
108 |
+
},
|
109 |
+
"nbformat": 4,
|
110 |
+
"nbformat_minor": 2
|
111 |
+
}
|