Sandroeth commited on
Commit
75b4fe7
·
verified ·
1 Parent(s): 69654ae

Update modeling_cali.py

Browse files
Files changed (1) hide show
  1. modeling_cali.py +4 -1
modeling_cali.py CHANGED
@@ -234,7 +234,7 @@ class CALIModel(CALIPreTrainedModel):
234
 
235
 
236
  class CALIForCausalLM(CALIPreTrainedModel, GenerationMixin):
237
- _tied_weights_keys = ["lm_head.weight"]
238
 
239
  def __init__(self, config: CALIConfig):
240
  super().__init__(config)
@@ -250,6 +250,9 @@ class CALIForCausalLM(CALIPreTrainedModel, GenerationMixin):
250
  def set_input_embeddings(self, value):
251
  self.model.embed = value
252
 
 
 
 
253
  def get_output_embeddings(self):
254
  return self.lm_head
255
 
 
234
 
235
 
236
  class CALIForCausalLM(CALIPreTrainedModel, GenerationMixin):
237
+
238
 
239
  def __init__(self, config: CALIConfig):
240
  super().__init__(config)
 
250
  def set_input_embeddings(self, value):
251
  self.model.embed = value
252
 
253
+ def get_tied_weights(self):
254
+ return {"lm_head.weight": "model.embed.weight"} if self.config.tie_embeddings else {}
255
+
256
  def get_output_embeddings(self):
257
  return self.lm_head
258