lhoestq HF staff commited on
Commit
2d4d597
1 Parent(s): 3642076

disanle batching

Browse files
Files changed (1) hide show
  1. generate.py +2 -2
generate.py CHANGED
@@ -26,11 +26,11 @@ logger.warning("Loading model...")
26
  if torch.backends.mps.is_available():
27
  device = "mps"
28
  model_id = "Qwen/Qwen1.5-0.5B-Chat"
29
- batch_size = 4
30
  else:
31
  device = "cuda"
32
  model_id = "google/gemma-2b-it"
33
- batch_size = 4
34
 
35
  model = models.transformers(model_id, device=device)
36
 
 
26
  if torch.backends.mps.is_available():
27
  device = "mps"
28
  model_id = "Qwen/Qwen1.5-0.5B-Chat"
29
+ batch_size = 1 # batching generates duplicates
30
  else:
31
  device = "cuda"
32
  model_id = "google/gemma-2b-it"
33
+ batch_size = 1 # batching generates duplicates
34
 
35
  model = models.transformers(model_id, device=device)
36