Text Generation
Transformers
Safetensors
llama_hydra
tweety
custom_code
FremyCompany commited on
Commit
646e144
1 Parent(s): 36d92d3

Add custom code

Browse files
Files changed (1) hide show
  1. modeling_llama_hydra.py +227 -0
modeling_llama_hydra.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
+
10
+ from typing import List, Optional, Tuple, Union
11
+
12
+ import transformers
13
+ from transformers import LlamaConfig
14
+ from transformers.cache_utils import Cache
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast,
18
+ QuestionAnsweringModelOutput,
19
+ SequenceClassifierOutputWithPast,
20
+ )
21
+
22
+ class LlamaHydraConfig(LlamaConfig):
23
+ model_type = "llama_hydra"
24
+
25
+ def __init__(self, **kwargs):
26
+ if 'vocab_size' not in kwargs:
27
+ if 'output_vocab_size' in kwargs:
28
+ kwargs['vocab_size'] = kwargs['output_vocab_size']
29
+ else:
30
+ kwargs['vocab_size'] = 32000
31
+ self.input_vocab_size = kwargs['input_vocab_size'] if 'input_vocab_size' in kwargs else kwargs['vocab_size']
32
+ self.output_vocab_size = kwargs['output_vocab_size'] if 'output_vocab_size' in kwargs else kwargs['vocab_size']
33
+ super().__init__(**kwargs)
34
+
35
+ class LlamaHydraForCausalLM(transformers.LlamaPreTrainedModel):
36
+ config_class = LlamaHydraConfig
37
+ _tied_weights_keys = ["lm_head.weight"]
38
+
39
+ def __init__(self, config):
40
+ hydra_config = LlamaHydraConfig(**config.__dict__)
41
+ encoder_config = LlamaConfig(**config.__dict__)
42
+ encoder_config.vocab_size = hydra_config.input_vocab_size
43
+ super().__init__(hydra_config)
44
+ self.model = transformers.LlamaModel(encoder_config)
45
+ self.input_vocab_size = hydra_config.input_vocab_size
46
+ self.output_vocab_size = hydra_config.output_vocab_size
47
+ self.vocab_size = hydra_config.vocab_size
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_input_embeddings(self):
54
+ return self.model.embed_tokens
55
+
56
+ def set_input_embeddings(self, value):
57
+ self.model.embed_tokens = value
58
+
59
+ def get_output_embeddings(self):
60
+ return self.lm_head
61
+
62
+ def set_output_embeddings(self, new_embeddings):
63
+ self.lm_head = new_embeddings
64
+
65
+ def set_decoder(self, decoder):
66
+ self.model = decoder
67
+
68
+ def get_decoder(self):
69
+ return self.model
70
+
71
+ #@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
72
+ #@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
73
+ def forward(
74
+ self,
75
+ input_ids: torch.LongTensor = None,
76
+ attention_mask: Optional[torch.Tensor] = None,
77
+ position_ids: Optional[torch.LongTensor] = None,
78
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
79
+ inputs_embeds: Optional[torch.FloatTensor] = None,
80
+ labels: Optional[torch.LongTensor] = None,
81
+ use_cache: Optional[bool] = None,
82
+ output_attentions: Optional[bool] = None,
83
+ output_hidden_states: Optional[bool] = None,
84
+ return_dict: Optional[bool] = None,
85
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
86
+ r"""
87
+ Args:
88
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
89
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
90
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
91
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
92
+
93
+ Returns:
94
+
95
+ Example:
96
+
97
+ ```python
98
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
99
+
100
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
101
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
102
+
103
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
104
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
105
+
106
+ >>> # Generate
107
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
108
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
109
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
110
+ ```"""
111
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
112
+ output_hidden_states = (
113
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
114
+ )
115
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
116
+
117
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
118
+ outputs = self.model(
119
+ input_ids=input_ids,
120
+ attention_mask=attention_mask,
121
+ position_ids=position_ids,
122
+ past_key_values=past_key_values,
123
+ inputs_embeds=inputs_embeds,
124
+ use_cache=use_cache,
125
+ output_attentions=output_attentions,
126
+ output_hidden_states=output_hidden_states,
127
+ return_dict=return_dict,
128
+ )
129
+
130
+ hidden_states = outputs[0]
131
+ if self.config.pretraining_tp > 1:
132
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
133
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
134
+ logits = torch.cat(logits, dim=-1)
135
+ else:
136
+ logits = self.lm_head(hidden_states)
137
+ logits = logits.float()
138
+
139
+ loss = None
140
+ if labels is not None:
141
+ # Shift so that tokens < n predict n
142
+ shift_logits = logits[..., :-1, :].contiguous()
143
+ shift_labels = labels[..., 1:].contiguous()
144
+ # Flatten the tokens
145
+ loss_fct = CrossEntropyLoss()
146
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
147
+ shift_labels = shift_labels.view(-1)
148
+ # Enable model parallelism
149
+ shift_labels = shift_labels.to(shift_logits.device)
150
+ loss = loss_fct(shift_logits, shift_labels)
151
+
152
+ if not return_dict:
153
+ output = (logits,) + outputs[1:]
154
+ return (loss,) + output if loss is not None else output
155
+
156
+ return CausalLMOutputWithPast(
157
+ loss=loss,
158
+ logits=logits,
159
+ past_key_values=outputs.past_key_values,
160
+ hidden_states=outputs.hidden_states,
161
+ attentions=outputs.attentions,
162
+ )
163
+
164
+ def prepare_inputs_for_generation(
165
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
166
+ ):
167
+ if past_key_values is not None:
168
+ if isinstance(past_key_values, Cache):
169
+ cache_length = past_key_values.get_seq_length()
170
+ past_length = past_key_values.seen_tokens
171
+ max_cache_length = past_key_values.get_max_length()
172
+ else:
173
+ cache_length = past_length = past_key_values[0][0].shape[2]
174
+ max_cache_length = None
175
+
176
+ # Keep only the unprocessed tokens:
177
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
178
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
179
+ # input)
180
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
181
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
182
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
183
+ # input_ids based on the past_length.
184
+ elif past_length < input_ids.shape[1]:
185
+ input_ids = input_ids[:, past_length:]
186
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
187
+
188
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
189
+ if (
190
+ max_cache_length is not None
191
+ and attention_mask is not None
192
+ and cache_length + input_ids.shape[1] > max_cache_length
193
+ ):
194
+ attention_mask = attention_mask[:, -max_cache_length:]
195
+
196
+ position_ids = kwargs.get("position_ids", None)
197
+ if attention_mask is not None and position_ids is None:
198
+ # create position_ids on the fly for batch generation
199
+ position_ids = attention_mask.long().cumsum(-1) - 1
200
+ position_ids.masked_fill_(attention_mask == 0, 1)
201
+ if past_key_values:
202
+ position_ids = position_ids[:, -input_ids.shape[1] :]
203
+
204
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
205
+ if inputs_embeds is not None and past_key_values is None:
206
+ model_inputs = {"inputs_embeds": inputs_embeds}
207
+ else:
208
+ model_inputs = {"input_ids": input_ids}
209
+
210
+ model_inputs.update(
211
+ {
212
+ "position_ids": position_ids,
213
+ "past_key_values": past_key_values,
214
+ "use_cache": kwargs.get("use_cache"),
215
+ "attention_mask": attention_mask,
216
+ }
217
+ )
218
+ return model_inputs
219
+
220
+ @staticmethod
221
+ def _reorder_cache(past_key_values, beam_idx):
222
+ reordered_past = ()
223
+ for layer_past in past_key_values:
224
+ reordered_past += (
225
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
226
+ )
227
+ return reordered_past