Update modeling_cogvlm.py: remove the dependence of triton
Browse files- modeling_cogvlm.py +57 -6
modeling_cogvlm.py
CHANGED
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, A
|
|
5 |
import math
|
6 |
import torch
|
7 |
from torch import nn
|
|
|
8 |
from torch.nn import CrossEntropyLoss
|
9 |
from torchvision import transforms
|
10 |
from einops import rearrange
|
@@ -15,7 +16,6 @@ from transformers.activations import ACT2FN
|
|
15 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
16 |
|
17 |
from .configuration_cogvlm import CogVLMConfig
|
18 |
-
from .util import FastRotaryEmbedding
|
19 |
from .visual import EVA2CLIPModel
|
20 |
|
21 |
if TYPE_CHECKING:
|
@@ -144,6 +144,57 @@ def attention_fn(
|
|
144 |
return context_layer
|
145 |
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
class VisionExpertAttention(nn.Module):
|
148 |
def __init__(self, config):
|
149 |
super().__init__()
|
@@ -153,8 +204,7 @@ class VisionExpertAttention(nn.Module):
|
|
153 |
self.head_dim = self.hidden_size // self.num_heads
|
154 |
self.max_position_embeddings = config.max_position_embeddings
|
155 |
|
156 |
-
|
157 |
-
self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
|
158 |
self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
159 |
self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
160 |
self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
@@ -193,8 +243,8 @@ class VisionExpertAttention(nn.Module):
|
|
193 |
kv_seq_len = key_states.shape[-2]
|
194 |
if past_key_value is not None:
|
195 |
kv_seq_len += past_key_value[0].shape[-2]
|
196 |
-
|
197 |
-
query_states, key_states =
|
198 |
|
199 |
if past_key_value is not None:
|
200 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
@@ -706,7 +756,8 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
|
|
706 |
# update token_type_ids with last value
|
707 |
if "token_type_ids" in model_kwargs:
|
708 |
token_type_ids = model_kwargs["token_type_ids"]
|
709 |
-
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
|
|
|
710 |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
711 |
|
712 |
if not is_encoder_decoder:
|
|
|
5 |
import math
|
6 |
import torch
|
7 |
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
from torch.nn import CrossEntropyLoss
|
10 |
from torchvision import transforms
|
11 |
from einops import rearrange
|
|
|
16 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
17 |
|
18 |
from .configuration_cogvlm import CogVLMConfig
|
|
|
19 |
from .visual import EVA2CLIPModel
|
20 |
|
21 |
if TYPE_CHECKING:
|
|
|
144 |
return context_layer
|
145 |
|
146 |
|
147 |
+
class RotaryEmbedding(torch.nn.Module):
|
148 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
149 |
+
super().__init__()
|
150 |
+
|
151 |
+
self.dim = dim
|
152 |
+
self.max_position_embeddings = max_position_embeddings
|
153 |
+
self.base = base
|
154 |
+
inv_freq = self._compute_inv_freq(device)
|
155 |
+
self.register_buffer("inv_freq", inv_freq)
|
156 |
+
self.max_seq_len_cached = 0
|
157 |
+
|
158 |
+
def _compute_inv_freq(self, device=None):
|
159 |
+
return 1.0 / (
|
160 |
+
self.base
|
161 |
+
** (torch.arange(0, self.dim, 2, device=device) / self.dim)
|
162 |
+
)
|
163 |
+
|
164 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
165 |
+
self.max_seq_len_cached = seq_len
|
166 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
167 |
+
|
168 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
169 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
170 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
171 |
+
self.register_buffer("cos_cached", emb.cos()[:, None, :].to(dtype), persistent=False)
|
172 |
+
self.register_buffer("sin_cached", emb.sin()[:, None, :].to(dtype), persistent=False)
|
173 |
+
|
174 |
+
def forward(self, x, seq_len):
|
175 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
176 |
+
if seq_len > self.max_seq_len_cached:
|
177 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
178 |
+
|
179 |
+
return (
|
180 |
+
self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
|
181 |
+
self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
|
182 |
+
)
|
183 |
+
|
184 |
+
|
185 |
+
def rotate_half(x):
|
186 |
+
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
187 |
+
return torch.cat((-x2, x1), dim=x1.ndim - 1)
|
188 |
+
|
189 |
+
|
190 |
+
def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
|
191 |
+
# batch_size, num_head, seq_len, hidden_size
|
192 |
+
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \
|
193 |
+
F.embedding(position_id, sin.squeeze(1)).unsqueeze(1)
|
194 |
+
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
195 |
+
return q, k
|
196 |
+
|
197 |
+
|
198 |
class VisionExpertAttention(nn.Module):
|
199 |
def __init__(self, config):
|
200 |
super().__init__()
|
|
|
204 |
self.head_dim = self.hidden_size // self.num_heads
|
205 |
self.max_position_embeddings = config.max_position_embeddings
|
206 |
|
207 |
+
self.rotary_emb = RotaryEmbedding(self.head_dim)
|
|
|
208 |
self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
209 |
self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
210 |
self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
|
|
243 |
kv_seq_len = key_states.shape[-2]
|
244 |
if past_key_value is not None:
|
245 |
kv_seq_len += past_key_value[0].shape[-2]
|
246 |
+
cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
|
247 |
+
query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids)
|
248 |
|
249 |
if past_key_value is not None:
|
250 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
|
756 |
# update token_type_ids with last value
|
757 |
if "token_type_ids" in model_kwargs:
|
758 |
token_type_ids = model_kwargs["token_type_ids"]
|
759 |
+
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
|
760 |
+
device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
|
761 |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
762 |
|
763 |
if not is_encoder_decoder:
|