File size: 39,564 Bytes
b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 b9012bf 35354b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 |
# See: https://huggingface.co/docs/transformers/custom_models
from typing import Optional, Tuple, Union, List
import math
import copy
import sys
from importlib import import_module
import torch
from torch import nn, Tensor
import torch.nn.init as init
from torch.nn import functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutput, CausalLMOutputWithPast
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoConfig,
AutoModel,
AutoModelForCausalLM,
)
from transformers.utils import logging
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
logger = logging.get_logger(__name__)
# The model type string to bind.
model_type = "walsh-causal-v1"
class Config(PretrainedConfig):
model_type = model_type
attribute_map = {
"hidden_size": "d_embed",
}
def __init__(
# All of these MUST have defaults, even if unused.
self,
vocab_size=16000,
pad_index=None,
hidden_size=1024,
num_attention_heads=8,
num_hidden_layers=6,
max_sequence_length=2048,
dim_feedforward = 4096,
dropout=0.1,
loss_function = "causal_loss",
# Default class to use for each of these components.
positional_encoder_cls='.PositionalEncoder',
attention_cls='.CausalSelfAttention',
activation_cls='torch.nn.ReLU',
feedforward_cls='.FeedforwardLayer',
layer_stack_cls='.TransformerLayerStack',
layer_cls='.PostLayerNorm',
transformer_cls='.Transformer',
norm_cls='torch.nn.LayerNorm',
embdding_cls='torch.nn.Embedding',
output_proj_cls='torch.nn.Linear',
positional_encoder_args={
'd_model': 1024,
'max_seq_len': 2048,
},
# Arg groups, passed to factory classes above.
transformer_args=dict(),
attention_args=dict(),
feedforward_args=dict(),
activation_args=dict(),
norm_args={
'normalized_shape': 1024,
},
layer_stack_args=dict(),
layer_args=dict(),
embedding_args=dict(),
output_proj_args=dict(),
output_attentions=False,
output_hidden_states=False,
use_cache=True,
**kwargs,
):
self.vocab_size = vocab_size
self.pad_index = pad_index
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.max_sequence_length = max_sequence_length
self.loss_function = loss_function
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.positional_encoder_cls = positional_encoder_cls
self.attention_cls = attention_cls
self.activation_cls = activation_cls
self.feedforward_cls = feedforward_cls
self.layer_stack_cls = layer_stack_cls
self.layer_cls = layer_cls
self.transformer_cls = transformer_cls
self.norm_cls = norm_cls
self.embdding_cls = embdding_cls
self.output_proj_cls = output_proj_cls
self.positional_encoder_args = positional_encoder_args
self.transformer_args = transformer_args
self.attention_args = attention_args
self.feedforward_args = feedforward_args
self.activation_args = activation_args
self.norm_args = norm_args
self.layer_stack_args = layer_stack_args
self.layer_args = layer_args
self.embedding_args = embedding_args
self.output_proj_args = output_proj_args
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.use_cache = use_cache
super().__init__(**kwargs)
def causal_loss(logits: Tensor, labels: Tensor, input_ids: Tensor, ignore_index=-100) -> Tensor:
"""
Compute and return the loss using logits and labels.
"""
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=ignore_index,
reduction='mean',
)
return loss.nan_to_num()
# Learning to Break the Loop: Analyzing and Mitigating Repetitions for Neural Text Generation
# https://arxiv.org/abs/2206.02369
def ditto_loss(logits: Tensor, labels: Tensor, input_ids: Tensor) -> Tensor:
batch_size, seq_len, vocab_size = logits.shape
rep_reduce_gamma = 0.5
ditto_weight = 1.0e5
probs = torch.softmax(logits, dim=-1)
total_loss = None
for i in range(batch_size):
context_len = labels[i, 0].item()
sentence_len = labels[i, 1].item()
n_repeats = labels[i, 2].item()
# For readability
context_end = context_len
sentence_start = context_len
sentence_end = sentence_start + sentence_len
target_start = sentence_end
# Get causal loss for context tokens
causal_ids = input_ids[i:i+1, :context_end]
c_loss = causal_loss(
logits=logits[i:i+1, :context_end],
labels=causal_ids,
input_ids=causal_ids
)
# Slice out target probabilities
target_probs = probs[i , target_start:, :]
# Slice out first instance of repeated sentence, detach is (prevents back-prop), repeat in N times,
# and trim to length of target_probs.
baseline_probs = probs[i, sentence_start:sentence_end, :].detach().repeat(n_repeats, 1)[:target_probs.size(0), :]
# Compute DITTO loss.
one_minus_probs = torch.clamp((1.0 - torch.abs((target_probs - baseline_probs * rep_reduce_gamma))), min=1e-20)
r_loss = -torch.log(one_minus_probs).mean() * ditto_weight
# Combine repitition and causal loss
loss = c_loss + r_loss
# Add this to the total
if total_loss is None:
total_loss = loss
else:
total_loss += loss
return total_loss / batch_size
# Dynamically lookup class name and return factory for class.
def get_dynamic_class(name):
try:
module_path, class_name = name.rsplit('.', 1)
if module_path == "":
return getattr(sys.modules[__name__], class_name)
module = import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
raise ImportError(name)
# An easily extensible dynamic transformer class
# Many variations can be specified entirely in the configuration, without touching this code.
class HFCausalModel(PreTrainedModel):
config_class = Config
model_type = 'Transformer'
supports_gradient_checkpointing = True
# Presently needs to be manually set to match transformer layer class...
_no_split_modules = ["DeepNetLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_skip_keys_device_placement = "past_key_values"
def __init__(self, config):
super().__init__(config)
self.d_model = config.hidden_size
self.transformer_head = self._make_transformer(config)
self.loss_function = get_dynamic_class(config.loss_function)
self.gradient_checkpointing = False
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> (Tensor, dict[str, Tensor]):
batch_size, seq_len = input_ids.shape
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
# If legacy cache, convert to DynamicCache
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
gradient_checkpointing_func = self._gradient_checkpointing_func
else:
gradient_checkpointing_func = None
outputs = self.transformer_head(
input_ids=input_ids,
position_ids=position_ids,
output_attentions=output_attentions,
gradient_checkpointing_func=gradient_checkpointing_func,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
logits = outputs["logits"].float()
attentions = outputs["attentions"]
# Compute loss.
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids)
else:
loss = None
# Convert back to legacy cache, if that's what we received
new_cache = outputs["past_key_values"]
if use_cache and new_cache is not None and use_legacy_cache:
new_cache = new_cache.to_legacy_cache()
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=new_cache,
hidden_states=outputs["hidden_states"],
attentions=outputs["attentions"],
)
# Implementation from Huggingface Transformers,
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py
# Note: We do not implement attention mask at present, so some of this code is not applicable
# TODO: Reenable attention mask support for batch inference..
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# NOTE: Injecting positional embeddings is not yet supported.
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
def _make_embedding(self, config):
embedding_cls = get_dynamic_class(config.embdding_cls)
return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args)
def _make_pos_encoder(self, config):
pos_enc_cls = get_dynamic_class(config.positional_encoder_cls)
return pos_enc_cls(**config.positional_encoder_args)
def _make_output_projection(self, config):
output_proj_cls = get_dynamic_class(config.output_proj_cls)
return output_proj_cls(self.d_model, config.vocab_size, **config.output_proj_args)
def _make_dropout(self, config):
return nn.Dropout(config.dropout)
def _make_activation(self, config):
activation_cls = get_dynamic_class(config.activation_cls)
return activation_cls(**config.activation_args)
def _make_norm(self, config):
norm_cls = get_dynamic_class(config.norm_cls)
return norm_cls(self.d_model)
def _make_self_attention(self, layer_idx, config):
attention_cls = get_dynamic_class(config.attention_cls)
# Map HF _attn_implementation to attn_type
match config._attn_implementation:
case "flash_attention_2":
if is_flash_attn_2_available():
if not is_flash_attn_greater_or_equal_2_10():
raise Exception("flash_attn_2 >= 2.10 is required")
attn_type = "flash2"
else:
attn_type = "torch"
case "sdpa":
attn_type = "torch"
case "eager":
attn_type = "native"
case _:
raise Exception(f"Unimplemented attention type '{config._attn_implementation}'")
return attention_cls(
d_model=self.d_model,
num_heads=config.num_attention_heads,
attn_type=attn_type,
layer_idx=layer_idx,
config=config,
**config.attention_args,
)
def _make_feedforward(self, layer_idx, config):
feedforward_cls = get_dynamic_class(config.feedforward_cls)
return feedforward_cls(
d_model=self.d_model,
feedforward_dim=config.dim_feedforward,
dropout=config.dropout,
activation=self._make_activation(config),
layer_idx=layer_idx,
**config.feedforward_args,
)
def _make_layer(self, layer_idx, config):
layer_cls = get_dynamic_class(config.layer_cls)
return layer_cls(
d_model=self.d_model,
dropout=self._make_dropout(config),
attention=self._make_self_attention(layer_idx, config),
feedforward=self._make_feedforward(layer_idx, config),
norm1=self._make_norm(config),
norm2=self._make_norm(config),
layer_idx=layer_idx,
**config.layer_args,
)
def _make_layer_stack(self, config):
layer_stack_cls = get_dynamic_class(config.layer_stack_cls)
return layer_stack_cls(
layers=nn.ModuleList([
self._make_layer(layer_idx, config) for layer_idx in range(config.num_hidden_layers)
]),
**config.layer_stack_args,
)
def _make_transformer(self, config):
transformer_cls = get_dynamic_class(config.transformer_cls)
return transformer_cls(
d_model=self.d_model,
embedding=self._make_embedding(config),
positional_encoder=self._make_pos_encoder(config),
layer_stack=self._make_layer_stack(config),
output_projection=self._make_output_projection(config),
**config.transformer_args,
)
@torch.no_grad()
def _init_weights(self, module):
pass
# Register model type and configuration
AutoConfig.register(model_type, Config)
AutoModelForCausalLM.register(Config, HFCausalModel)
# A generic container class for standard transformer components.
class Transformer(nn.Module):
def __init__(self, d_model, embedding, positional_encoder, layer_stack, output_projection, **kwargs):
super().__init__()
self.embedding = embedding
self.positional_encoder = positional_encoder
self.layer_stack = layer_stack
self.output_projection = output_projection
self.d_model = d_model
self.sqrt_d_model = d_model**0.5
self.reset_parameters()
def forward(
self,
input_ids,
position_ids,
output_attentions,
gradient_checkpointing_func,
past_key_values,
use_cache,
output_hidden_states,
):
outputs = self.layer_stack(
self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model, position_ids),
output_attentions=output_attentions,
gradient_checkpointing_func=gradient_checkpointing_func,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
# Translate output states to logits.
outputs["logits"] = self.output_projection(outputs["last_hidden_state"])
del outputs["last_hidden_state"]
return outputs
def reset_parameters(self):
init.xavier_uniform_(self.output_projection.weight)
init.constant_(self.output_projection.bias, 0.)
init.normal_(self.embedding.weight, std=self.d_model**-0.5)
# Converts a torch array of integers into their equivalent binary codes.
def binary_tensor(x, bits):
mask = 2**torch.arange(bits).to(x.device, x.dtype)
return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
def hadamard_walsh_matrix(k: int):
# k: The dimension of the matrix is 2^k
assert k > 0
# Start with Hadamard H2^1 matrix.
h1 = torch.tensor([[1, 1], [1, -1]], dtype=torch.float)
# The series of matrices can be computed by recurisvely applying the Kronecker product,
# starting with h1.
#
# This will produce the series of Hadamard-Wlash matrices in natural order.
w = h1
for _ in range(k-1):
w = torch.kron(h1, w)
return w
# This positional encoder adds absolute binary positions to the embedding, encoded via
# Hadamard-Walsh matrix.
# See: https://en.wikipedia.org/wiki/Hadamard_code
# Each bit in the binary code word is encoded via a row the Hadamard-Walsh matrix, with a
# 1 being encoded by the presense of the row and a 0 by its absence. While training, the base
# sequence offset is randomly selected, which appears to allow the model to generalize to
# sequences longer than it was trained on. This is similar to what is described here:
# https://arxiv.org/pdf/2305.16843.pdf
# I have tried this approach and found that my approach works better for generalization.
#
# Note: Without random shifting, the early performance of this encoder is exceptionally good.
# The drawback is that the model can't generalize to longer sequences than it was trained on
# and can't easily accomidate additonal bits later in the training process.
class RSWalshPositionalEncoder(nn.Module):
def __init__(self, d_embed, max_seq, gain=0.333):
super().__init__()
self.max_seq = max_seq
self.d_embed = d_embed
# Hadamard-Walsh k, where the dimension of the matrix is 2^k
k = math.ceil(math.log2(d_embed))
# The number of bits required to encode max_seq
bits = math.ceil(math.log2(max_seq))
# Gain controls the weight given to the encodings.
# When a trainable parameter, the value appears to settle at around 0.333.
self.gain = gain
assert bits <= d_embed, "max_seq exceeds n-bits available for d_embed"
# Generate sequential binary codes for absolute positionals.
# The implementation originally used Grey codes, which where successive symbols
# differ by by only one bit. See: https://en.wikipedia.org/wiki/Gray_code
# This, along with a few other coding schemes were tested, with a simple
# binary code having the best performance.
binary_code = binary_tensor(torch.arange(0, max_seq, 1), bits)
self.register_buffer('binary_code', binary_code, persistent=False)
# Each bit is encoded via a row of a Hadamard-Walsh matrix.
# We slice off the unused rows and columns -- ideally, d_embed should be
# the same dimension as the matrix.
walsh = hadamard_walsh_matrix(k)[:bits,:d_embed] * self.gain
# This alternative appears superior to the original.
# If starting from scratch, this use this.
# walsh = (hadamard_walsh_matrix(k)[:bits,:d_embed] -0.5) * self.gain
self.register_buffer('walsh', walsh, persistent=False)
def forward(self, x, position_ids=None):
seq_len = x.size(-2)
# Get sequence of binary codes...
# We use a random base offset when training.
# This results in slower initial gains, but appears to allow the model to generalize to
# the value of max_seq, even if never trained with sequences of this length. I also have
# a suspicion that this has a regularizing effect on training, similar to dropout. Models with
# random base offset shifting, despite slower initial improvement, appear to perform better in the long-run.
# TODO: Setup a controlled experiment to test this hypothesis.
if self.training:
shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item()
seq = self.binary_code[shift:seq_len + shift,:]
# When the cache is used for generation, after the first call, we are only passed a single token at a time,
# with the remaining tokens being in the cache. We need to make sure that the newly injected tokens have the
# correct relative position by indexing the codes with the position_ids.
elif position_ids != None:
seq = self.binary_code[position_ids, :]
# Disable shifting when not training. This does not appear to change the evaluation loss, but
# it does makes predictions easier to analyse when the attention weights are not shifting with each step.
else:
seq = self.binary_code[:seq_len,:]
# For reasons I have yet to identify, when the model is running in Textgenwebui, the matrix appears
# to evade conversion to bfloat16, despite everything else having been converted.
# This is a work-around for this.
self.walsh = self.walsh.to(dtype=x.dtype)
# Encode binary sequence with Hadamard-Walsh codes and apply to embeddings.
# If nothing else, the Walsh encodings make the positional information exceptionally
# robust with respect to dropout and other adversities. They can still be easily detected
# at the final layer.
return x + (seq.to(dtype=x.dtype) @ self.walsh)
# A generic stack of transformer layers.
class TransformerLayerStack(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = layers
def forward(
self,
hidden_states,
output_attentions,
past_key_values,
use_cache,
output_hidden_states,
gradient_checkpointing_func=None,
):
present_key_value = None
all_attentions = [] if output_attentions else None
all_hidden_states = [hidden_states] if output_hidden_states else None
for layer in self.layers:
if gradient_checkpointing_func is not None:
layer_outputs = gradient_checkpointing_func(
layer.__call__,
hidden_states,
output_attentions,
past_key_values,
use_cache,
use_reentrant=False,
)
else:
layer_outputs = layer(
hidden_states,
output_attentions,
past_key_values,
use_cache,
)
hidden_states = layer_outputs["hidden_states"]
if output_hidden_states:
all_hidden_states.append(hidden_states)
if use_cache:
present_key_value = layer_outputs["past_key_values"]
if output_attentions:
all_attentions.append(layer_outputs["attentions"])
return dict(
last_hidden_state=hidden_states,
past_key_values=present_key_value,
hidden_states=hidden_states,
attentions=all_attentions,
)
# DeepNet: Scaling Transformers to 1,000 Layers
# https://arxiv.org/abs/2203.00555
# Note: This is a type of Pre-Layer-Norm Transformer layer.
class DeepnetLayer(nn.Module):
def __init__(
self,
d_model,
attention,
feedforward,
norm1,
norm2,
dropout,
layer_idx,
alpha=1.0,
):
super().__init__()
self.d_model = d_model
self.attention = attention
self.feedforward = feedforward
self.norm1 = norm1
self.norm2 = norm2
self.dropout = dropout
# Deepnet alpha
self.alpha = alpha
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
output_attentions,
past_key_values,
use_cache,
):
# Keep input as residual
residual = hidden_states * self.alpha
# Compute attention
attn_outputs = self.attention(
hidden_states,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
hidden_states = attn_outputs["hidden_states"]
# Add attention with residual and normalize.
hidden_states = self.norm1(residual + self.dropout(hidden_states))
# Keep output as next residual.
residual = hidden_states * self.alpha
# Pass through feedforward network.
hidden_states = self.feedforward(hidden_states)
# Combine residual and ff output, then normalize again.
hidden_states = self.norm2(residual + self.dropout(hidden_states))
return dict(
hidden_states=hidden_states,
attentions=attn_outputs["attentions"],
past_key_values=attn_outputs["past_key_values"]
)
# A vanilla MLP transfomer layer.
class FeedforwardLayer(nn.Module):
def __init__(
self,
d_model: int,
feedforward_dim: int,
dropout,
layer_idx,
activation=nn.ReLU(),
beta=1.0,
bias=True,
):
super().__init__()
self.d_model = d_model
self.beta = beta
self.activation = activation
self.linear1 = nn.Linear(d_model, feedforward_dim, bias=bias)
self.linear2 = nn.Linear(feedforward_dim, d_model, bias=bias)
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def forward(self, x):
return self.linear2(self.dropout(self.activation(self.linear1(x))))
def reset_parameters(self):
init.xavier_uniform_(self.linear1.weight, gain=self.beta)
init.xavier_uniform_(self.linear2.weight, gain=self.beta)
init.constant_(self.linear1.bias, 0.)
init.constant_(self.linear2.bias, 0.)
class CausalSelfAttention(nn.Module):
def __init__(
self,
d_model,
num_heads,
# values:
# native: Use local impementation; slowest option; good for debugging; useful when experimenting with non-standard stuff.
# torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
# flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage.
attn_type,
layer_idx,
config,
beta=1.0,
dropout=0.1,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.beta = beta
self.attn_type = attn_type
self.layer_idx = layer_idx
self.config = config
assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
# The dimension of each head.
self.d_head = d_model // num_heads
# We scale the attention scores by the inverse-square-root of the head dimension
# this shifts the temerature of softmax.
self.dot_product_scale = 1.0 / math.sqrt(self.d_head)
self.in_proj = nn.Linear(self.d_model, 3 * self.d_model, bias=True)
self.output_linear = nn.Linear(self.d_model, self.d_model, bias=True)
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def extra_repr(self) -> str:
return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, dropout={self.dropout}'
def reset_parameters(self):
# Deepnet initialization
# https://arxiv.org/pdf/2203.00555.pdf
q, k, v = self.in_proj.weight.chunk(3)
init.xavier_uniform_(q, gain=1.0)
init.xavier_uniform_(k, gain=1.0)
init.xavier_uniform_(v, gain=self.beta)
init.xavier_uniform_(self.output_linear.weight, gain=self.beta)
init.constant_(self.in_proj.bias, 0.)
init.constant_(self.output_linear.bias, 0.)
# Project QKV input through input matrices, reshape to (batch_size, n_heads, seq_len, d_model), and apply cache.
def project_input(self, qkv, past_key_values):
batch_size, seq_len, d_embed = qkv.shape
proj = self.in_proj(qkv)
query, key, value = proj.chunk(chunks=3, dim=-1)
# Split projections into multiple heads and swap position of sequence / heads dimension
query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
# Update the cache values.
if past_key_values is not None:
key, value = past_key_values.update(key, value, self.layer_idx)
return query, key, value
def forward(
self,
qkv,
output_attentions,
past_key_values,
use_cache,
):
attn_type = self.attn_type
if output_attentions and attn_type != "native":
logger.warning_once(
"CausalSelfAttention(output_attentions=True) and attn_type is not 'native': "
"Forcing native attention."
)
attn_type = "native"
if attn_type == "flash2":
if use_cache is None or use_cache == False:
return self.flash2_forward(qkv)
else:
return self.flash2_forward_cached(qkv, past_key_values)
# qkv: (batch_size, seq_len, d_embed)
batch_size, seq_len, d_embed = qkv.shape
# Feed the inputs through the K, Q, V matrices.
query, key, value = self.project_input(qkv, past_key_values)
kv_seq_len = key.shape[-2]
# Default to returning empty attention weights.
attentions = None
# https://github.com/pytorch/pytorch/issues/112577
if attn_type == "torch":
# This context manager can be used to force which implementation to use.
#with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
attended_values = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=(seq_len > 1),
scale=self.dot_product_scale
)
# "native" scaled-dot-product attention implementation.
else:
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
# Mask future positions from the past
if seq_len > 1:
scores.masked_fill_(
torch.tril(
torch.ones(seq_len, kv_seq_len, dtype=torch.bool, device=qkv.device),
diagonal=0,
).logical_not(),
float('-inf'),
)
# Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
attentions = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
del scores
# Use the attention weights to get a weighted combination of value vectors
attended_values = torch.matmul(attentions, value)
if not output_attentions:
del attentions
attentions = None
# Concatenate attention heads and project to original embedding size using the output linear layer
attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
# Project the concatenated output through the output matrix.
attended_values = self.output_linear(attended_values)
return dict(
hidden_states=attended_values,
attentions=attentions,
past_key_values=past_key_values
)
# No cache support, but faster
def flash2_forward(
self,
qkv,
):
batch_size, seq_len, d_embed = qkv.shape
# Feed the inputs through the K, Q, V matrices.
# query : (batch_size, seq_len, d_model)
# qkv : (batch_size, seq_len, 3, num_heads, d_kq)
# Feed the inputs through the K, Q, V matrices.
# query : (batch_size, seq_len, d_model)
# qkv : (batch_size, seq_len, 3, num_heads, d_kq)
qkv = self.in_proj(qkv).unflatten(
-1,
(3, self.num_heads, self.d_head)
)
attended_values = flash_attn_qkvpacked_func(
self._downcast_to_float16(qkv)[0],
dropout_p=self.dropout.p if self.training else 0.0,
softmax_scale=self.dot_product_scale,
causal=True,
)
# attended_values: (batch_size, seqlen, nheads, headdim)
# Concatentate heads back into d_embed
attended_values = attended_values.view(batch_size, seq_len, d_embed)
# Project the concatenated output through the output matrix.
attended_values = self.output_linear(attended_values)
return dict(
hidden_states=attended_values,
attentions=None,
past_key_values=None
)
# See https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
#https://huggingface.co/docs/transformers/internal/generation_utils
def flash2_forward_cached(
self,
qkv,
past_key_values,
):
batch_size, seq_len, d_embed = qkv.shape
# Feed the inputs through the K, Q, V matrices.
query, key, value = self.project_input(qkv, past_key_values)
query, key, value = self._downcast_to_float16(query, key, value)
# Expected inputs to flash2:
# q: (batch_size, seqlen, nheads, headdim)
# k: (batch_size, seqlen, nheads_k, headdim)
# v: (batch_size, seqlen, nheads_k, headdim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attended_values = flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=self.dropout.p if self.training else 0.0,
softmax_scale=self.dot_product_scale,
causal=True,
)
# attended_values: (batch_size, seqlen, nheads, headdim)
# Concatentate heads back into d_embed
attended_values = attended_values.view(batch_size, seq_len, d_embed)
# Project the concatenated output through the output matrix.
attended_values = self.output_linear(attended_values)
return dict(
hidden_states=attended_values,
attentions=None,
past_key_values=past_key_values
)
def _downcast_to_float16(self, *args):
if args[0].dtype != torch.float32:
return args
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.output_linear.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
return (arg.to(target_dtype) for arg in args) |