File size: 2,691 Bytes
65224b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# src/app.py
from flask import Flask, request, render_template
import torch
from model import TransformerModel
from utils import load_vocab, tokenize
import time
import random
import os

app = Flask(__name__, template_folder='templates')

# Configuration
MODEL_PATH = 'models/3ed0k4_model_epoch10.pth'  # Update this path based on the latest model
VOCAB_PATH = 'vocab.json'
EMBED_SIZE = 256
NUM_HEADS = 8
HIDDEN_DIM = 512
NUM_LAYERS = 4
DROPOUT = 0.1
MAX_LENGTH = 100  # Maximum tokens to generate

# Load vocabulary
vocab = load_vocab(VOCAB_PATH)
vocab_size = len(vocab)

# Initialize model
model = TransformerModel(
    vocab_size=vocab_size,
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
)

# Load model weights
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please train the model first.")
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()

def generate_text(prompt, max_length=MAX_LENGTH):
    tokens = tokenize(prompt)
    numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens]
    input_seq = torch.tensor(numericalized, dtype=torch.long).unsqueeze(0)  # Batch size 1

    generated = numericalized.copy()

    with torch.no_grad():
        for _ in range(max_length):
            src_mask = model.generate_square_subsequent_mask(input_seq.size(1)).to(input_seq.device)
            outputs = model(input_seq, src_mask)
            next_token_logits = outputs[0, -1, :]
            next_token = torch.argmax(next_token_logits).item()

            if next_token == vocab['<PAD>']:
                break

            generated.append(next_token)
            input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(0)

    # Convert numerical tokens back to words
    inv_vocab = {idx: word for word, idx in vocab.items()}
    generated_tokens = [inv_vocab.get(tok, '<UNK>') for tok in generated]
    return ' '.join(generated_tokens)

@app.route('/', methods=['GET'])
def index():
    return render_template('index.html')

@app.route('/chat', methods=['POST'])
def chat():
    message = request.form.get('message')
    if not message:
        return render_template('index.html')

    # Simulate thinking delay
    delay = random.randint(1, 10)
    print(f"Thinking for {delay} seconds...")
    time.sleep(delay)

    response = generate_text(message)
    return render_template('index.html', message=message, response=response)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)