asigalov61 commited on
Commit
d34cdb3
1 Parent(s): afe740e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -37,20 +37,23 @@ def generate(
37
  progress(0, desc="Starting...")
38
 
39
  for i in progress.tqdm(range(seq_len)):
40
-
41
- x = out[:, -max_seq_len:]
42
-
43
- torch_in = x.tolist()[0]
44
-
45
- logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
46
-
47
- filtered_logits = logits
48
-
49
- probs = F.softmax(filtered_logits / temperature, dim=-1)
50
 
51
- sample = torch.multinomial(probs, 1)
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- out = torch.cat((out, sample), dim=-1)
 
54
 
55
  if return_prime:
56
  return out[:, :]
 
37
  progress(0, desc="Starting...")
38
 
39
  for i in progress.tqdm(range(seq_len)):
 
 
 
 
 
 
 
 
 
 
40
 
41
+ try:
42
+
43
+ x = out[:, -max_seq_len:]
44
+
45
+ torch_in = x.tolist()[0]
46
+
47
+ logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
48
+
49
+ probs = F.softmax(logits / temperature, dim=-1)
50
+
51
+ sample = torch.multinomial(probs, 1)
52
+
53
+ out = torch.cat((out, sample), dim=-1)
54
 
55
+ except:
56
+ break
57
 
58
  if return_prime:
59
  return out[:, :]