support_generation
#3
by
shashwat1002
- opened
- modeling_backpack_gpt2.py +42 -2
modeling_backpack_gpt2.py
CHANGED
@@ -153,7 +153,7 @@ class BackpackGPT2Model(BackpackGPT2PreTrainedModel):
|
|
153 |
def get_sense_network(self):
|
154 |
return self.sense_network
|
155 |
|
156 |
-
def forward(self, input_ids, position_ids):
|
157 |
# Compute senses
|
158 |
sense_input_embeds = self.word_embeddings(input_ids)
|
159 |
senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
|
@@ -205,8 +205,48 @@ class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel):
|
|
205 |
|
206 |
def get_lm_head(self):
|
207 |
return self.lm_head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
-
def forward(self, input_ids, position_ids=None):
|
210 |
outputs = self.backpack(input_ids, position_ids=position_ids)
|
211 |
hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
|
212 |
lm_logits = self.lm_head(hidden_states) # (bs, s, V)
|
|
|
153 |
def get_sense_network(self):
|
154 |
return self.sense_network
|
155 |
|
156 |
+
def forward(self, input_ids, position_ids, **kwargs):
|
157 |
# Compute senses
|
158 |
sense_input_embeds = self.word_embeddings(input_ids)
|
159 |
senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
|
|
|
205 |
|
206 |
def get_lm_head(self):
|
207 |
return self.lm_head
|
208 |
+
|
209 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, **kwargs):
|
210 |
+
# prepare_inputs_for_generation needs to be overwritten to support generation
|
211 |
+
# this is inspired from the one in GPT2LMHeadModel: https://github.com/huggingface/transformers/blob/d533465150532b0c5de167b574e59f64c68b1154/src/transformers/models/gpt2/modeling_gpt2.py#L1007C4-L1007C4
|
212 |
+
|
213 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
214 |
+
# only last token for inputs_ids if past is defined in kwargs
|
215 |
+
if past_key_values:
|
216 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
217 |
+
if token_type_ids is not None:
|
218 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
219 |
+
|
220 |
+
attention_mask = kwargs.get("attention_mask", None)
|
221 |
+
position_ids = kwargs.get("position_ids", None)
|
222 |
+
|
223 |
+
if attention_mask is not None and position_ids is None:
|
224 |
+
# create position_ids on the fly for batch generation
|
225 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
226 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
227 |
+
if past_key_values:
|
228 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
229 |
+
else:
|
230 |
+
position_ids = None
|
231 |
+
|
232 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
233 |
+
if inputs_embeds is not None and past_key_values is None:
|
234 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
235 |
+
else:
|
236 |
+
model_inputs = {"input_ids": input_ids}
|
237 |
+
|
238 |
+
model_inputs.update(
|
239 |
+
{
|
240 |
+
"past_key_values": past_key_values,
|
241 |
+
"use_cache": kwargs.get("use_cache"),
|
242 |
+
"position_ids": position_ids,
|
243 |
+
"attention_mask": attention_mask,
|
244 |
+
"token_type_ids": token_type_ids,
|
245 |
+
}
|
246 |
+
)
|
247 |
+
return model_inputs
|
248 |
|
249 |
+
def forward(self, input_ids, position_ids=None, **kwargs):
|
250 |
outputs = self.backpack(input_ids, position_ids=position_ids)
|
251 |
hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
|
252 |
lm_logits = self.lm_head(hidden_states) # (bs, s, V)
|