Crystalcareai
commited on
Commit
•
6a35495
1
Parent(s):
beb979f
Update modeling_quiet.py
Browse files- modeling_quiet.py +2 -27
modeling_quiet.py
CHANGED
@@ -1111,21 +1111,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
1113 |
|
1114 |
-
def generate_with_callback(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, callback=None, **kwargs):
|
1115 |
-
if isinstance(input_ids, str):
|
1116 |
-
input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
|
1117 |
-
|
1118 |
-
if attention_mask is None:
|
1119 |
-
attention_mask = torch.ones_like(input_ids)
|
1120 |
-
|
1121 |
-
from .generate import generate
|
1122 |
-
generated_token_ids, generated_text = generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
1123 |
-
|
1124 |
-
if callback is not None:
|
1125 |
-
callback(generated_text)
|
1126 |
-
|
1127 |
-
return generated_text
|
1128 |
-
|
1129 |
@torch.no_grad()
|
1130 |
def generate(
|
1131 |
self,
|
@@ -1143,16 +1128,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1143 |
attention_mask = torch.ones_like(input_ids)
|
1144 |
|
1145 |
from .generate import generate
|
1146 |
-
|
1147 |
-
|
1148 |
-
# Convert the generated token IDs to a tensor
|
1149 |
-
generated_token_ids = torch.tensor(generated_token_ids)
|
1150 |
-
|
1151 |
-
# Return the generated text if it's a string, otherwise return the token IDs
|
1152 |
-
if isinstance(generated_text, str):
|
1153 |
-
return generated_text
|
1154 |
-
else:
|
1155 |
-
return generated_token_ids
|
1156 |
|
1157 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1158 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@@ -2084,5 +2060,4 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
|
|
2084 |
past_key_values=transformer_outputs.past_key_values,
|
2085 |
hidden_states=transformer_outputs.hidden_states,
|
2086 |
attentions=transformer_outputs.attentions,
|
2087 |
-
)
|
2088 |
-
|
|
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
1113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1114 |
@torch.no_grad()
|
1115 |
def generate(
|
1116 |
self,
|
|
|
1128 |
attention_mask = torch.ones_like(input_ids)
|
1129 |
|
1130 |
from .generate import generate
|
1131 |
+
return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1132 |
|
1133 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1134 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
2060 |
past_key_values=transformer_outputs.past_key_values,
|
2061 |
hidden_states=transformer_outputs.hidden_states,
|
2062 |
attentions=transformer_outputs.attentions,
|
2063 |
+
)
|
|