hywu commited on
Commit
4ad5d70
1 Parent(s): 86f47ef

update modeling file

Browse files
Files changed (1) hide show
  1. modeling_camelidae.py +3 -1
modeling_camelidae.py CHANGED
@@ -20,6 +20,7 @@
20
  """ PyTorch LLaMA model."""
21
  import math
22
  from typing import List, Optional, Tuple, Union
 
23
 
24
  import numpy as np
25
  import copy
@@ -52,7 +53,7 @@ logger = logging.get_logger(__name__)
52
 
53
  _CONFIG_FOR_DOC = "CamelidaeConfig"
54
 
55
-
56
  class MoEModelOutputWithPast(ModelOutput):
57
  last_hidden_state: torch.FloatTensor = None
58
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -61,6 +62,7 @@ class MoEModelOutputWithPast(ModelOutput):
61
  router_logits: Optional[Tuple[torch.FloatTensor]] = None
62
 
63
 
 
64
  class MoECausalLMOutputWithPast(ModelOutput):
65
  loss: Optional[torch.FloatTensor] = None
66
  aux_loss: Optional[torch.FloatTensor] = None
 
20
  """ PyTorch LLaMA model."""
21
  import math
22
  from typing import List, Optional, Tuple, Union
23
+ from dataclasses import dataclass
24
 
25
  import numpy as np
26
  import copy
 
53
 
54
  _CONFIG_FOR_DOC = "CamelidaeConfig"
55
 
56
+ @dataclass
57
  class MoEModelOutputWithPast(ModelOutput):
58
  last_hidden_state: torch.FloatTensor = None
59
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
 
62
  router_logits: Optional[Tuple[torch.FloatTensor]] = None
63
 
64
 
65
+ @dataclass
66
  class MoECausalLMOutputWithPast(ModelOutput):
67
  loss: Optional[torch.FloatTensor] = None
68
  aux_loss: Optional[torch.FloatTensor] = None