Балаганский Никита Николаевич commited on
Commit
8aae6e9
1 Parent(s): 4eb38bc

fix sampling

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -55,6 +55,12 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
55
  caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
56
  generator.set_caif_sampler(caif_sampler)
57
  ordinary_sampler = TopKWithTemperatureSampler()
 
 
 
 
 
 
58
  generator.set_ordinary_sampler(ordinary_sampler)
59
  if device == "cpu":
60
  autocast = torch.cpu.amp.autocast
@@ -67,7 +73,8 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
67
  max_length=20,
68
  caif_period=1,
69
  caif_tokens_num=100,
70
- entropy=3.2
 
71
  )
72
  return sequences[0]
73
 
 
55
  caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
56
  generator.set_caif_sampler(caif_sampler)
57
  ordinary_sampler = TopKWithTemperatureSampler()
58
+ kwargs = {
59
+ "top_k": 20,
60
+ "temperature": 1.0,
61
+ "top_k_classifier": 100,
62
+ "classifier_weight": 5,
63
+ }
64
  generator.set_ordinary_sampler(ordinary_sampler)
65
  if device == "cpu":
66
  autocast = torch.cpu.amp.autocast
 
73
  max_length=20,
74
  caif_period=1,
75
  caif_tokens_num=100,
76
+ entropy=3.2,
77
+ **kwargs
78
  )
79
  return sequences[0]
80