| import gradio as gr |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, AutoModel |
| import torch |
| import pandas as pd |
| import os |
|
|
| os.environ['CURL_CA_BUNDLE'] = '' |
|
|
| |
| issues_dataset = load_dataset("gvozdev/subspace-info-v2", split="train") |
|
|
| |
| model_ckpt = "sentence-transformers/all-MiniLM-L12-v1" |
| tokenizer = AutoTokenizer.from_pretrained(model_ckpt) |
| model = AutoModel.from_pretrained(model_ckpt, trust_remote_code=True) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| issues_dataset = issues_dataset.map() |
|
|
| |
| device = torch.device("cpu") |
| model.to(device) |
|
|
|
|
| |
| def cls_pooling(model_output): |
| return model_output.last_hidden_state[:, 0] |
|
|
|
|
| |
| |
| def get_embeddings(text_list): |
| encoded_input = tokenizer( |
| text_list, padding=True, truncation=True, return_tensors="pt" |
| ) |
| encoded_input = {k: v.to(device) for k, v in encoded_input.items()} |
| model_output = model(**encoded_input) |
| return cls_pooling(model_output) |
|
|
| |
| |
| |
|
|
|
|
| |
| |
| embeddings_dataset = issues_dataset.map( |
| lambda x: {"embeddings": get_embeddings(x["subject"]).detach().cpu().numpy()[0]} |
| ) |
|
|
| |
| embeddings_dataset.add_faiss_index(column="embeddings") |
|
|
|
|
| |
| def answer_question(question): |
| |
| question_embedding = get_embeddings([question]).cpu().detach().numpy() |
|
|
| |
| scores, samples = embeddings_dataset.get_nearest_examples( |
| "embeddings", question_embedding, k=1 |
| ) |
|
|
| samples_df = pd.DataFrame.from_dict(samples) |
|
|
| |
| |
| |
|
|
| return samples_df["details"].values[0] |
|
|
|
|
| |
| title = "Subspace Docs bot" |
| description = '<p style="text-align: center;">This is a bot trained on Subspace Network documentation ' \ |
| 'to answer the most common questions about the project</p>' |
|
|
|
|
| def chat(message, history): |
| history = history or [] |
| response = answer_question(message) |
| history.append((message, response)) |
| return history, history |
|
|
|
|
| iface = gr.Interface( |
| chat, |
| ["text", "state"], |
| ["chatbot", "state"], |
| allow_flagging="never", |
| title=title, |
| description=description, |
| theme="Monochrome", |
| examples=["What is Subspace Network?", "Do you have a token?", "System requirements"] |
| ) |
|
|
| iface.launch(share=False) |
|
|