Jackmin108 commited on
Commit
814cbbb
1 Parent(s): 65e9690

some fixes and suggestions

Browse files

Signed-off-by: Meow <ongjackm@gmail.com>

Files changed (5) hide show
  1. embedding.py +2 -2
  2. mha.py +6 -3
  3. mlp.py +2 -2
  4. modeling_lora.py +5 -3
  5. modeling_xlm_roberta.py +2 -1
embedding.py CHANGED
@@ -48,7 +48,7 @@ class XLMRobertaEmbeddings(nn.Module):
48
  """
49
  batch_size, seqlen = input_ids.shape
50
  if adapter_mask is not None:
51
- unique_tasks = torch.unique(adapter_mask).tolist()
52
  embedding_dtype = next(self.word_embeddings.parameters()).dtype
53
  embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
54
  dtype=embedding_dtype, device=input_ids.device)
@@ -71,7 +71,7 @@ class XLMRobertaEmbeddings(nn.Module):
71
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
72
 
73
  if adapter_mask is not None:
74
- unique_tasks = torch.unique(adapter_mask).tolist()
75
  for task_id in unique_tasks:
76
  task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
77
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
 
48
  """
49
  batch_size, seqlen = input_ids.shape
50
  if adapter_mask is not None:
51
+ unique_tasks = torch.unique(adapter_mask)
52
  embedding_dtype = next(self.word_embeddings.parameters()).dtype
53
  embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
54
  dtype=embedding_dtype, device=input_ids.device)
 
71
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
72
 
73
  if adapter_mask is not None:
74
+ unique_tasks = torch.unique(adapter_mask)
75
  for task_id in unique_tasks:
76
  task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
77
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
mha.py CHANGED
@@ -647,7 +647,7 @@ class MHA(nn.Module):
647
  assert x_kv is None and mixer_subset is None
648
 
649
  if cu_adapter_mask is not None:
650
- unique_tasks = torch.unique(cu_adapter_mask).tolist()
651
  qkv_dtype = next(self.Wqkv.parameters()).dtype
652
  qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
653
  dtype=qkv_dtype, device=x.device)
@@ -663,7 +663,10 @@ class MHA(nn.Module):
663
  if not self.return_residual:
664
  qkv = self.Wqkv(x)
665
  else:
666
- qkv, x = self.Wqkv(x)
 
 
 
667
 
668
  if self.dwconv:
669
  qkv = rearrange(
@@ -752,7 +755,7 @@ class MHA(nn.Module):
752
 
753
  inp = rearrange(context, "... h d -> ... (h d)")
754
  if cu_adapter_mask is not None:
755
- unique_tasks = torch.unique(cu_adapter_mask).tolist()
756
  out_dtype = next(self.out_proj.parameters()).dtype
757
  out = torch.empty(inp.shape[0], self.out_proj.out_features,
758
  dtype=out_dtype, device=inp.device)
 
647
  assert x_kv is None and mixer_subset is None
648
 
649
  if cu_adapter_mask is not None:
650
+ unique_tasks = torch.unique(cu_adapter_mask)
651
  qkv_dtype = next(self.Wqkv.parameters()).dtype
652
  qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
653
  dtype=qkv_dtype, device=x.device)
 
663
  if not self.return_residual:
664
  qkv = self.Wqkv(x)
665
  else:
666
+ if hasattr(self.Wqkv, 'parametrizations'):
667
+ qkv, x = self.Wqkv(x, residual=True)
668
+ else:
669
+ qkv, x = self.Wqkv(x)
670
 
671
  if self.dwconv:
672
  qkv = rearrange(
 
755
 
756
  inp = rearrange(context, "... h d -> ... (h d)")
757
  if cu_adapter_mask is not None:
758
+ unique_tasks = torch.unique(cu_adapter_mask)
759
  out_dtype = next(self.out_proj.parameters()).dtype
760
  out = torch.empty(inp.shape[0], self.out_proj.out_features,
761
  dtype=out_dtype, device=inp.device)
mlp.py CHANGED
@@ -49,7 +49,7 @@ class Mlp(nn.Module):
49
 
50
  def forward(self, x, cu_adapter_mask=None):
51
  if cu_adapter_mask is not None:
52
- unique_tasks = torch.unique(cu_adapter_mask).tolist()
53
  fc1_dtype = next(self.fc1.parameters()).dtype
54
  y = torch.empty(x.shape[0], self.fc1.out_features,
55
  dtype=fc1_dtype, device=x.device)
@@ -64,7 +64,7 @@ class Mlp(nn.Module):
64
  y = self.activation(y)
65
 
66
  if cu_adapter_mask is not None:
67
- unique_tasks = torch.unique(cu_adapter_mask).tolist()
68
  fc2_dtype = next(self.fc2.parameters()).dtype
69
  out = torch.empty(y.shape[0], self.fc2.out_features,
70
  dtype=fc2_dtype, device=y.device)
 
49
 
50
  def forward(self, x, cu_adapter_mask=None):
51
  if cu_adapter_mask is not None:
52
+ unique_tasks = torch.unique(cu_adapter_mask)
53
  fc1_dtype = next(self.fc1.parameters()).dtype
54
  y = torch.empty(x.shape[0], self.fc1.out_features,
55
  dtype=fc1_dtype, device=x.device)
 
64
  y = self.activation(y)
65
 
66
  if cu_adapter_mask is not None:
67
+ unique_tasks = torch.unique(cu_adapter_mask)
68
  fc2_dtype = next(self.fc2.parameters()).dtype
69
  out = torch.empty(y.shape[0], self.fc2.out_features,
70
  dtype=fc2_dtype, device=y.device)
modeling_lora.py CHANGED
@@ -355,7 +355,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
355
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
356
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
357
  )
358
- task_id = self._adaptation_map[task_type]
359
- num_examples = 1 if isinstance(sentences, str) else len(sentences)
360
- adapter_mask = torch.full((num_examples,), task_id, dtype=torch.int32)
 
 
361
  return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
 
355
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
356
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
357
  )
358
+ adapter_mask = None
359
+ if task_type:
360
+ task_id = self._adaptation_map[task_type]
361
+ num_examples = 1 if isinstance(sentences, str) else len(sentences)
362
+ adapter_mask = torch.full((num_examples,), task_id, dtype=torch.int32, device=self.device)
363
  return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -314,7 +314,7 @@ class XLMRobertaPooler(nn.Module):
314
  # to the first token.
315
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
316
  if adapter_mask is not None:
317
- unique_tasks = torch.unique(adapter_mask).tolist()
318
  pool_dtype = next(self.dense.parameters()).dtype
319
  pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
320
  dtype=pool_dtype, device=first_token_tensor.device)
@@ -465,6 +465,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
465
  normalize_embeddings: bool = False,
466
  truncate_dim: Optional[int] = None,
467
  adapter_mask: Optional[torch.Tensor] = None,
 
468
  **tokenizer_kwargs,
469
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
470
  """
 
314
  # to the first token.
315
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
316
  if adapter_mask is not None:
317
+ unique_tasks = torch.unique(adapter_mask)
318
  pool_dtype = next(self.dense.parameters()).dtype
319
  pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
320
  dtype=pool_dtype, device=first_token_tensor.device)
 
465
  normalize_embeddings: bool = False,
466
  truncate_dim: Optional[int] = None,
467
  adapter_mask: Optional[torch.Tensor] = None,
468
+ task_type: Optional[str] = None,
469
  **tokenizer_kwargs,
470
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
471
  """