File size: 3,400 Bytes
69e075a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from .modeling_exaone import *
from beagle.mixin import *


class ExaoneBeagleAttention_(ExaoneSelfAttention, BeagleAttentionMixin):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        **kwargs # ['encoder_hidden_states', 'encoder_position_embeddings']
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        bsz, q_len, _ = hidden_states.size()
        query_states, key_states, value_states = self.qkv_transform(
            hidden_states, past_key_value, use_cache, position_embeddings, **kwargs)

        ################################################
        ### everything kept original starting from here
        ################################################

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout_rate, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"Attention outputs should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()

        attn_output = self.out_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class ExaoneBeagleAttention(ExaoneAttention):
    def __init__(self, config, layer_id=0):
        super().__init__(config, layer_id)
        self.attention = ExaoneBeagleAttention_(config, self.layer_id)


class ExaoneBeagleLayer(ExaoneBlock):
    def __init__(self, config, layer_id):
        super().__init__(config, layer_id)

        if not config.beagle_use_fc_eagle:
            delattr(self, 'attn')
            recycle_vram()
            self.attn = ExaoneBeagleAttention(
                config=config, layer_id=0
            )


class ExaoneForSpeculativeCausalLM(ExaoneForCausalLM, BeagleMixin):
    _no_split_modules = ["ExaoneBlock", "ExaoneBeagleLayer"]

    def __init__(self, config):
        super().__init__(config)

        BeagleMixin.__init__(self, config)
        self.speculative_decoder = ExaoneBeagleLayer(config, layer_id=0)

        self.post_init()

    def forward(self, *args, **kwargs) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
        return self.beagle_forward(*args, **kwargs)