{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from langchain.chains import create_sql_query_chain\n", "from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline, LlamaTokenizer, LlamaForCausalLM\n", "from langchain_huggingface import HuggingFacePipeline\n", "from langchain_openai import ChatOpenAI\n", "import os\n", "from langchain_community.utilities.sql_database import SQLDatabase\n", "from operator import itemgetter\n", "from langchain_core.output_parsers import StrOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "from langchain_core.runnables import RunnablePassthrough\n", "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate\n", "from langchain_community.vectorstores import Chroma\n", "from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n", "from langchain_openai import OpenAIEmbeddings\n", "from operator import itemgetter\n", "from langchain.chains.openai_tools import create_extraction_chain_pydantic\n", "from langchain_core.pydantic_v1 import BaseModel, Field\n", "from typing import List\n", "import pandas as pd\n", "from argparse import ArgumentParser\n", "import json\n", "from langchain.memory import ChatMessageHistory\n", "from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n", "import subprocess\n", "import sys\n", "from transformers import pipeline\n", "import librosa\n", "import soundfile\n", "import datasets\n", "import sounddevice as sd\n", "import numpy as np\n", "import io\n", "import gradio as gr" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n" ] } ], "source": [ "model_id = \"avnishkanungo/whisper-small-dv\" # update with your model id\n", "pipe = pipeline(\"automatic-speech-recognition\", model=model_id)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def select_table(desc_path):\n", " def get_table_details():\n", " # Read the CSV file into a DataFrame\n", " table_description = pd.read_csv(desc_path) ##\"/teamspace/studios/this_studio/database_table_descriptions.csv\"\n", " table_docs = []\n", "\n", " # Iterate over the DataFrame rows to create Document objects\n", " table_details = \"\"\n", " for index, row in table_description.iterrows():\n", " table_details = table_details + \"Table Name:\" + row['Table'] + \"\\n\" + \"Table Description:\" + row['Description'] + \"\\n\\n\"\n", "\n", " return table_details\n", " \n", " class Table(BaseModel):\n", " \"\"\"Table in SQL database.\"\"\"\n", "\n", " name: str = Field(description=\"Name of table in SQL database.\")\n", " \n", " table_details_prompt = f\"\"\"Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \\\n", " The tables are:\n", "\n", " {get_table_details()}\n", "\n", " Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.\"\"\"\n", "\n", " table_chain = create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt)\n", " \n", " def get_tables(tables: List[Table]) -> List[str]:\n", " tables = [table.name for table in tables]\n", " return tables\n", "\n", " select_table = {\"input\": itemgetter(\"question\")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables\n", "\n", " return select_table\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def prompt_creation(example_path):\n", "\n", " with open(example_path, 'r') as file: ##'/teamspace/studios/this_studio/few_shot_samples.json'\n", " data = json.load(file)\n", "\n", " examples = data[\"examples\"]\n", "\n", " example_prompt = ChatPromptTemplate.from_messages(\n", " [\n", " (\"human\", \"{input}\\nSQLQuery:\"),\n", " (\"ai\", \"{query}\"),\n", " ]\n", " )\n", "\n", " vectorstore = Chroma()\n", " vectorstore.delete_collection()\n", " example_selector = SemanticSimilarityExampleSelector.from_examples(\n", " examples,\n", " OpenAIEmbeddings(),\n", " vectorstore,\n", " k=2,\n", " input_keys=[\"input\"],\n", " )\n", " \n", " few_shot_prompt = FewShotChatMessagePromptTemplate(\n", " example_prompt=example_prompt,\n", " example_selector=example_selector,\n", " input_variables=[\"input\",\"top_k\"],\n", " )\n", "\n", "\n", " final_prompt = ChatPromptTemplate.from_messages(\n", " [\n", " (\"system\", \"You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\\n\\nHere is the relevant table info: {table_info}\\n\\nBelow are a number of examples of questions and their corresponding SQL queries.\"),\n", " few_shot_prompt,\n", " MessagesPlaceholder(variable_name=\"messages\"),\n", " (\"human\", \"{input}\"),\n", " ]\n", " )\n", "\n", " print(few_shot_prompt.format(input=\"How many products are there?\"))\n", " \n", " return final_prompt" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def rephrase_answer():\n", " answer_prompt = PromptTemplate.from_template(\n", " \"\"\"Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n", "\n", " Question: {question}\n", " SQL Query: {query}\n", " SQL Result: {result}\n", " Answer: \"\"\"\n", " )\n", "\n", " rephrase_answer = answer_prompt | llm | StrOutputParser()\n", "\n", " return rephrase_answer" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def is_ffmpeg_installed():\n", " try:\n", " # Run `ffmpeg -version` to check if ffmpeg is installed\n", " subprocess.run(['ffmpeg', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", " return True\n", " except (subprocess.CalledProcessError, FileNotFoundError):\n", " return False\n", "\n", "def install_ffmpeg():\n", " try:\n", " if sys.platform.startswith('linux'):\n", " subprocess.run(['sudo', 'apt-get', 'update'], check=True)\n", " subprocess.run(['sudo', 'apt-get', 'install', '-y', 'ffmpeg'], check=True)\n", " elif sys.platform == 'darwin': # macOS\n", " subprocess.run(['/bin/bash', '-c', 'brew install ffmpeg'], check=True)\n", " elif sys.platform == 'win32':\n", " print(\"Please download ffmpeg from https://ffmpeg.org/download.html and install it manually.\")\n", " return False\n", " else:\n", " print(\"Unsupported OS. Please install ffmpeg manually.\")\n", " return False\n", " except subprocess.CalledProcessError as e:\n", " print(f\"Failed to install ffmpeg: {e}\")\n", " return False\n", " return True" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def transcribe_speech(filepath):\n", " output = pipe(\n", " filepath,\n", " max_new_tokens=256,\n", " generate_kwargs={\n", " \"task\": \"transcribe\",\n", " \"language\": \"english\",\n", " }, # update with the language you've fine-tuned on\n", " chunk_length_s=30,\n", " batch_size=8,\n", " )\n", " return output[\"text\"]\n", " \n", "def record_command():\n", " sample_rate = 16000 # Sample rate in Hz\n", " duration = 8 # Duration in seconds\n", "\n", " print(\"Recording...\")\n", "\n", " # Record audio\n", " audio = sd.rec(int(sample_rate * duration), samplerate=sample_rate, channels=1, dtype='float32')\n", " sd.wait() # Wait until recording is finished\n", "\n", " print(\"Recording finished\")\n", "\n", " # Convert the audio to a binary stream and save it to a variable\n", " audio_buffer = io.BytesIO()\n", " soundfile.write(audio_buffer, audio, sample_rate, format='WAV')\n", " audio_buffer.seek(0) # Reset buffer position to the beginning\n", "\n", " # The audio file is now saved in audio_buffer\n", " # You can read it again using soundfile or any other audio library\n", " audio_data, sample_rate = soundfile.read(audio_buffer)\n", "\n", " # Optional: Save the audio to a file for verification\n", " # with open('recorded_audio.wav', 'wb') as f:\n", " # f.write(audio_buffer.getbuffer())\n", "\n", " print(\"Audio saved to variable\")\n", " return audio_data" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def sql_translator(filepath,key):\n", " # Please configure your DB credentials and paths of the files for few shot learning and fine tuning\n", " db_user = \"root\"\n", " db_password = \"\"\n", " db_host = \"localhost\"\n", " db_name = \"classicmodels\"\n", "\n", " db = SQLDatabase.from_uri(f\"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}\")\n", " # print(db.dialect)\n", " # print(db.get_usable_table_names())\n", " # print(db.table_info)\n", " os.environ[\"OPENAI_API_KEY\"] = key\n", "\n", " history = ChatMessageHistory()\n", "\n", " llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", "\n", " final_prompt = prompt_creation(os.getcwd()+\"/few_shot_samples.json\")\n", "\n", " generate_query = create_sql_query_chain(llm, db, final_prompt)\n", "\n", " execute_query = QuerySQLDataBaseTool(db=db)\n", "\n", " sql_query = transcribe_speech(filepath)\n", " chain = (\n", " RunnablePassthrough.assign(table_names_to_use=select_table(os.getcwd()+\"/database_table_descriptions.csv\")) |\n", " RunnablePassthrough.assign(query=generate_query).assign(\n", " result=itemgetter(\"query\") | execute_query\n", " )\n", " | rephrase_answer()\n", " )\n", "\n", " output = chain.invoke({\"question\": sql_query, \"messages\":history.messages})\n", " history.add_user_message(sql_query)\n", " history.add_ai_message(output)\n", "\n", " return output" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def create_interface():\n", " with gr.Blocks() as interface:\n", " gr.Markdown(\"## Audio and Text Processing Interface\")\n", "\n", " # Text input component\n", " text_input = gr.Textbox(lines=2, placeholder=\"Enter text here...\", label=\"Open AI Key\")\n", " \n", " # Audio input component\n", " audio_input = gr.Audio(sources=\"microphone\", type=\"filepath\", label=\"Record or Upload Audio\")\n", "\n", " # with gr.TabbedInterface([\"Audio Input\", \"Text Input\"]) as tabs:\n", " # with gr.Tab(\"Audio Input\"):\n", " # audio_input = gr.Audio(source=\"microphone\", type=\"filepath\", label=\"Record or Upload Audio\")\n", " # selected_input = audio_input # Reference to the selected input component\n", "\n", " # with gr.Tab(\"Text Input\"):\n", " # query_input = gr.Textbox(lines=2, placeholder=\"Enter text here...\", label=\"Input Text\")\n", " # selected_input = query_input # Reference to the selected input component\n", " \n", " # Button to trigger processing\n", " process_button = gr.Button(\"Process\")\n", " \n", " # Output component\n", " output_text = gr.Textbox(label=\"Output\")\n", "\n", " # Define the action for the button click\n", " process_button.click(fn=sql_translator, inputs=[audio_input, text_input], outputs=output_text)\n", "\n", " return interface" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Just Audio UI" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7861\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "