model_3ed0k4 / src /app.py
3ed0k4's picture
Upload 12 files
65224b2 verified
# src/app.py
from flask import Flask, request, render_template
import torch
from model import TransformerModel
from utils import load_vocab, tokenize
import time
import random
import os
app = Flask(__name__, template_folder='templates')
# Configuration
MODEL_PATH = 'models/3ed0k4_model_epoch10.pth' # Update this path based on the latest model
VOCAB_PATH = 'vocab.json'
EMBED_SIZE = 256
NUM_HEADS = 8
HIDDEN_DIM = 512
NUM_LAYERS = 4
DROPOUT = 0.1
MAX_LENGTH = 100 # Maximum tokens to generate
# Load vocabulary
vocab = load_vocab(VOCAB_PATH)
vocab_size = len(vocab)
# Initialize model
model = TransformerModel(
vocab_size=vocab_size,
embed_size=EMBED_SIZE,
num_heads=NUM_HEADS,
hidden_dim=HIDDEN_DIM,
num_layers=NUM_LAYERS,
dropout=DROPOUT
)
# Load model weights
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please train the model first.")
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
def generate_text(prompt, max_length=MAX_LENGTH):
tokens = tokenize(prompt)
numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens]
input_seq = torch.tensor(numericalized, dtype=torch.long).unsqueeze(0) # Batch size 1
generated = numericalized.copy()
with torch.no_grad():
for _ in range(max_length):
src_mask = model.generate_square_subsequent_mask(input_seq.size(1)).to(input_seq.device)
outputs = model(input_seq, src_mask)
next_token_logits = outputs[0, -1, :]
next_token = torch.argmax(next_token_logits).item()
if next_token == vocab['<PAD>']:
break
generated.append(next_token)
input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(0)
# Convert numerical tokens back to words
inv_vocab = {idx: word for word, idx in vocab.items()}
generated_tokens = [inv_vocab.get(tok, '<UNK>') for tok in generated]
return ' '.join(generated_tokens)
@app.route('/', methods=['GET'])
def index():
return render_template('index.html')
@app.route('/chat', methods=['POST'])
def chat():
message = request.form.get('message')
if not message:
return render_template('index.html')
# Simulate thinking delay
delay = random.randint(1, 10)
print(f"Thinking for {delay} seconds...")
time.sleep(delay)
response = generate_text(message)
return render_template('index.html', message=message, response=response)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)