File size: 2,479 Bytes
e1558a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from fastapi import FastAPI
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import List, Dict
import time
import datetime
import uvicorn

model = AutoModelForSeq2SeqLM.from_pretrained("KN123/nl2sql")
tokenizer = AutoTokenizer.from_pretrained("KN123/nl2sql")

def get_prompt(tables, question):
    prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
    # print(prompt)
    return prompt

def prepare_input(question: str, tables: Dict[str, List[str]]):
    tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
    # print(tables)
    tables = ", ".join(tables)
    # print(tables)
    prompt = get_prompt(tables, question)
    # print(prompt)
    input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
    # print(input_ids)
    return input_ids

def inference(question: str, tables: Dict[str, List[str]]) -> str:
    input_data = prepare_input(question=question, tables=tables)
    input_data = input_data.to(model.device)
    outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
    # print("Outputs", outputs)
    result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
    return result

app = FastAPI()

@app.get("/")
def home():
    return {
        "message" : "Hello there! Everything is working fine!",
        "api-version": "1.0.0",
        "role": "nl2sql",
        "description": "This api can be used to convert natural language to SQL given the human prompt, tables and the attributes."
    }

@app.get("/generate")
def generate(text:str):
    start = time.time()
    res = inference("how many people with name jui and age less than 25", {
        "people_name":["id","name"], "people_age": ["people_id","age"]
    })
    end = time.time()
    total_time_taken = end - start
    current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
    current_date = datetime.date.today()
    timezone_name = time.tzname[time.daylight]
    print(res)
    return {
        "api_response": f"{res}",
        "time_taken(s)": f"{total_time_taken}",
        "request_details": {
            "utc_datetime": f"{current_utc_datetime}",
            "current_date": f"{current_date}",
            "timezone_name": f"{timezone_name}"
        }
    }


if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8000)