mwalol commited on
Commit
c4766c9
1 Parent(s): 811403e

Upload h2oai_pipeline.py

Browse files
Files changed (1) hide show
  1. h2oai_pipeline.py +42 -0
h2oai_pipeline.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextGenerationPipeline
2
+ from transformers.pipelines.text_generation import ReturnType
3
+
4
+ STYLE = "<|prompt|>{instruction}</s><|answer|>"
5
+
6
+
7
+ class H2OTextGenerationPipeline(TextGenerationPipeline):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.prompt = STYLE
11
+
12
+ def preprocess(
13
+ self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs
14
+ ):
15
+ prompt_text = self.prompt.format(instruction=prompt_text)
16
+ return super().preprocess(
17
+ prompt_text,
18
+ prefix=prefix,
19
+ handle_long_generation=handle_long_generation,
20
+ **generate_kwargs,
21
+ )
22
+
23
+ def postprocess(
24
+ self,
25
+ model_outputs,
26
+ return_type=ReturnType.FULL_TEXT,
27
+ clean_up_tokenization_spaces=True,
28
+ ):
29
+ records = super().postprocess(
30
+ model_outputs,
31
+ return_type=return_type,
32
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
33
+ )
34
+ for rec in records:
35
+ rec["generated_text"] = (
36
+ rec["generated_text"]
37
+ .split("<|answer|>")[1]
38
+ .strip()
39
+ .split("<|prompt|>")[0]
40
+ .strip()
41
+ )
42
+ return records