KN123 commited on
Commit
e1558a9
1 Parent(s): e3fc3b8

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. main.py +70 -0
Dockerfile CHANGED
@@ -10,4 +10,4 @@ ENV HOME = /home/user \
10
  WORKDIR $HOME/app
11
  COPY --chown==user . $HOME/app
12
 
13
- CMD [ "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860" ]
 
10
  WORKDIR $HOME/app
11
  COPY --chown==user . $HOME/app
12
 
13
+ CMD [ "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860" ]
main.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from typing import List, Dict
4
+ import time
5
+ import datetime
6
+ import uvicorn
7
+
8
+ model = AutoModelForSeq2SeqLM.from_pretrained("KN123/nl2sql")
9
+ tokenizer = AutoTokenizer.from_pretrained("KN123/nl2sql")
10
+
11
+ def get_prompt(tables, question):
12
+ prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
13
+ # print(prompt)
14
+ return prompt
15
+
16
+ def prepare_input(question: str, tables: Dict[str, List[str]]):
17
+ tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
18
+ # print(tables)
19
+ tables = ", ".join(tables)
20
+ # print(tables)
21
+ prompt = get_prompt(tables, question)
22
+ # print(prompt)
23
+ input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
24
+ # print(input_ids)
25
+ return input_ids
26
+
27
+ def inference(question: str, tables: Dict[str, List[str]]) -> str:
28
+ input_data = prepare_input(question=question, tables=tables)
29
+ input_data = input_data.to(model.device)
30
+ outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
31
+ # print("Outputs", outputs)
32
+ result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
33
+ return result
34
+
35
+ app = FastAPI()
36
+
37
+ @app.get("/")
38
+ def home():
39
+ return {
40
+ "message" : "Hello there! Everything is working fine!",
41
+ "api-version": "1.0.0",
42
+ "role": "nl2sql",
43
+ "description": "This api can be used to convert natural language to SQL given the human prompt, tables and the attributes."
44
+ }
45
+
46
+ @app.get("/generate")
47
+ def generate(text:str):
48
+ start = time.time()
49
+ res = inference("how many people with name jui and age less than 25", {
50
+ "people_name":["id","name"], "people_age": ["people_id","age"]
51
+ })
52
+ end = time.time()
53
+ total_time_taken = end - start
54
+ current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
55
+ current_date = datetime.date.today()
56
+ timezone_name = time.tzname[time.daylight]
57
+ print(res)
58
+ return {
59
+ "api_response": f"{res}",
60
+ "time_taken(s)": f"{total_time_taken}",
61
+ "request_details": {
62
+ "utc_datetime": f"{current_utc_datetime}",
63
+ "current_date": f"{current_date}",
64
+ "timezone_name": f"{timezone_name}"
65
+ }
66
+ }
67
+
68
+
69
+ if __name__ == "__main__":
70
+ uvicorn.run(app, host="127.0.0.1", port=8000)