import json
import os
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List
import gradio as gr
from dotenv import load_dotenv
from huggingface_hub import Repository
from langchain import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
from langchain.prompts import PromptTemplate
from utils import force_git_push
def replace_template(template: str, data: dict) -> str:
"""Replace template variables with data."""
for key, value in data.items():
template = template.replace(f"{{{key}}}", value)
return template
def json_to_dict(json_file: str) -> dict:
with open(json_file, "r") as f:
json_data = json.load(f)
return json_data
def generate_response(chatbot: ConversationChain, input: str, count=1) -> List[str]:
"""Generates responses for a `langchain` chatbot."""
return [chatbot.predict(input=input) for _ in range(count)]
def generate_responses(chatbots: List[ConversationChain], inputs: List[str]) -> List[str]:
"""Generates parallel responses for a list of `langchain` chatbots."""
results = []
with ThreadPoolExecutor(max_workers=100) as executor:
for result in executor.map(
generate_response,
chatbots,
inputs,
[NUM_RESPONSES] * len(inputs),
):
results += result
return results
if Path(".env").is_file():
load_dotenv(".env")
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
FORCE_PUSH = os.getenv("FORCE_PUSH")
HF_TOKEN = os.getenv("HF_TOKEN")
PROMPT_TEMPLATES = Path("prompt_templates")
NUM_RESPONSES = 3 # Number of responses to generate per interaction
DATA_FILENAME = "data.jsonl"
DATA_FILE = os.path.join("data", DATA_FILENAME)
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN)
TOTAL_CNT = 3 # How many user inputs to collect
PUSH_FREQUENCY = 60
def asynchronous_push(f_stop):
if repo.is_repo_clean():
print("Repo currently clean. Ignoring push_to_hub")
else:
repo.git_add(auto_lfs_track=True)
repo.git_commit("Auto commit by space")
if FORCE_PUSH == "yes":
force_git_push(repo)
else:
repo.git_push()
if not f_stop.is_set():
# call again in 60 seconds
threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start()
f_stop = threading.Event()
asynchronous_push(f_stop)
[input_vars, prompt_tpl] = json_to_dict(PROMPT_TEMPLATES / "prompt_01.json").values()
prompt_data = json_to_dict(PROMPT_TEMPLATES / "data_01.json")
prompt_tpl = replace_template(prompt_tpl, prompt_data)
prompt = PromptTemplate(template=prompt_tpl, input_variables=input_vars)
chatbot = ConversationChain(
llm=HuggingFaceHub(
repo_id="Open-Orca/Mistral-7B-OpenOrca",
model_kwargs={"temperature": 1},
huggingfacehub_api_token=HF_TOKEN,
),
prompt=prompt,
verbose=False,
memory=ConversationBufferMemory(ai_prefix="Assistant"),
)
demo = gr.Blocks()
with demo:
# We keep track of state as a JSON
state_dict = {
"conversation_id": str(uuid.uuid4()),
"cnt": 0,
"data": [],
"past_user_inputs": [],
"generated_responses": [],
}
state = gr.JSON(state_dict, visible=False)
gr.Markdown("# Talk to the assistant")
state_display = gr.Markdown(f"Your messages: 0/{TOTAL_CNT}")
# Generate model prediction
def _predict(txt, state):
start = time.time()
responses = generate_response(chatbot, txt, count=NUM_RESPONSES)
print(f"Time taken to generate {len(responses)} responses : {time.time() - start:.2f} seconds")
state["cnt"] += 1
metadata = {"cnt": state["cnt"], "text": txt}
for idx, response in enumerate(responses):
metadata[f"response_{idx + 1}"] = response
state["data"].append(metadata)
state["past_user_inputs"].append(txt)
past_conversation_string = "
".join(
[
"
".join(["Human 😃: " + user_input, "Assistant 🤖: " + model_response])
for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])
]
)
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True, choices=responses, interactive=True, value=responses[0]),
gr.update(value=past_conversation_string),
state,
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
def _select_response(selected_response, state):
done = state["cnt"] == TOTAL_CNT
state["generated_responses"].append(selected_response)
state["data"][-1]["selected_response"] = selected_response
if state["cnt"] == TOTAL_CNT:
with open(DATA_FILE, "a") as jsonlfile:
json_data_with_assignment_id = [
json.dumps(
dict(
{
"assignmentId": state["assignmentId"],
"conversation_id": state["conversation_id"],
},
**datum,
)
)
for datum in state["data"]
]
jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
toggle_example_submit = gr.update(visible=not done)
past_conversation_string = "
".join(
[
"
".join(["😃: " + user_input, "🤖: " + model_response])
for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"])
]
)
toggle_final_submit = gr.update(visible=False)
if done:
# Wipe the memory
chatbot.memory = ConversationBufferMemory(ai_prefix="Assistant")
else:
# Sync model's memory with the conversation path that
# was actually taken.
chatbot.memory = state["data"][-1][selected_response].memory
text_input = gr.update(visible=False) if done else gr.update(visible=True)
return (
gr.update(visible=False),
gr.update(visible=True),
text_input,
gr.update(visible=False),
state,
gr.update(value=past_conversation_string),
toggle_example_submit,
toggle_final_submit,
)
# Input fields
past_conversation = gr.Markdown()
text_input = gr.Textbox(placeholder="Enter a statement", show_label=False)
select_response = gr.Radio(
choices=[None, None],
visible=False,
label="Choose the most helpful and honest response",
)
select_response_button = gr.Button("Select Response", visible=False)
with gr.Column() as example_submit:
submit_ex_button = gr.Button("Submit")
with gr.Column(visible=False) as final_submit:
submit_hit_button = gr.Button("Submit HIT")
select_response_button.click(
_select_response,
inputs=[select_response, state],
outputs=[
select_response,
example_submit,
text_input,
select_response_button,
state,
past_conversation,
example_submit,
final_submit,
],
)
submit_ex_button.click(
_predict,
inputs=[text_input, state],
outputs=[
text_input,
select_response_button,
select_response,
past_conversation,
state,
example_submit,
final_submit,
state_display,
],
)
submit_hit_button.click(
lambda state: state,
inputs=[state],
outputs=[state],
)
demo.launch()