Spaces:
Sleeping
Sleeping
import ast | |
import gradio as gr | |
import os | |
import re | |
import json | |
import logging | |
import torch | |
from datetime import datetime | |
from threading import Thread | |
from typing import Optional | |
from transformers import TextIteratorStreamer | |
from functools import partial | |
from huggingface_hub import CommitScheduler | |
from uuid import uuid4 | |
from pathlib import Path | |
from code_interpreter.JupyterClient import JupyterNotebook | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning, module="transformers") | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
from code_interpreter.OpenCodeInterpreter import OpenCodeInterpreter | |
JSON_DATASET_DIR = Path("json_dataset") | |
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
upvote_button_value = "π Upvote Conversation" | |
downvote_button_value = "π Downvote Conversation" | |
scheduler = CommitScheduler( | |
repo_id="opencodeinterpreter_user_data", | |
repo_type="dataset", | |
folder_path=JSON_DATASET_DIR, | |
path_in_repo="data", | |
private=True | |
) | |
logging.basicConfig(level=logging.INFO) | |
class StreamingOpenCodeInterpreter(OpenCodeInterpreter): | |
streamer: Optional[TextIteratorStreamer] = None | |
# overwirte generate function | |
def generate( | |
self, | |
prompt: str = "", | |
max_new_tokens = 1024, | |
do_sample: bool = False, | |
top_p: float = 0.95, | |
top_k: int = 50, | |
) -> str: | |
# Get the model and tokenizer, and tokenize the user text. | |
self.streamer = TextIteratorStreamer( | |
self.tokenizer, skip_prompt=True, Timeout=5 | |
) | |
inputs = self.tokenizer([prompt], return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH) | |
inputs = inputs.to(self.model.device) | |
kwargs = dict( | |
**inputs, | |
streamer=self.streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
top_p=top_p, | |
top_k=top_k, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
thread = Thread(target=self.model.generate, kwargs=kwargs) | |
thread.start() | |
return "" | |
def save_json(dialog, mode, json_file_path, flag, dialog_id) -> None: | |
with scheduler.lock: | |
with json_file_path.open("a") as f: | |
json.dump({"id": dialog_id, "dialog": dialog, "mode": mode, "flag": flag, "datetime": datetime.now().isoformat()}, f, ensure_ascii=False) | |
f.write("\n") | |
def convert_history(gradio_history: list[list], interpreter_history: list[dict]): | |
interpreter_history = [interpreter_history[0]] if interpreter_history and interpreter_history[0]["role"] == "system" else [] | |
if not gradio_history: | |
return interpreter_history | |
for item in gradio_history: | |
if item[0] is not None: | |
interpreter_history.append({"role": "user", "content": item[0]}) | |
if item[1] is not None: | |
interpreter_history.append({"role": "assistant", "content": item[1]}) | |
return interpreter_history | |
def reset_dialog_info(dialog_info): | |
new_uuid = str(uuid4()) | |
logging.info(f"allocating new uuid {new_uuid} for conversation...") | |
return [new_uuid, None] | |
def is_valid_python_code(code): | |
try: | |
ast.parse(code) | |
return True | |
except SyntaxError: | |
return False | |
class InputFunctionVisitor(ast.NodeVisitor): | |
def __init__(self): | |
self.found_input = False | |
def visit_Call(self, node): | |
if isinstance(node.func, ast.Name) and node.func.id == 'input': | |
self.found_input = True | |
self.generic_visit(node) | |
def has_input_function_calls(code): | |
try: | |
tree = ast.parse(code) | |
except SyntaxError: | |
return False | |
visitor = InputFunctionVisitor() | |
visitor.visit(tree) | |
return visitor.found_input | |
def gradio_launch(model_path: str, MAX_TRY: int = 3): | |
with gr.Blocks() as demo: | |
gr.Markdown("# The Official Demo of OpenCodeInterpreter Models") | |
gr.Markdown("**We use `m-a-p/OpenCodeInterpreter-DS-6.7B` model in this demo.**") | |
gr.Markdown("**NOTE: Please read the disclaimer section in [README.md](https://huggingface.co/spaces/m-a-p/OpenCodeInterpreter_demo/blob/main/README.md) before using this demo!**") | |
gr.Markdown("**By using this demo, you acknowledge that you have read this disclaimer, understand its terms, and agree to be bound by them.**") | |
chatbot = gr.Chatbot(label="OpenCodeInterpreter", avatar_images=["assets/user.pic.jpg", "assets/assistant.pic.jpg"], show_copy_button=True) | |
with gr.Group(): | |
with gr.Row(): | |
msg = gr.Textbox( | |
container=False, | |
show_label=False, | |
label="Message", | |
placeholder="Type a message...", | |
scale=7, | |
autofocus=True | |
) | |
sub = gr.Button( | |
"Submit", | |
variant="primary", | |
scale=1, | |
min_width=150 | |
) | |
# stop = gr.Button( | |
# "Stop", | |
# variant="stop", | |
# visible=False, | |
# scale=1, | |
# min_width=150 | |
# ) | |
with gr.Row(): | |
# retry = gr.Button("π Retry", variant="secondary") | |
# undo = gr.Button("β©οΈ Undo", variant="secondary") | |
upvote = gr.Button(upvote_button_value, variant="secondary") | |
downvote = gr.Button(downvote_button_value, variant="secondary") | |
clear = gr.Button("ποΈ Clear", variant="secondary") | |
session_state = gr.State([]) | |
jupyter_state = gr.State(JupyterNotebook()) | |
dialog_info = gr.State(["", None]) | |
demo.load(reset_dialog_info, dialog_info, dialog_info) | |
def bot(user_message, history, jupyter_state, dialog_info, interpreter): | |
logging.info(f"user message: {user_message}") | |
interpreter.dialog = convert_history(gradio_history=history, interpreter_history=interpreter.dialog) | |
history.append([user_message, None]) | |
interpreter.dialog.append({"role": "user", "content": user_message}) | |
# setup | |
HAS_CODE = False # For now | |
prompt = interpreter.dialog_to_prompt(dialog=interpreter.dialog) | |
_ = interpreter.generate(prompt) | |
history[-1][1] = "" | |
generated_text = "" | |
for character in interpreter.streamer: | |
history[-1][1] += character | |
history[-1][1] = history[-1][1].replace("<|EOT|>","") | |
generated_text += character | |
yield history, history, jupyter_state, dialog_info | |
if is_valid_python_code(history[-1][1].strip()): | |
history[-1][1] = f"```python\n{history[-1][1].strip()}\n```" | |
generated_text = history[-1][1] | |
HAS_CODE, generated_code_block = interpreter.extract_code_blocks( | |
generated_text | |
) | |
interpreter.dialog.append( | |
{ | |
"role": "assistant", | |
"content": generated_text.replace("<unk>_", "") | |
.replace("<unk>", "") | |
.replace("<|EOT|>", ""), | |
} | |
) | |
logging.info(f"saving current dialog to file {dialog_info[0]}.json...") | |
logging.info(f"current dialog: {interpreter.dialog}") | |
save_json(interpreter.dialog, mode="openci_only", flag=dialog_info[1], json_file_path=JSON_DATASET_DIR/f"{dialog_info[0]}.json", dialog_id=dialog_info[0]) | |
attempt = 1 | |
while HAS_CODE: | |
if attempt > MAX_TRY: | |
break | |
# if no code then doesn't have to execute it | |
generated_text = "" # clear generated text | |
yield history, history, jupyter_state, dialog_info | |
# replace unknown thing to none '' | |
generated_code_block = generated_code_block.replace( | |
"<unk>_", "" | |
).replace("<unk>", "") | |
if has_input_function_calls(generated_code_block): | |
code_block_output = "Please directly assign the value of inputs instead of using input() function in your code." | |
else: | |
( | |
code_block_output, | |
error_flag, | |
) = interpreter.execute_code_and_return_output( | |
f"{generated_code_block}", | |
jupyter_state | |
) | |
if error_flag == "Timeout": | |
logging.info(f"{dialog_info[0]}: Restart jupyter kernel due to timeout") | |
jupyter_state = JupyterNotebook() | |
code_block_output = interpreter.clean_code_output(code_block_output) | |
if code_block_output.strip(): | |
code_block_output = "Execution result: \n" + code_block_output | |
else: | |
code_block_output = "Code is executed, but result is empty. Please make sure that you include test case in your code." | |
history.append([code_block_output, ""]) | |
interpreter.dialog.append({"role": "user", "content": code_block_output}) | |
yield history, history, jupyter_state, dialog_info | |
prompt = interpreter.dialog_to_prompt(dialog=interpreter.dialog) | |
logging.info(f"generating answer for dialog {dialog_info[0]}") | |
_ = interpreter.generate(prompt) | |
for character in interpreter.streamer: | |
history[-1][1] += character | |
history[-1][1] = history[-1][1].replace("<|EOT|>","") | |
generated_text += character | |
yield history, history, jupyter_state, dialog_info | |
logging.info(f"finish generating answer for dialog {dialog_info[0]}") | |
HAS_CODE, generated_code_block = interpreter.extract_code_blocks( | |
history[-1][1] | |
) | |
interpreter.dialog.append( | |
{ | |
"role": "assistant", | |
"content": generated_text.replace("<unk>_", "") | |
.replace("<unk>", "") | |
.replace("<|EOT|>", ""), | |
} | |
) | |
attempt += 1 | |
logging.info(f"saving current dialog to file {dialog_info[0]}.json...") | |
logging.info(f"current dialog: {interpreter.dialog}") | |
save_json(interpreter.dialog, mode="openci_only", flag=dialog_info[1], json_file_path=JSON_DATASET_DIR/f"{dialog_info[0]}.json", dialog_id=dialog_info[0]) | |
if generated_text.endswith("<|EOT|>"): | |
continue | |
return history, history, jupyter_state, dialog_info | |
def reset_textbox(): | |
return gr.update(value="") | |
def set_button_variant(upvote_button_variant, downvote_button_variant): | |
return gr.Button(upvote_button_value, variant=upvote_button_variant), gr.Button(downvote_button_value, variant=downvote_button_variant) | |
def reset_button_and_flag(dialog_info): | |
return (*set_button_variant("secondary", "secondary"), [dialog_info[0], None]) | |
def clear_history(history, jupyter_state, dialog_info, interpreter): | |
interpreter.dialog = [] | |
jupyter_state.close() | |
return ([], [], JupyterNotebook(), reset_dialog_info(dialog_info), *set_button_variant("secondary", "secondary")) | |
def toggle_preference(button, dialog_info): | |
if button == upvote_button_value: | |
dialog_info[1] = True | |
elif button == downvote_button_value: | |
dialog_info[1] = False | |
else: | |
raise ValueError(button) | |
logging.info(f"{button} is clicked by {dialog_info[0]}, current flag: {dialog_info[1]}") | |
if dialog_info[1] is None: | |
return (*set_button_variant("secondary", "secondary"), dialog_info) | |
elif dialog_info[1]: | |
return (*set_button_variant("primary", "secondary"), dialog_info) | |
else: | |
return (*set_button_variant("secondary", "primary"), dialog_info) | |
def save_preference(dialog_info, interpreter): | |
if interpreter.dialog: | |
save_json(interpreter.dialog, mode="openci_only", flag=dialog_info[1], json_file_path=JSON_DATASET_DIR/f"{dialog_info[0]}.json", dialog_id=dialog_info[0]) | |
return dialog_info | |
interpreter = StreamingOpenCodeInterpreter(model_path=model_path) | |
sub.click(reset_button_and_flag, dialog_info, [upvote, downvote, dialog_info]) | |
sub.click(partial(bot, interpreter=interpreter), [msg, session_state, jupyter_state, dialog_info], [chatbot, session_state, jupyter_state, dialog_info]) | |
sub.click(reset_textbox, [], [msg]) | |
clear.click( | |
partial(clear_history, interpreter=interpreter), | |
[session_state, jupyter_state, dialog_info], | |
[chatbot, session_state, jupyter_state, dialog_info, upvote, downvote], | |
queue=False | |
) | |
upvote.click( | |
toggle_preference, | |
[upvote, dialog_info], | |
[upvote, downvote, dialog_info] | |
).then( | |
partial(save_preference, interpreter=interpreter), | |
dialog_info, | |
dialog_info | |
) | |
downvote.click( | |
toggle_preference, | |
[downvote, dialog_info], | |
[upvote, downvote, dialog_info] | |
).then( | |
partial(save_preference, interpreter=interpreter), | |
dialog_info, | |
dialog_info | |
) | |
demo.queue(max_size=20) | |
demo.launch(share=True) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--path", | |
type=str, | |
required=False, | |
help="Path to the OpenCodeInterpreter Model.", | |
default="m-a-p/OpenCodeInterpreter-DS-6.7B", | |
) | |
args = parser.parse_args() | |
gradio_launch(model_path=args.path) | |