File size: 375 Bytes
0f10080 78c5b19 e799e08 78c5b19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
"""
helpers for lora embeddings
"""
def get_linear_embedding_layers(model_type):
"""
returns the linear embedding layers needed for loras, dependent on the model arch
"""
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
if model_type == "falcon":
return ["word_embeddings", "lm_head"]
return ["embed_tokens", "lm_head"]
|