Update modeling_llama.py
Browse files- modeling_llama.py +3 -4
modeling_llama.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
from typing import Optional, List, Union, Tuple
|
2 |
|
3 |
import torch
|
4 |
-
from
|
|
|
5 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \
|
6 |
_prepare_4d_causal_attention_mask
|
7 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
@@ -143,9 +144,7 @@ class MightyLlamaModel(LlamaModel):
|
|
143 |
)
|
144 |
|
145 |
|
146 |
-
class MightyLlamaForCausalLM(LlamaForCausalLM):
|
147 |
-
config_class = LlamaConfig
|
148 |
-
|
149 |
def __init__(self, config):
|
150 |
super().__init__(config)
|
151 |
self.model = MightyLlamaModel(config)
|
|
|
1 |
from typing import Optional, List, Union, Tuple
|
2 |
|
3 |
import torch
|
4 |
+
from torch import nn
|
5 |
+
from transformers import LlamaModel, Cache, DynamicCache, LlamaForCausalLM
|
6 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \
|
7 |
_prepare_4d_causal_attention_mask
|
8 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
|
144 |
)
|
145 |
|
146 |
|
147 |
+
class MightyLlamaForCausalLM(LlamaForCausalLM):
|
|
|
|
|
148 |
def __init__(self, config):
|
149 |
super().__init__(config)
|
150 |
self.model = MightyLlamaModel(config)
|