winglian commited on
Commit
04bb993
1 Parent(s): e0668ab

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +3 -4
modeling_llama.py CHANGED
@@ -1,7 +1,8 @@
1
  from typing import Optional, List, Union, Tuple
2
 
3
  import torch
4
- from transformers import LlamaConfig, LlamaModel, Cache, DynamicCache
 
5
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \
6
  _prepare_4d_causal_attention_mask
7
  from transformers.modeling_outputs import BaseModelOutputWithPast
@@ -143,9 +144,7 @@ class MightyLlamaModel(LlamaModel):
143
  )
144
 
145
 
146
- class MightyLlamaForCausalLM(LlamaForCausalLM):
147
- config_class = LlamaConfig
148
-
149
  def __init__(self, config):
150
  super().__init__(config)
151
  self.model = MightyLlamaModel(config)
 
1
  from typing import Optional, List, Union, Tuple
2
 
3
  import torch
4
+ from torch import nn
5
+ from transformers import LlamaModel, Cache, DynamicCache, LlamaForCausalLM
6
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \
7
  _prepare_4d_causal_attention_mask
8
  from transformers.modeling_outputs import BaseModelOutputWithPast
 
144
  )
145
 
146
 
147
+ class MightyLlamaForCausalLM(LlamaForCausalLM):
 
 
148
  def __init__(self, config):
149
  super().__init__(config)
150
  self.model = MightyLlamaModel(config)