Балаганский Никита Николаевич
commited on
Commit
•
8aae6e9
1
Parent(s):
4eb38bc
fix sampling
Browse files
app.py
CHANGED
@@ -55,6 +55,12 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
|
|
55 |
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
|
56 |
generator.set_caif_sampler(caif_sampler)
|
57 |
ordinary_sampler = TopKWithTemperatureSampler()
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
generator.set_ordinary_sampler(ordinary_sampler)
|
59 |
if device == "cpu":
|
60 |
autocast = torch.cpu.amp.autocast
|
@@ -67,7 +73,8 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
|
|
67 |
max_length=20,
|
68 |
caif_period=1,
|
69 |
caif_tokens_num=100,
|
70 |
-
entropy=3.2
|
|
|
71 |
)
|
72 |
return sequences[0]
|
73 |
|
|
|
55 |
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
|
56 |
generator.set_caif_sampler(caif_sampler)
|
57 |
ordinary_sampler = TopKWithTemperatureSampler()
|
58 |
+
kwargs = {
|
59 |
+
"top_k": 20,
|
60 |
+
"temperature": 1.0,
|
61 |
+
"top_k_classifier": 100,
|
62 |
+
"classifier_weight": 5,
|
63 |
+
}
|
64 |
generator.set_ordinary_sampler(ordinary_sampler)
|
65 |
if device == "cpu":
|
66 |
autocast = torch.cpu.amp.autocast
|
|
|
73 |
max_length=20,
|
74 |
caif_period=1,
|
75 |
caif_tokens_num=100,
|
76 |
+
entropy=3.2,
|
77 |
+
**kwargs
|
78 |
)
|
79 |
return sequences[0]
|
80 |
|