duzx16
commited on
Commit
•
189e5df
1
Parent(s):
f2191d0
Add get_input_embeddings
Browse files- modeling_chatglm.py +3 -0
modeling_chatglm.py
CHANGED
@@ -702,6 +702,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
702 |
dtype=config.torch_dtype, **init_kwargs)
|
703 |
self.gradient_checkpointing = False
|
704 |
|
|
|
|
|
|
|
705 |
def forward(
|
706 |
self,
|
707 |
input_ids,
|
|
|
702 |
dtype=config.torch_dtype, **init_kwargs)
|
703 |
self.gradient_checkpointing = False
|
704 |
|
705 |
+
def get_input_embeddings(self):
|
706 |
+
return self.embedding.word_embeddings
|
707 |
+
|
708 |
def forward(
|
709 |
self,
|
710 |
input_ids,
|