|
import streamlit as st |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import AutoPeftModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "amiguel/itemClassification_Alpaca_Mistral" |
|
model = AutoPeftModelForCausalLM.from_pretrained(model_name, load_in_4bit = load_in_4bit,) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
st.title("GPT2 Text Generation App") |
|
|
|
|
|
prompt = st.text_input("Enter the prompt:") |
|
max_length = st.number_input("Enter the maximum length of the generated text:", value=100) |
|
num_beams = st.number_input("Enter the number of beams:", value=4) |
|
|
|
|
|
output_field = st.empty() |
|
|
|
|
|
def infer(): |
|
|
|
inputs = tokenizer([prompt], return_tensors="pt", max_length=max_length, truncation=True) |
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
outputs = model.generate(**inputs, max_length=max_length, num_beams=num_beams) |
|
|
|
output_str = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
output_field.text(output_str) |
|
|
|
|
|
infer_button = st.button("Generate Text") |
|
|
|
|
|
if infer_button: |
|
infer() |