Spaces:
Runtime error
Runtime error
khaled5321
commited on
Commit
•
f45a763
1
Parent(s):
155c20a
Add application file
Browse files- .gitignore +1 -0
- app.py +170 -0
- arabic_vocab.pth +3 -0
- english_vocab.pth +3 -0
- model.pt +3 -0
- requirements.txt +0 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
/venv
|
app.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import random
|
3 |
+
import re
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import unicodedata
|
8 |
+
import nltk
|
9 |
+
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
10 |
+
|
11 |
+
nltk.download('punkt')
|
12 |
+
|
13 |
+
|
14 |
+
class Encoder(nn.Module):
|
15 |
+
def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
|
16 |
+
super().__init__()
|
17 |
+
self.embedding = nn.Embedding(input_dim, emb_dim)
|
18 |
+
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
|
19 |
+
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
|
20 |
+
self.dropout = nn.Dropout(dropout)
|
21 |
+
|
22 |
+
def forward(self, src):
|
23 |
+
embedded = self.dropout(self.embedding(src))
|
24 |
+
outputs, hidden = self.rnn(embedded)
|
25 |
+
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
|
26 |
+
|
27 |
+
return outputs, hidden
|
28 |
+
|
29 |
+
|
30 |
+
class Attention(nn.Module):
|
31 |
+
def __init__(self, enc_hid_dim, dec_hid_dim):
|
32 |
+
super().__init__()
|
33 |
+
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
|
34 |
+
self.v = nn.Linear(dec_hid_dim, 1, bias = False)
|
35 |
+
|
36 |
+
def forward(self, hidden, encoder_outputs):
|
37 |
+
batch_size = encoder_outputs.shape[1]
|
38 |
+
src_len = encoder_outputs.shape[0]
|
39 |
+
|
40 |
+
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
|
41 |
+
encoder_outputs = encoder_outputs.permute(1, 0, 2)
|
42 |
+
|
43 |
+
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
|
44 |
+
attention = self.v(energy).squeeze(2)
|
45 |
+
|
46 |
+
return F.softmax(attention, dim=1)
|
47 |
+
|
48 |
+
|
49 |
+
class Decoder(nn.Module):
|
50 |
+
def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
|
51 |
+
super().__init__()
|
52 |
+
self.output_dim = output_dim
|
53 |
+
self.attention = attention
|
54 |
+
self.embedding = nn.Embedding(output_dim, emb_dim)
|
55 |
+
self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
|
56 |
+
self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
|
57 |
+
self.dropout = nn.Dropout(dropout)
|
58 |
+
|
59 |
+
def forward(self, input, hidden, encoder_outputs):
|
60 |
+
input = input.unsqueeze(0)
|
61 |
+
embedded = self.dropout(self.embedding(input))
|
62 |
+
a = self.attention(hidden, encoder_outputs)
|
63 |
+
a = a.unsqueeze(1)
|
64 |
+
encoder_outputs = encoder_outputs.permute(1, 0, 2)
|
65 |
+
weighted = torch.bmm(a, encoder_outputs)
|
66 |
+
weighted = weighted.permute(1, 0, 2)
|
67 |
+
rnn_input = torch.cat((embedded, weighted), dim = 2)
|
68 |
+
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
|
69 |
+
|
70 |
+
assert (output == hidden).all()
|
71 |
+
|
72 |
+
embedded = embedded.squeeze(0)
|
73 |
+
output = output.squeeze(0)
|
74 |
+
weighted = weighted.squeeze(0)
|
75 |
+
|
76 |
+
prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
|
77 |
+
|
78 |
+
return prediction, hidden.squeeze(0)
|
79 |
+
|
80 |
+
class Seq2Seq(nn.Module):
|
81 |
+
def __init__(self, encoder, decoder, device):
|
82 |
+
super().__init__()
|
83 |
+
|
84 |
+
self.encoder = encoder
|
85 |
+
self.decoder = decoder
|
86 |
+
self.device = device
|
87 |
+
|
88 |
+
def forward(self, src, trg, teacher_forcing_ratio = 0.5):
|
89 |
+
batch_size = trg.shape[1]
|
90 |
+
trg_len = trg.shape[0]
|
91 |
+
trg_vocab_size = self.decoder.output_dim
|
92 |
+
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
|
93 |
+
|
94 |
+
encoder_outputs, hidden = self.encoder(src)
|
95 |
+
|
96 |
+
input = trg[0,:]
|
97 |
+
|
98 |
+
for t in range(1, trg_len):
|
99 |
+
output, hidden = self.decoder(input, hidden, encoder_outputs)
|
100 |
+
|
101 |
+
outputs[t] = output
|
102 |
+
|
103 |
+
teacher_force = random.random() < teacher_forcing_ratio
|
104 |
+
top1 = output.argmax(1)
|
105 |
+
input = trg[t] if teacher_force else top1
|
106 |
+
|
107 |
+
return outputs
|
108 |
+
|
109 |
+
|
110 |
+
def unicodeToAscii(s):
|
111 |
+
return ''.join(
|
112 |
+
c for c in unicodedata.normalize('NFD', s)
|
113 |
+
if unicodedata.category(c) != 'Mn'
|
114 |
+
)
|
115 |
+
|
116 |
+
def tokenize_ar(text):
|
117 |
+
"""
|
118 |
+
Tokenizes Arabic text from a string into a list of strings (tokens) and reverses it
|
119 |
+
"""
|
120 |
+
return [tok for tok in nltk.tokenize.wordpunct_tokenize(unicodeToAscii(text))]
|
121 |
+
|
122 |
+
src_vocab = torch.load("arabic_vocab.pth")
|
123 |
+
trg_vocab = torch.load("english_vocab.pth")
|
124 |
+
|
125 |
+
INPUT_DIM = 9790
|
126 |
+
OUTPUT_DIM = 5682
|
127 |
+
ENC_EMB_DIM = 256
|
128 |
+
DEC_EMB_DIM = 256
|
129 |
+
ENC_HID_DIM = 512
|
130 |
+
DEC_HID_DIM = 512
|
131 |
+
ENC_DROPOUT = 0.5
|
132 |
+
DEC_DROPOUT = 0.5
|
133 |
+
|
134 |
+
attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
|
135 |
+
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
|
136 |
+
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)
|
137 |
+
|
138 |
+
model = Seq2Seq(enc, dec, "cpu")
|
139 |
+
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
|
140 |
+
|
141 |
+
|
142 |
+
def infer(text, max_length=50):
|
143 |
+
text = tokenize_ar(text)
|
144 |
+
sequence = []
|
145 |
+
sequence.append(src_vocab['<sos>'])
|
146 |
+
sequence.extend([src_vocab[token] for token in text])
|
147 |
+
sequence.append(src_vocab['<eos>'])
|
148 |
+
|
149 |
+
sequence = torch.Tensor(sequence)
|
150 |
+
sequence = sequence[:, None].to(torch.int64)
|
151 |
+
target = torch.zeros(max_length, 1).to(torch.int64)
|
152 |
+
|
153 |
+
with torch.no_grad():
|
154 |
+
model.eval()
|
155 |
+
output = model(sequence, target, 0)
|
156 |
+
output_dim = output.shape[-1]
|
157 |
+
output = output[1:].view(-1, output_dim)
|
158 |
+
|
159 |
+
prediction = []
|
160 |
+
for i in output:
|
161 |
+
prediction.append(torch.argmax(i).item())
|
162 |
+
|
163 |
+
tokens = trg_vocab.lookup_tokens(prediction)
|
164 |
+
en = TreebankWordDetokenizer().detokenize(tokens).replace('<eos>', "")
|
165 |
+
|
166 |
+
return re.sub(r'[^\w\s]','',en).strip()
|
167 |
+
|
168 |
+
iface = gr.Interface(fn=infer, inputs="text", outputs="text")
|
169 |
+
|
170 |
+
iface.launch()
|
arabic_vocab.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:be6f7c887496c6d29ce95ce3a7a8946924164e59bbfc2d6172e147304d182525
|
3 |
+
size 197049
|
english_vocab.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79987ed2b3823a5bfc0715b460e58427a98aea877fe2fb8e6d7eab896c620c2e
|
3 |
+
size 93115
|
model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:58170b33f60b6aae11efed74ae6f46f8de6c8909a1f17abbc3c3089415101144
|
3 |
+
size 82332843
|
requirements.txt
ADDED
Binary file (2.6 kB). View file
|
|