arnocandel
commited on
Commit
•
d5d0b9d
1
Parent(s):
98a65ab
Upload 10 files
Browse files
h2oai_pipeline.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
from transformers import TextGenerationPipeline
|
2 |
from transformers.pipelines.text_generation import ReturnType
|
3 |
|
|
|
|
|
|
|
4 |
human = "<human>:"
|
5 |
bot = "<bot>:"
|
6 |
|
@@ -28,3 +31,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
28 |
for rec in records:
|
29 |
rec['generated_text'] = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
|
30 |
return records
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import TextGenerationPipeline
|
2 |
from transformers.pipelines.text_generation import ReturnType
|
3 |
|
4 |
+
from stopping import get_stopping
|
5 |
+
|
6 |
+
prompt_type = "human_bot"
|
7 |
human = "<human>:"
|
8 |
bot = "<bot>:"
|
9 |
|
|
|
31 |
for rec in records:
|
32 |
rec['generated_text'] = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
|
33 |
return records
|
34 |
+
|
35 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
36 |
+
stopping_criteria = get_stopping(prompt_type, self.tokenizer, self.device, human=human, bot=bot)
|
37 |
+
generate_kwargs['stopping_criteria'] = stopping_criteria
|
38 |
+
return super()._forward(model_inputs, **generate_kwargs)
|
pytorch_model-00001-of-00003.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5028171302
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d0678fa21071e8428e77f6cf089e44d4aa6999bc81968b6bf7e211013ff39c7
|
3 |
size 5028171302
|
pytorch_model-00002-of-00003.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5017761129
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:adb145b24d55109b0973dd9f10b4f2d6d90c33c7f27f80367217a0e59ed5af50
|
3 |
size 5017761129
|
pytorch_model-00003-of-00003.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 3803055858
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc650723bc83675298d5451e40bd0ea0c18c9b16b87a81e44b20641695ee9900
|
3 |
size 3803055858
|