nl2sql-api / app.py
KN123's picture
Upload 3 files
e3fc3b8 verified
raw
history blame
2.48 kB
from fastapi import FastAPI
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import List, Dict
import time
import datetime
import uvicorn
model = AutoModelForSeq2SeqLM.from_pretrained("KN123/nl2sql")
tokenizer = AutoTokenizer.from_pretrained("KN123/nl2sql")
def get_prompt(tables, question):
prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
# print(prompt)
return prompt
def prepare_input(question: str, tables: Dict[str, List[str]]):
tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
# print(tables)
tables = ", ".join(tables)
# print(tables)
prompt = get_prompt(tables, question)
# print(prompt)
input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
# print(input_ids)
return input_ids
def inference(question: str, tables: Dict[str, List[str]]) -> str:
input_data = prepare_input(question=question, tables=tables)
input_data = input_data.to(model.device)
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
# print("Outputs", outputs)
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
return result
app = FastAPI()
@app.get("/")
def home():
return {
"message" : "Hello there! Everything is working fine!",
"api-version": "1.0.0",
"role": "nl2sql",
"description": "This api can be used to convert natural language to SQL given the human prompt, tables and the attributes."
}
@app.get("/generate")
def generate(text:str):
start = time.time()
res = inference("how many people with name jui and age less than 25", {
"people_name":["id","name"], "people_age": ["people_id","age"]
})
end = time.time()
total_time_taken = end - start
current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
current_date = datetime.date.today()
timezone_name = time.tzname[time.daylight]
print(res)
return {
"api_response": f"{res}",
"time_taken(s)": f"{total_time_taken}",
"request_details": {
"utc_datetime": f"{current_utc_datetime}",
"current_date": f"{current_date}",
"timezone_name": f"{timezone_name}"
}
}
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8000)