Spaces:
Running
Running
# app.py | |
import streamlit as st | |
import torch | |
from src.model import TransformerModel | |
from src.utils import load_vocab, tokenize | |
import time | |
import random | |
import os | |
# Configuration | |
MODEL_PATH = 'models/3ed0k4_model_epoch9.pth' # Update this path based on the latest model | |
VOCAB_PATH = 'vocab.json' | |
EMBED_SIZE = 256 | |
NUM_HEADS = 8 | |
HIDDEN_DIM = 512 | |
NUM_LAYERS = 4 | |
DROPOUT = 0.1 | |
MAX_LENGTH = 100 # Maximum tokens to generate | |
# Title and Description | |
st.title("3ed0k4 NLP Text Generation Model π") | |
st.write("Enter a prompt, and the model will generate text based on your input. It will take 1 to 10 seconds to respond to simulate 'thinking'.") | |
# Load vocabulary | |
def load_resources(): | |
vocab = load_vocab(VOCAB_PATH) | |
return vocab | |
vocab = load_resources() | |
vocab_size = len(vocab) | |
# Initialize model | |
def load_model(): | |
model = TransformerModel( | |
vocab_size=vocab_size, | |
embed_size=EMBED_SIZE, | |
num_heads=NUM_HEADS, | |
hidden_dim=HIDDEN_DIM, | |
num_layers=NUM_LAYERS, | |
dropout=DROPOUT | |
) | |
if not os.path.exists(MODEL_PATH): | |
st.error(f"Model file not found at {MODEL_PATH}. Please ensure the model is trained and the path is correct.") | |
return None | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) | |
model.eval() | |
return model | |
model = load_model() | |
def generate_text(prompt, max_length=MAX_LENGTH): | |
tokens = tokenize(prompt) | |
numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens] | |
input_seq = torch.tensor(numericalized, dtype=torch.long).unsqueeze(0) # Batch size 1 | |
generated = numericalized.copy() | |
with torch.no_grad(): | |
for _ in range(max_length): | |
src_mask = model.generate_square_subsequent_mask(input_seq.size(1)).to(input_seq.device) | |
outputs = model(input_seq, src_mask) | |
next_token_logits = outputs[0, -1, :] | |
next_token = torch.argmax(next_token_logits).item() | |
if next_token == vocab.get('<PAD>', 0): | |
break | |
generated.append(next_token) | |
input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(0) | |
# Convert numerical tokens back to words | |
inv_vocab = {idx: word for word, idx in vocab.items()} | |
generated_tokens = [inv_vocab.get(tok, '<UNK>') for tok in generated] | |
return ' '.join(generated_tokens) | |
# User Inputs | |
prompt = st.text_input("Enter your prompt:", "") | |
delay = st.slider("Select thinking delay (seconds):", min_value=1, max_value=10, value=3) | |
if st.button("Generate"): | |
if not model: | |
st.error("Model is not loaded. Please check the model path.") | |
elif prompt.strip() == "": | |
st.warning("Please enter a prompt to generate text.") | |
else: | |
with st.spinner("Thinking..."): | |
time.sleep(delay) | |
response = generate_text(prompt) | |
st.success("Here's the generated text:") | |
st.write(response) | |