adi2606 commited on
Commit
22209ca
·
verified ·
1 Parent(s): db93709

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -2,13 +2,22 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
- # Load model and tokenizer
6
- model = AutoModelForCausalLM.from_pretrained("adi2606/MenstrualQA")
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("adi2606/MenstrualQA")
8
 
9
  # Function to generate a response from the chatbot
10
- def generate_response(message: str, temperature: float = 0.4, repetition_penalty: float = 1.1) -> str:
11
- inputs = tokenizer(message, return_tensors="pt", padding=True, truncation=True)
 
 
 
 
 
 
12
 
13
  # Generate the response
14
  output = model.generate(
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ # Set up the device to use CPU only
6
+ device = torch.device("cpu")
7
+
8
+ # Load model and tokenizer, then move the model to the appropriate device
9
+ model = AutoModelForCausalLM.from_pretrained("adi2606/MenstrualQA").to(device)
10
  tokenizer = AutoTokenizer.from_pretrained("adi2606/MenstrualQA")
11
 
12
  # Function to generate a response from the chatbot
13
+ def generate_response(message: str, temperature: float = 0.4, repetition_penalty: float = 1.1, max_input_length: int = 256) -> str:
14
+ inputs = tokenizer(
15
+ message,
16
+ return_tensors="pt",
17
+ padding=True,
18
+ truncation=True,
19
+ max_length=max_input_length
20
+ ).to(device)
21
 
22
  # Generate the response
23
  output = model.generate(