khaled5321 commited on
Commit
f45a763
1 Parent(s): 155c20a

Add application file

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. app.py +170 -0
  3. arabic_vocab.pth +3 -0
  4. english_vocab.pth +3 -0
  5. model.pt +3 -0
  6. requirements.txt +0 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /venv
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import re
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import unicodedata
8
+ import nltk
9
+ from nltk.tokenize.treebank import TreebankWordDetokenizer
10
+
11
+ nltk.download('punkt')
12
+
13
+
14
+ class Encoder(nn.Module):
15
+ def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
16
+ super().__init__()
17
+ self.embedding = nn.Embedding(input_dim, emb_dim)
18
+ self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
19
+ self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
20
+ self.dropout = nn.Dropout(dropout)
21
+
22
+ def forward(self, src):
23
+ embedded = self.dropout(self.embedding(src))
24
+ outputs, hidden = self.rnn(embedded)
25
+ hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
26
+
27
+ return outputs, hidden
28
+
29
+
30
+ class Attention(nn.Module):
31
+ def __init__(self, enc_hid_dim, dec_hid_dim):
32
+ super().__init__()
33
+ self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
34
+ self.v = nn.Linear(dec_hid_dim, 1, bias = False)
35
+
36
+ def forward(self, hidden, encoder_outputs):
37
+ batch_size = encoder_outputs.shape[1]
38
+ src_len = encoder_outputs.shape[0]
39
+
40
+ hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
41
+ encoder_outputs = encoder_outputs.permute(1, 0, 2)
42
+
43
+ energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
44
+ attention = self.v(energy).squeeze(2)
45
+
46
+ return F.softmax(attention, dim=1)
47
+
48
+
49
+ class Decoder(nn.Module):
50
+ def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
51
+ super().__init__()
52
+ self.output_dim = output_dim
53
+ self.attention = attention
54
+ self.embedding = nn.Embedding(output_dim, emb_dim)
55
+ self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
56
+ self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
57
+ self.dropout = nn.Dropout(dropout)
58
+
59
+ def forward(self, input, hidden, encoder_outputs):
60
+ input = input.unsqueeze(0)
61
+ embedded = self.dropout(self.embedding(input))
62
+ a = self.attention(hidden, encoder_outputs)
63
+ a = a.unsqueeze(1)
64
+ encoder_outputs = encoder_outputs.permute(1, 0, 2)
65
+ weighted = torch.bmm(a, encoder_outputs)
66
+ weighted = weighted.permute(1, 0, 2)
67
+ rnn_input = torch.cat((embedded, weighted), dim = 2)
68
+ output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
69
+
70
+ assert (output == hidden).all()
71
+
72
+ embedded = embedded.squeeze(0)
73
+ output = output.squeeze(0)
74
+ weighted = weighted.squeeze(0)
75
+
76
+ prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
77
+
78
+ return prediction, hidden.squeeze(0)
79
+
80
+ class Seq2Seq(nn.Module):
81
+ def __init__(self, encoder, decoder, device):
82
+ super().__init__()
83
+
84
+ self.encoder = encoder
85
+ self.decoder = decoder
86
+ self.device = device
87
+
88
+ def forward(self, src, trg, teacher_forcing_ratio = 0.5):
89
+ batch_size = trg.shape[1]
90
+ trg_len = trg.shape[0]
91
+ trg_vocab_size = self.decoder.output_dim
92
+ outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
93
+
94
+ encoder_outputs, hidden = self.encoder(src)
95
+
96
+ input = trg[0,:]
97
+
98
+ for t in range(1, trg_len):
99
+ output, hidden = self.decoder(input, hidden, encoder_outputs)
100
+
101
+ outputs[t] = output
102
+
103
+ teacher_force = random.random() < teacher_forcing_ratio
104
+ top1 = output.argmax(1)
105
+ input = trg[t] if teacher_force else top1
106
+
107
+ return outputs
108
+
109
+
110
+ def unicodeToAscii(s):
111
+ return ''.join(
112
+ c for c in unicodedata.normalize('NFD', s)
113
+ if unicodedata.category(c) != 'Mn'
114
+ )
115
+
116
+ def tokenize_ar(text):
117
+ """
118
+ Tokenizes Arabic text from a string into a list of strings (tokens) and reverses it
119
+ """
120
+ return [tok for tok in nltk.tokenize.wordpunct_tokenize(unicodeToAscii(text))]
121
+
122
+ src_vocab = torch.load("arabic_vocab.pth")
123
+ trg_vocab = torch.load("english_vocab.pth")
124
+
125
+ INPUT_DIM = 9790
126
+ OUTPUT_DIM = 5682
127
+ ENC_EMB_DIM = 256
128
+ DEC_EMB_DIM = 256
129
+ ENC_HID_DIM = 512
130
+ DEC_HID_DIM = 512
131
+ ENC_DROPOUT = 0.5
132
+ DEC_DROPOUT = 0.5
133
+
134
+ attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
135
+ enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
136
+ dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)
137
+
138
+ model = Seq2Seq(enc, dec, "cpu")
139
+ model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
140
+
141
+
142
+ def infer(text, max_length=50):
143
+ text = tokenize_ar(text)
144
+ sequence = []
145
+ sequence.append(src_vocab['<sos>'])
146
+ sequence.extend([src_vocab[token] for token in text])
147
+ sequence.append(src_vocab['<eos>'])
148
+
149
+ sequence = torch.Tensor(sequence)
150
+ sequence = sequence[:, None].to(torch.int64)
151
+ target = torch.zeros(max_length, 1).to(torch.int64)
152
+
153
+ with torch.no_grad():
154
+ model.eval()
155
+ output = model(sequence, target, 0)
156
+ output_dim = output.shape[-1]
157
+ output = output[1:].view(-1, output_dim)
158
+
159
+ prediction = []
160
+ for i in output:
161
+ prediction.append(torch.argmax(i).item())
162
+
163
+ tokens = trg_vocab.lookup_tokens(prediction)
164
+ en = TreebankWordDetokenizer().detokenize(tokens).replace('<eos>', "")
165
+
166
+ return re.sub(r'[^\w\s]','',en).strip()
167
+
168
+ iface = gr.Interface(fn=infer, inputs="text", outputs="text")
169
+
170
+ iface.launch()
arabic_vocab.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be6f7c887496c6d29ce95ce3a7a8946924164e59bbfc2d6172e147304d182525
3
+ size 197049
english_vocab.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79987ed2b3823a5bfc0715b460e58427a98aea877fe2fb8e6d7eab896c620c2e
3
+ size 93115
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58170b33f60b6aae11efed74ae6f46f8de6c8909a1f17abbc3c3089415101144
3
+ size 82332843
requirements.txt ADDED
Binary file (2.6 kB). View file