# src/app.py from flask import Flask, request, render_template import torch from model import TransformerModel from utils import load_vocab, tokenize import time import random import os app = Flask(__name__, template_folder='templates') # Configuration MODEL_PATH = 'models/3ed0k4_model_epoch10.pth' # Update this path based on the latest model VOCAB_PATH = 'vocab.json' EMBED_SIZE = 256 NUM_HEADS = 8 HIDDEN_DIM = 512 NUM_LAYERS = 4 DROPOUT = 0.1 MAX_LENGTH = 100 # Maximum tokens to generate # Load vocabulary vocab = load_vocab(VOCAB_PATH) vocab_size = len(vocab) # Initialize model model = TransformerModel( vocab_size=vocab_size, embed_size=EMBED_SIZE, num_heads=NUM_HEADS, hidden_dim=HIDDEN_DIM, num_layers=NUM_LAYERS, dropout=DROPOUT ) # Load model weights if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please train the model first.") model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) model.eval() def generate_text(prompt, max_length=MAX_LENGTH): tokens = tokenize(prompt) numericalized = [vocab.get(token, vocab['']) for token in tokens] input_seq = torch.tensor(numericalized, dtype=torch.long).unsqueeze(0) # Batch size 1 generated = numericalized.copy() with torch.no_grad(): for _ in range(max_length): src_mask = model.generate_square_subsequent_mask(input_seq.size(1)).to(input_seq.device) outputs = model(input_seq, src_mask) next_token_logits = outputs[0, -1, :] next_token = torch.argmax(next_token_logits).item() if next_token == vocab['']: break generated.append(next_token) input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(0) # Convert numerical tokens back to words inv_vocab = {idx: word for word, idx in vocab.items()} generated_tokens = [inv_vocab.get(tok, '') for tok in generated] return ' '.join(generated_tokens) @app.route('/', methods=['GET']) def index(): return render_template('index.html') @app.route('/chat', methods=['POST']) def chat(): message = request.form.get('message') if not message: return render_template('index.html') # Simulate thinking delay delay = random.randint(1, 10) print(f"Thinking for {delay} seconds...") time.sleep(delay) response = generate_text(message) return render_template('index.html', message=message, response=response) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000)