nkanungo commited on
Commit
6744c0a
·
1 Parent(s): f77e58d

Update nano_gpt_inferencing.py

Browse files
Files changed (1) hide show
  1. nano_gpt_inferencing.py +8 -1
nano_gpt_inferencing.py CHANGED
@@ -193,5 +193,12 @@ def generate_paragraph(initial_text,max_token=50):
193
  final_model = model.to(device)
194
  encoded_text= encode(initial_text)
195
  encoded_text_tensor = torch.tensor(encoded_text).view(1, -1)
196
- return decode(final_model.generate(encoded_text_tensor, max_new_tokens=int(max_token))[0].tolist())
 
 
 
 
 
 
 
197
 
 
193
  final_model = model.to(device)
194
  encoded_text= encode(initial_text)
195
  encoded_text_tensor = torch.tensor(encoded_text).view(1, -1)
196
+
197
+ try:
198
+ max_token = int(max_token)
199
+ return decode(final_model.generate(encoded_text_tensor, max_new_tokens=max_token)[0].tolist())
200
+
201
+ except ValueError as e:
202
+ return 'You have entered an invalid value in Max number of tokens, Please enter an integer value to proceed'
203
+
204