test / app.py
goendalf666's picture
Update app.py
37fb80b
raw
history blame contribute delete
No virus
1.59 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Initialize the model and tokenizer
cuda = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained("goendalf666/salesGPT_v2", trust_remote_code=True).to(cuda)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
def interact_with_model(user_input):
# Construct conversation text for the model
conversation_text = (
"You are in the role of a Salesman. "
"Here is a conversation: "
f"Customer: {user_input} Salesman: "
)
# Tokenize inputs
inputs = tokenizer(conversation_text, return_tensors="pt").to(cuda)
# Generate response
outputs = model.generate(**inputs, max_length=512)
response_text = tokenizer.batch_decode(outputs)[0]
# Extract only the newly generated text
new_text_start = len(conversation_text)
new_generated_text = response_text[new_text_start:].strip()
# Find where the next "Customer:" is, and truncate the text there
end_index = new_generated_text.find("Customer:")
if end_index != -1:
new_generated_text = new_generated_text[:end_index].strip()
# Ignore if the model puts "Salesman: " itself at the beginning
if new_generated_text.startswith("Salesman:"):
new_generated_text = new_generated_text[len("Salesman:"):].strip()
# Return the model's response
return new_generated_text
# Create Gradio Interface and launch it
iface = gr.Interface(fn=interact_with_model, inputs="text", outputs="text")
iface.launch()