nl2sql-api / app.py
KN123's picture
Upload app.py
b2c5d32 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import List, Dict
import time
import datetime
import uvicorn
model = AutoModelForSeq2SeqLM.from_pretrained("KN123/nl2sql")
tokenizer = AutoTokenizer.from_pretrained("KN123/nl2sql")
def get_prompt(tables, question):
prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
# print(prompt)
return prompt
def prepare_input(question: str, tables: Dict[str, List[str]]):
tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
# print(tables)
tables = ", ".join(tables)
# print(tables)
prompt = get_prompt(tables, question)
# print(prompt)
input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
# print(input_ids)
return input_ids
def inference(question: str, tables: Dict[str, List[str]]) -> str:
input_data = prepare_input(question=question, tables=tables)
input_data = input_data.to(model.device)
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
# print("Outputs", outputs)
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
return result
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
@app.get("/")
def home():
return {
"message" : "Hello there! Everything is working fine!",
"api-version": "1.0.0",
"role": "nl2sql",
"description": "This api can be used to convert natural language to SQL given the human prompt, tables and the attributes."
}
@app.get("/test-generate")
def generate(text:str):
start = time.time()
res = inference("how many people with name jui and age less than 25", {
"people_name":["id","name"], "people_age": ["people_id","age"]
})
end = time.time()
total_time_taken = end - start
current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
current_date = datetime.date.today()
timezone_name = time.tzname[time.daylight]
print(res)
return {
"api_response": f"{res}",
"time_taken(s)": f"{total_time_taken}",
"request_details": {
"utc_datetime": f"{current_utc_datetime}",
"current_date": f"{current_date}",
"timezone_name": f"{timezone_name}"
}
}
@app.post("/generate")
def generate(request_body:Dict):
if 'text' not in request_body or 'tables' not in request_body:
raise HTTPException(status_code=400, detail="Missing 'text' or 'tables' in request body")
prompt = request_body['text']
tables = request_body['tables']
start = time.time()
res = inference(prompt, tables)
end = time.time()
total_time_taken = end - start
current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
current_date = datetime.date.today()
timezone_name = time.tzname[time.daylight]
print(res)
return {
"api_response": f"{res}",
"time_taken(s)": f"{total_time_taken}",
"request_details": {
"utc_datetime": f"{current_utc_datetime}",
"current_date": f"{current_date}",
"timezone_name": f"{timezone_name}"
}
}
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8000)