adi2606 commited on
Commit
2d0ce20
·
verified ·
1 Parent(s): 8f9a9d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -2,24 +2,23 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
- # Set up the device
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
-
8
  # Load model and tokenizer
9
- model = AutoModelForCausalLM.from_pretrained("adi2606/MenstrualQA").to(device)
10
  tokenizer = AutoTokenizer.from_pretrained("adi2606/MenstrualQA")
11
 
12
  # Function to generate a response from the chatbot
13
  def generate_response(message: str, temperature: float = 0.4, repetition_penalty: float = 1.1) -> str:
14
- inputs = tokenizer(message, return_tensors="pt").to(device)
15
 
16
  # Generate the response
17
  output = model.generate(
18
  inputs['input_ids'],
 
19
  max_length=512,
20
  temperature=temperature,
21
  repetition_penalty=repetition_penalty,
22
- do_sample=True
 
23
  )
24
 
25
  # Decode the generated output
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
 
 
 
5
  # Load model and tokenizer
6
+ model = AutoModelForCausalLM.from_pretrained("adi2606/MenstrualQA")
7
  tokenizer = AutoTokenizer.from_pretrained("adi2606/MenstrualQA")
8
 
9
  # Function to generate a response from the chatbot
10
  def generate_response(message: str, temperature: float = 0.4, repetition_penalty: float = 1.1) -> str:
11
+ inputs = tokenizer(message, return_tensors="pt", padding=True, truncation=True).to(device)
12
 
13
  # Generate the response
14
  output = model.generate(
15
  inputs['input_ids'],
16
+ attention_mask=inputs['attention_mask'],
17
  max_length=512,
18
  temperature=temperature,
19
  repetition_penalty=repetition_penalty,
20
+ do_sample=True,
21
+ pad_token_id=tokenizer.eos_token_id
22
  )
23
 
24
  # Decode the generated output