Crystalcareai commited on
Commit
969568c
1 Parent(s): 7b223b3

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +121 -113
modeling_quiet.py CHANGED
@@ -1024,16 +1024,14 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1024
  # Update the attention mask
1025
  if attention_mask is not None:
1026
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1027
- else:
1028
- attention_mask = torch.ones((batch_size, seq_len)).to(input_ids.device)
1029
 
1030
  # Generate the continuation
1031
  continuation_length = self.n_ahead - 2
1032
  new_key_values = past_key_values
1033
-
1034
  # Initialize next_token_id with a default value
1035
  next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
1036
-
1037
  start_time = time.time()
1038
  for continuation_idx in range(continuation_length):
1039
  outputs = self.model(
@@ -1059,106 +1057,79 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1059
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1060
 
1061
  # Append the generated token to the input sequence
1062
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1063
  seq_len += 1
1064
 
1065
  # Update the attention mask
1066
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
 
1067
 
1068
  # Append the end thought token to the input sequence
1069
- end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1070
- input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1071
- seq_len += 1
1072
 
1073
- # Update the attention mask
 
1074
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1075
 
1076
- # Get the hidden states before and after the thought
1077
- outputs_before = self.model(
1078
- input_ids=original_input_ids,
1079
- attention_mask=original_attention_mask,
1080
- position_ids=position_ids,
1081
- past_key_values=past_key_values,
1082
- inputs_embeds=inputs_embeds,
1083
- use_cache=use_cache,
1084
- output_attentions=output_attentions,
1085
- output_hidden_states=output_hidden_states,
1086
- return_dict=return_dict,
1087
- )
1088
- hidden_states_before = outputs_before[0][:, -1:, :]
1089
 
1090
- # two new tokens: last continuation token and end thought token
1091
- outputs_after = self.model(
1092
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1093
- attention_mask=attention_mask[:, -2:],
1094
- position_ids=position_ids,
1095
- past_key_values=new_key_values,
1096
- inputs_embeds=inputs_embeds,
1097
- use_cache=use_cache,
1098
- output_attentions=output_attentions,
1099
- output_hidden_states=output_hidden_states,
1100
- return_dict=return_dict,
1101
- )
1102
- hidden_states_after = outputs_after[0][:, -1:, :]
1103
 
1104
- # Apply the talk head to get the mixing weight
1105
- mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1106
 
1107
- # Apply the mixing weight to the hidden states
1108
- mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1109
 
1110
- # Apply the language model head to get the final logits
1111
- logits = self.lm_head(mixed_hidden_states)
1112
- return logits
1113
 
1114
- # @torch.no_grad()
1115
- # def generate(
1116
- # self,
1117
- # input_ids: torch.LongTensor,
1118
- # attention_mask: Optional[torch.Tensor] = None,
1119
- # max_new_tokens: Optional[int] = None,
1120
- # temperature: float = 1.0,
1121
- # **kwargs,
1122
- # ):
1123
- # if isinstance(input_ids, str):
1124
- # input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1125
-
1126
- # if attention_mask is None:
1127
- # attention_mask = torch.ones_like(input_ids)
1128
-
1129
- # batch_size, seq_len = input_ids.shape
1130
- # max_length = seq_len + max_new_tokens if max_new_tokens is not None else self.config.max_length
1131
-
1132
- # position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
1133
- # position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
1134
-
1135
- # past_key_values = None
1136
- # hidden_states = None
1137
- # all_hidden_states = ()
1138
-
1139
- # for _ in range(max_length - seq_len):
1140
- # logits = self.infer(
1141
- # input_ids=input_ids,
1142
- # attention_mask=attention_mask,
1143
- # position_ids=position_ids,
1144
- # past_key_values=past_key_values,
1145
- # inputs_embeds=hidden_states,
1146
- # use_cache=True,
1147
- # output_attentions=False,
1148
- # output_hidden_states=False,
1149
- # return_dict=False,
1150
- # )
1151
-
1152
- # next_token_logits = logits[:, -1, :] / temperature
1153
- # next_token_id = torch.argmax(next_token_logits, dim=-1)
1154
-
1155
- # input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
1156
- # attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)], dim=-1)
1157
- # position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1)
1158
-
1159
- # all_hidden_states = all_hidden_states + (hidden_states,)
1160
-
1161
- # return input_ids, all_hidden_states
1162
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1163
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1164
  def forward(
@@ -1891,16 +1862,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1891
  torch.cuda.empty_cache()
1892
 
1893
 
1894
- return self.infer(
1895
- input_ids=input_ids,
1896
- attention_mask=attention_mask,
1897
- position_ids=position_ids,
1898
- past_key_values=past_key_values,
1899
- inputs_embeds=inputs_embeds,
1900
- use_cache=use_cache,
1901
- output_attentions=output_attentions,
1902
- output_hidden_states=output_hidden_states,
1903
- return_dict=return_dict,
1904
  )
1905
 
1906
 
@@ -1908,18 +1875,59 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1908
  def prepare_inputs_for_generation(
1909
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1910
  ):
1911
- if attention_mask is None:
1912
- attention_mask = input_ids.new_ones(input_ids.shape)
1913
-
1914
- if past_key_values:
1915
- input_ids = input_ids[:, -1:]
1916
-
1917
- return {
1918
- "input_ids": input_ids,
1919
- "attention_mask": attention_mask,
1920
- "past_key_values": past_key_values,
1921
- "inputs_embeds": inputs_embeds,
1922
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1923
 
1924
  @staticmethod
1925
  def _reorder_cache(past_key_values, beam_idx):
 
1024
  # Update the attention mask
1025
  if attention_mask is not None:
1026
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
 
 
1027
 
1028
  # Generate the continuation
1029
  continuation_length = self.n_ahead - 2
1030
  new_key_values = past_key_values
1031
+
1032
  # Initialize next_token_id with a default value
1033
  next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
1034
+
1035
  start_time = time.time()
1036
  for continuation_idx in range(continuation_length):
1037
  outputs = self.model(
 
1057
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1058
 
1059
  # Append the generated token to the input sequence
1060
+ # input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1061
  seq_len += 1
1062
 
1063
  # Update the attention mask
1064
+ if attention_mask is not None:
1065
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1066
 
1067
  # Append the end thought token to the input sequence
1068
+ end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1069
+ input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1070
+ seq_len += 1
1071
 
1072
+ # Update the attention mask
1073
+ if attention_mask is not None:
1074
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1075
 
1076
+ # Get the hidden states before and after the thought
1077
+ outputs_before = self.model(
1078
+ input_ids=original_input_ids,
1079
+ attention_mask=original_attention_mask,
1080
+ position_ids=position_ids,
1081
+ past_key_values=past_key_values,
1082
+ inputs_embeds=inputs_embeds,
1083
+ use_cache=use_cache,
1084
+ output_attentions=output_attentions,
1085
+ output_hidden_states=output_hidden_states,
1086
+ return_dict=return_dict,
1087
+ )
1088
+ hidden_states_before = outputs_before[0][:, -1:, :]
1089
 
1090
+ # two new tokens: last continuation token and end thought token
1091
+ outputs_after = self.model(
1092
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1093
+ attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
1094
+ position_ids=position_ids,
1095
+ past_key_values=new_key_values,
1096
+ inputs_embeds=inputs_embeds,
1097
+ use_cache=use_cache,
1098
+ output_attentions=output_attentions,
1099
+ output_hidden_states=output_hidden_states,
1100
+ return_dict=return_dict,
1101
+ )
1102
+ hidden_states_after = outputs_after[0][:, -1:, :]
1103
 
1104
+ # Apply the talk head to get the mixing weight
1105
+ mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1106
 
1107
+ # Apply the mixing weight to the hidden states
1108
+ mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1109
 
1110
+ # Apply the language model head to get the final logits
1111
+ logits = self.lm_head(mixed_hidden_states)
1112
+ return logits
1113
 
1114
+ @torch.no_grad()
1115
+ def generate(
1116
+ self,
1117
+ input_ids: torch.LongTensor = torch.LongTensor(),
1118
+ attention_mask: Optional[torch.Tensor] = None,
1119
+ max_new_tokens: Optional[int] = None,
1120
+ temperature: float = 1.1,
1121
+ **kwargs,
1122
+ ):
1123
+ if isinstance(input_ids, str):
1124
+ input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1125
+
1126
+ if attention_mask is None:
1127
+ # Create a default attention mask if not provided
1128
+ attention_mask = torch.ones_like(input_ids)
1129
+
1130
+ from .generate import generate
1131
+ return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1132
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1134
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1135
  def forward(
 
1862
  torch.cuda.empty_cache()
1863
 
1864
 
1865
+ return CausalLMOutputWithPast(
1866
+ loss=loss if loss is not None else None,
1867
+ logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
1868
+ past_key_values=outputs.past_key_values,
1869
+ hidden_states=outputs.hidden_states,
1870
+ attentions=outputs.attentions,
 
 
 
 
1871
  )
1872
 
1873
 
 
1875
  def prepare_inputs_for_generation(
1876
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1877
  ):
1878
+ # Omit tokens covered by past_key_values
1879
+ if past_key_values is not None:
1880
+ if isinstance(past_key_values, Cache):
1881
+ cache_length = past_key_values.get_seq_length()
1882
+ past_length = past_key_values.seen_tokens
1883
+ max_cache_length = past_key_values.get_max_length()
1884
+ else:
1885
+ cache_length = past_length = past_key_values[0][0].shape[2]
1886
+ max_cache_length = None
1887
+
1888
+ # Keep only the unprocessed tokens:
1889
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1890
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
1891
+ # input)
1892
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1893
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1894
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1895
+ # input_ids based on the past_length.
1896
+ elif past_length < input_ids.shape[1]:
1897
+ input_ids = input_ids[:, past_length:]
1898
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1899
+
1900
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1901
+ if (
1902
+ max_cache_length is not None
1903
+ and attention_mask is not None
1904
+ and cache_length + input_ids.shape[1] > max_cache_length
1905
+ ):
1906
+ attention_mask = attention_mask[:, -max_cache_length:]
1907
+
1908
+ position_ids = kwargs.get("position_ids", None)
1909
+ if attention_mask is not None and position_ids is None:
1910
+ # create position_ids on the fly for batch generation
1911
+ position_ids = attention_mask.long().cumsum(-1) - 1
1912
+ position_ids.masked_fill_(attention_mask == 0, 1)
1913
+ if past_key_values:
1914
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1915
+
1916
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1917
+ if inputs_embeds is not None and past_key_values is None:
1918
+ model_inputs = {"inputs_embeds": inputs_embeds}
1919
+ else:
1920
+ model_inputs = {"input_ids": input_ids}
1921
+
1922
+ model_inputs.update(
1923
+ {
1924
+ "position_ids": position_ids,
1925
+ "past_key_values": past_key_values,
1926
+ "use_cache": kwargs.get("use_cache"),
1927
+ "attention_mask": attention_mask,
1928
+ }
1929
+ )
1930
+ return model_inputs
1931
 
1932
  @staticmethod
1933
  def _reorder_cache(past_key_values, beam_idx):