Pretrain-GPT2 / app.py
ahmadmac's picture
Update app.py
293e1b9 verified
raw
history blame
851 Bytes
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("ahmadmac/Pretrained-GPT2")
model = AutoModelForCausalLM.from_pretrained("ahmadmac/Pretrained-GPT2")
def generate_text(prompt, max_length=50, num_return_sequences=1, temperature=0.7):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(
input_ids,
max_length=max_length,
num_return_sequences=num_return_sequences,
temperature=0.7 
)
return tokenizer.decode(output[0], skip_special_tokens=True)
def main():
st.title("Text Generator")
prompt = st.text_input("Enter your prompt:")
if st.button("Generate"):
generated_text = generate_text(prompt)
st.text_area("Generated Text:", generated_text)
if __name__ == "__main__":
main()