Spaces:
Sleeping
Sleeping
# src/app.py | |
from flask import Flask, request, render_template | |
import torch | |
from model import TransformerModel | |
from utils import load_vocab, tokenize | |
import time | |
import random | |
import os | |
app = Flask(__name__, template_folder='templates') | |
# Configuration | |
MODEL_PATH = 'models/3ed0k4_model_epoch10.pth' # Update this path based on the latest model | |
VOCAB_PATH = 'vocab.json' | |
EMBED_SIZE = 256 | |
NUM_HEADS = 8 | |
HIDDEN_DIM = 512 | |
NUM_LAYERS = 4 | |
DROPOUT = 0.1 | |
MAX_LENGTH = 100 # Maximum tokens to generate | |
# Load vocabulary | |
vocab = load_vocab(VOCAB_PATH) | |
vocab_size = len(vocab) | |
# Initialize model | |
model = TransformerModel( | |
vocab_size=vocab_size, | |
embed_size=EMBED_SIZE, | |
num_heads=NUM_HEADS, | |
hidden_dim=HIDDEN_DIM, | |
num_layers=NUM_LAYERS, | |
dropout=DROPOUT | |
) | |
# Load model weights | |
if not os.path.exists(MODEL_PATH): | |
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please train the model first.") | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) | |
model.eval() | |
def generate_text(prompt, max_length=MAX_LENGTH): | |
tokens = tokenize(prompt) | |
numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens] | |
input_seq = torch.tensor(numericalized, dtype=torch.long).unsqueeze(0) # Batch size 1 | |
generated = numericalized.copy() | |
with torch.no_grad(): | |
for _ in range(max_length): | |
src_mask = model.generate_square_subsequent_mask(input_seq.size(1)).to(input_seq.device) | |
outputs = model(input_seq, src_mask) | |
next_token_logits = outputs[0, -1, :] | |
next_token = torch.argmax(next_token_logits).item() | |
if next_token == vocab['<PAD>']: | |
break | |
generated.append(next_token) | |
input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(0) | |
# Convert numerical tokens back to words | |
inv_vocab = {idx: word for word, idx in vocab.items()} | |
generated_tokens = [inv_vocab.get(tok, '<UNK>') for tok in generated] | |
return ' '.join(generated_tokens) | |
def index(): | |
return render_template('index.html') | |
def chat(): | |
message = request.form.get('message') | |
if not message: | |
return render_template('index.html') | |
# Simulate thinking delay | |
delay = random.randint(1, 10) | |
print(f"Thinking for {delay} seconds...") | |
time.sleep(delay) | |
response = generate_text(message) | |
return render_template('index.html', message=message, response=response) | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=5000) | |