NEOX / tests /model /test_fused_kernels.py
akswelh's picture
Upload 251 files
d90b3a8 verified
raw
history blame
8.07 kB
# Copyright (c) 2024, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pytest
import torch
from transformers import BertTokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers import BertTokenizer, GPT2Tokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from megatron.fused_kernels import load
import transformers
transformers.logging.set_verbosity(
transformers.logging.FATAL,
)
@pytest.mark.xfail(reason="SystemExit: None")
def test_load_fused_kernels():
load()
try:
import scaled_masked_softmax_cuda
import scaled_upper_triang_masked_softmax_cuda
import fused_rotary_positional_embedding
import torch
print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e
@pytest.mark.xfail(reason="SystemExit: None")
def test_fused_softmax():
load()
from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes
from megatron.model.gpt2_model import (
gpt2_attention_mask_func as attention_mask_func,
)
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
embedding_output = bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
# (bsz, 1, 1, seq_len)
mask = bert.get_extended_attention_mask(
attention_mask=tokens["attention_mask"].cuda(),
input_shape=tokens["input_ids"].shape,
device=bert.device,
)
# (bsz, 1, seq_len, seq_len)
mask = mask.repeat(1, 1, mask.size()[-1], 1)
attention = bert.encoder.layer[0].attention.self
key_layer = attention.transpose_for_scores(attention.key(embedding_output))
query_layer = attention.transpose_for_scores(attention.query(embedding_output))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores /= math.sqrt(key_layer.size()[-1])
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
fusion_type=SoftmaxFusionTypes.general,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attention_scores,
(mask != 0),
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
fusion_type=SoftmaxFusionTypes.none,
scale=None,
softmax_in_fp32=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attention_scores,
(mask != 0),
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
@pytest.mark.xfail(reason="SystemExit: None")
def test_fused_upper_triangle_mask_softmax():
load()
from megatron.model.gpt2_model import (
gpt2_attention_mask_func as attention_mask_func,
)
from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes
gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi" # 24
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
attention_mask = tokens["attention_mask"].cuda()
attention_mask = attention_mask.view(attention_mask.size(0), -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1)
attn = gpt.h[0]
hidden_states = gpt.wte(tokens["input_ids"].cuda())
q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1)
q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim)
k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim)
attn_weights = torch.matmul(q, k.transpose(-1, -2))
sq, sk = q.size(-2), k.size(-2)
causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool()
total_mask = ~(causal_mask & (attention_mask == 0))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
fusion_type=SoftmaxFusionTypes.upper_triang,
scale=None,
softmax_in_fp32=False,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attn_weights,
total_mask,
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
fusion_type=SoftmaxFusionTypes.none,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attn_weights,
total_mask,
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)