from fastapi import FastAPI from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from typing import List, Dict import time import datetime import uvicorn model = AutoModelForSeq2SeqLM.from_pretrained("KN123/nl2sql") tokenizer = AutoTokenizer.from_pretrained("KN123/nl2sql") def get_prompt(tables, question): prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}""" # print(prompt) return prompt def prepare_input(question: str, tables: Dict[str, List[str]]): tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables] # print(tables) tables = ", ".join(tables) # print(tables) prompt = get_prompt(tables, question) # print(prompt) input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids # print(input_ids) return input_ids def inference(question: str, tables: Dict[str, List[str]]) -> str: input_data = prepare_input(question=question, tables=tables) input_data = input_data.to(model.device) outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) # print("Outputs", outputs) result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) return result app = FastAPI() @app.get("/") def home(): return { "message" : "Hello there! Everything is working fine!", "api-version": "1.0.0", "role": "nl2sql", "description": "This api can be used to convert natural language to SQL given the human prompt, tables and the attributes." } @app.get("/generate") def generate(text:str): start = time.time() res = inference("how many people with name jui and age less than 25", { "people_name":["id","name"], "people_age": ["people_id","age"] }) end = time.time() total_time_taken = end - start current_utc_datetime = datetime.datetime.now(datetime.timezone.utc) current_date = datetime.date.today() timezone_name = time.tzname[time.daylight] print(res) return { "api_response": f"{res}", "time_taken(s)": f"{total_time_taken}", "request_details": { "utc_datetime": f"{current_utc_datetime}", "current_date": f"{current_date}", "timezone_name": f"{timezone_name}" } } if __name__ == "__main__": uvicorn.run(app, host="127.0.0.1", port=8000)