Mikhil-jivus commited on
Commit
b7d6aa3
1 Parent(s): c2a0993

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -36,19 +36,16 @@ def respond(
36
 
37
  messages.append({"role": "user", "content": message})
38
 
39
- # Tokenize the input messages with dynamic padding and truncation
40
  input_text = system_message + " ".join([f"{msg['role']}: {msg['content']}" for msg in messages])
41
- inputs = tokenizer(
42
- input_text,
43
- return_tensors="pt",
44
- padding=True, # Dynamically pad to the longest sequence in the batch
45
- truncation=True, # Truncate if exceeds max length
46
- max_length=max_tokens # Ensure max length is respected
47
- )
48
-
49
- input_ids = inputs["input_ids"]
50
- attention_mask = inputs["attention_mask"]
51
-
52
  # Generate a response
53
  chat_history_ids = model.generate(
54
  input_ids,
@@ -57,9 +54,9 @@ def respond(
57
  top_p=top_p,
58
  pad_token_id=tokenizer.eos_token_id,
59
  do_sample=True,
60
- attention_mask=attention_mask, # Use the dynamically generated attention mask
61
  )
62
-
63
  # Decode the response
64
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
65
 
 
36
 
37
  messages.append({"role": "user", "content": message})
38
 
39
+ # Tokenize the input messages
40
  input_text = system_message + " ".join([f"{msg['role']}: {msg['content']}" for msg in messages])
41
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
42
+
43
+ # Move input_ids to the GPU
44
+ input_ids = input_ids.to("cuda")
45
+
46
+ # Create attention mask and move to GPU
47
+ attention_mask = input_ids.ne(tokenizer.pad_token_id).long().to("cuda")
48
+
 
 
 
49
  # Generate a response
50
  chat_history_ids = model.generate(
51
  input_ids,
 
54
  top_p=top_p,
55
  pad_token_id=tokenizer.eos_token_id,
56
  do_sample=True,
57
+ attention_mask=attention_mask,
58
  )
59
+
60
  # Decode the response
61
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
62