bol20162021's picture
Update app.py
93087d7 verified
"""Streamlit app for demoing SambaCoder-nsql-llama-2-70b."""
import json
import os
import pandas as pd
import requests
import streamlit as st
from manifest import Manifest, Response
from manifest.connections.client_pool import ClientConnection
STOP_TOKENS = ["###", ";", "--", "```"]
def generate_prompt(question, schema):
return f"""{schema}\n\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- {question}\n"""
def generate_sql(question, schema):
prompt = generate_prompt(question, schema)
url = st.secrets["backend_url"]
headers = {
"Content-Type": "application/json",
"key": st.secrets["key"],
}
data = {
"inputs": [prompt],
"params": {
"do_sample": {"type": "bool", "value": "false"},
"max_tokens_to_generate": {"type": "int", "value": "2048"},
"repetition_penalty": {"type": "float", "value": "1"},
"temperature": {"type": "float", "value": "1"},
"top_k": {"type": "int", "value": "50"},
"top_p": {"type": "float", "value": "1"},
"process_prompt":{"type":"bool","value":"false"},
"select_expert":{"type":"str","value":"SambaCoder-nsql-llama-2-70b"}
},
}
r = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
if r.encoding is None:
r.encoding = "utf-8"
for line in r.iter_lines(decode_unicode=True):
if line and line.startswith("data: "):
output = json.loads(line[len("data: ") :])
token = output.get("stream_token", "")
if len(token) > 0:
yield token
st.title("SambaCoder-nsql-llama-2-70b Demo")
expander = st.expander("Database Schema")
# Input field for text prompt
# TODO(Bo Li): update this with the new example
default_schema = """CREATE TABLE stadium (
stadium_id number,
location text,
name text,
capacity number,
highest number,
lowest number,
average number
)
CREATE TABLE singer (
singer_id number,
name text,
country text,
song_name text,
song_release_year text,
age number,
is_male others
)
CREATE TABLE concert (
concert_id number,
concert_name text,
theme text,
stadium_id text,
year text
)
CREATE TABLE singer_in_concert (
concert_id number,
singer_id text
)"""
schema = expander.text_area("Current schema:", value=default_schema, height=500)
# Input field for text prompt
text_prompt = st.text_input(
"Please let me know what question do you want to ask?",
value="What is the maximum, the average, and the minimum capacity of stadiums ?",
)
# if text_prompt or
if st.button("Generate SQL"):
sql_query = generate_sql(text_prompt, schema)
st.write_stream(sql_query)