teaevo commited on
Commit
46920ac
·
1 Parent(s): 8b84431

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -22
app.py CHANGED
@@ -42,30 +42,33 @@ def predict(input, history=[]):
42
  # Check if the user input is a question
43
  is_question = "?" in input
44
 
45
- if is_question:
46
- sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
47
- #sql_outputs = sql_model.generate(**sql_encoding)
48
- #sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
49
-
50
- bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
51
- history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
52
- response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
53
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
54
-
55
- else:
56
- # tokenize the new input sentence
57
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # append the new user input tokens to the chat history
60
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
61
-
62
- # generate a response
63
- history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
64
 
65
- # convert the tokens to text, and then split the responses into the right format
66
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
67
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
68
-
69
  return response, history
70
 
71
 
 
42
  # Check if the user input is a question
43
  is_question = "?" in input
44
 
45
+
46
+ # tokenize the new input sentence
47
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
48
+
49
+ # append the new user input tokens to the chat history
50
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
51
+
52
+ # generate a response
53
+ history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
54
+ response_dialog = tokenizer.decode(history[0])
55
+
56
+ # Use the SQL model to generate a response
57
+ encoding = sql_tokenizer(table=table, query=response_dialog, return_tensors="pt")
58
+ outputs = sql_model.generate(**encoding)
59
+ response_sql = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
60
+
61
+ # Add the SQL model's response to the chat history
62
+ history.extend(response_sql)
63
+
64
+
65
+ # convert the tokens to text, and then split the responses into the right format
66
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
67
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
68
+
69
+
70
 
 
 
 
 
 
71
 
 
 
 
 
72
  return response, history
73
 
74