Avoid validation error on duplicated attribute
#42
by
alexis779
- opened
- instruct_pipeline.py +1 -9
instruct_pipeline.py
CHANGED
@@ -22,10 +22,8 @@ INTRO_BLURB = (
|
|
22 |
# This is the prompt that is used for generating responses using an already trained model. It ends with the response
|
23 |
# key, where the job of the model is to provide the completion that follows it (i.e. the response itself).
|
24 |
PROMPT_FOR_GENERATION_FORMAT = """{intro}
|
25 |
-
|
26 |
{instruction_key}
|
27 |
{instruction}
|
28 |
-
|
29 |
{response_key}
|
30 |
""".format(
|
31 |
intro=INTRO_BLURB,
|
@@ -37,17 +35,13 @@ PROMPT_FOR_GENERATION_FORMAT = """{intro}
|
|
37 |
|
38 |
def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int:
|
39 |
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
40 |
-
|
41 |
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
42 |
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
43 |
-
|
44 |
Args:
|
45 |
tokenizer (PreTrainedTokenizer): the tokenizer
|
46 |
key (str): the key to convert to a single token
|
47 |
-
|
48 |
Raises:
|
49 |
RuntimeError: if more than one ID was generated
|
50 |
-
|
51 |
Returns:
|
52 |
int: the token ID for the given key
|
53 |
"""
|
@@ -62,7 +56,6 @@ class InstructionTextGenerationPipeline(Pipeline):
|
|
62 |
self, *args, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs
|
63 |
):
|
64 |
"""Initialize the pipeline
|
65 |
-
|
66 |
Args:
|
67 |
do_sample (bool, optional): Whether or not to use sampling. Defaults to True.
|
68 |
max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 128.
|
@@ -132,7 +125,6 @@ class InstructionTextGenerationPipeline(Pipeline):
|
|
132 |
generated_sequence = self.model.generate(
|
133 |
input_ids=input_ids.to(self.model.device),
|
134 |
attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
|
135 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
136 |
**generate_kwargs,
|
137 |
)
|
138 |
|
@@ -209,4 +201,4 @@ class InstructionTextGenerationPipeline(Pipeline):
|
|
209 |
|
210 |
records.append(rec)
|
211 |
|
212 |
-
return records
|
|
|
22 |
# This is the prompt that is used for generating responses using an already trained model. It ends with the response
|
23 |
# key, where the job of the model is to provide the completion that follows it (i.e. the response itself).
|
24 |
PROMPT_FOR_GENERATION_FORMAT = """{intro}
|
|
|
25 |
{instruction_key}
|
26 |
{instruction}
|
|
|
27 |
{response_key}
|
28 |
""".format(
|
29 |
intro=INTRO_BLURB,
|
|
|
35 |
|
36 |
def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int:
|
37 |
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
|
|
38 |
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
39 |
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
|
|
40 |
Args:
|
41 |
tokenizer (PreTrainedTokenizer): the tokenizer
|
42 |
key (str): the key to convert to a single token
|
|
|
43 |
Raises:
|
44 |
RuntimeError: if more than one ID was generated
|
|
|
45 |
Returns:
|
46 |
int: the token ID for the given key
|
47 |
"""
|
|
|
56 |
self, *args, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs
|
57 |
):
|
58 |
"""Initialize the pipeline
|
|
|
59 |
Args:
|
60 |
do_sample (bool, optional): Whether or not to use sampling. Defaults to True.
|
61 |
max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 128.
|
|
|
125 |
generated_sequence = self.model.generate(
|
126 |
input_ids=input_ids.to(self.model.device),
|
127 |
attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
|
|
|
128 |
**generate_kwargs,
|
129 |
)
|
130 |
|
|
|
201 |
|
202 |
records.append(rec)
|
203 |
|
204 |
+
return records
|