from __future__ import annotations import contextlib import functools import os from typing import Callable, Iterator, Optional, Sequence import torch import torch.nn.functional as F ellipsis = type(...) def get_mean_nbits_by_codebook(codes: torch.IntTensor, huffman_group_size: int = 2): """ Calculates average code length in codebooks. :param codes: codebook codes :param huffman_group_size: huffman compresssion dimension count """ import huffman _, codebook_size, num_codebooks = codes.shape flat_codes_by_codebook = codes.permute(2, 0, 1).flatten(1, 2) code_counts = torch.zeros( num_codebooks, codebook_size, device=flat_codes_by_codebook.device, dtype=flat_codes_by_codebook.dtype ).scatter_add( -1, flat_codes_by_codebook, torch.ones_like(flat_codes_by_codebook) ) # shape: [current beam_size, num_codebooks, codebook_size], initial beam_size = 1 code_probs = code_counts / code_counts.sum(dim=-1, keepdim=True).float() code_probs = code_probs.cpu().numpy() assert num_codebooks % huffman_group_size == 0 mean_code_lengths = [] for group_index in range(num_codebooks // huffman_group_size): group_code_probs = {(): 1} for codebook_index in range(group_index * huffman_group_size, (group_index + 1) * huffman_group_size): new_group_code_probs = {} for group, group_prob in group_code_probs.items(): for code, code_prob in tuple(enumerate(code_probs[codebook_index])): new_group_code_probs[group + (code,)] = group_prob * code_prob group_code_probs = new_group_code_probs huffman_codebook_i = huffman.codebook(list(group_code_probs.items())) codebook_mean_code_length_i = sum( len(huffman_codebook_i[code]) * prob for code, prob in group_code_probs.items() ) mean_code_lengths.append(codebook_mean_code_length_i) return mean_code_lengths def get_int_dtype(nbits: int) -> torch.dtype: if nbits <= 8: return torch.int8 if nbits <= 16: return torch.int16 if nbits <= 32: return torch.int32 if nbits <= 64: return torch.int64 raise ValueError(f"No dtype available for {nbits}-bit codebooks") @torch.inference_mode() def pack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: data[data >= 2 ** (nbits - 1)] -= 2**nbits return data.to(get_int_dtype(nbits)) @torch.inference_mode() def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: return data.to(torch.int64) % (2**nbits) @functools.lru_cache() def maybe_script(fn: callable) -> callable: """Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script.""" using_tpu = bool(os.environ.get("TPU_NAME")) # this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu)) return torch.jit.script(fn) if should_script else fn @contextlib.contextmanager def using_tf32(enabled: bool): was_cudnn = torch.backends.cudnn.allow_tf32 was_matmul = torch.backends.cuda.matmul.allow_tf32 torch.backends.cudnn.allow_tf32 = enabled torch.backends.cuda.matmul.allow_tf32 = enabled yield torch.backends.cudnn.allow_tf32 = was_cudnn torch.backends.cuda.matmul.allow_tf32 = was_matmul def iterate_minibatches( *tensors: torch.Tensor, batch_size: int, allow_incomplete: bool = True, device: Optional[torch.device] = None, callback: Callable[[Sequence[torch.Tensor]], Sequence[torch.Tensor]] = lambda x: x, ) -> Iterator[Sequence[torch.Tensor]]: """ Samples data points *forever*, in random order, with less overhead than DataLoader; Adapted from https://github.com/stanis-morozov/unq/blob/master/lib/utils.py probably implemented over9000 times in transformers, torch, etc :param tensors: one or more tensors with the same 0-th dimension :param batch_size: sample this many points with each yield :param allow_incomplete: if True and if dataset size is not divisible by batch size, the last batch may have less than :batch_size: samples to cover the entire dataset. If False, the last batch is dropped :param callback: optional function to be called on each batch of tensors before it is yielded to the user :returns: generates a tuple of minibatches from each tensor, same length as input *tensors If a batch contains only one tensor, this function will yield a tensor (and not a tuple/list with one tensor) """ num_samples = len(tensors[0]) assert all(len(x) == num_samples for x in tensors) indices = torch.randperm(num_samples, device=tensors[0].device) while True: prev_batch = None for batch_start in range(0, len(indices), batch_size): if not allow_incomplete and batch_start + batch_size > len(indices): break batch_ix = indices[batch_start : batch_start + batch_size] batch = callback(tuple(tensor[batch_ix].to(device, non_blocking=True) for tensor in tensors)) if prev_batch is not None: yield prev_batch prev_batch = batch if isinstance(batch, (list, tuple)) and len(tensors) > 1 else batch[0] del batch yield prev_batch @maybe_script def _dequantize_weight( codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Decode float weights from quantization codes. Differentiable. :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks] :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size] :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size] :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size] """ num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape out_features = num_out_groups * out_group_size in_features = num_in_groups * in_group_size codebook_offsets = torch.arange( 0, num_codebooks * codebook_size, codebook_size, device=codes.device ) # shape: [num_codebooks] reconstructed_weight_flat = F.embedding_bag( codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum" ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size] reconstructed_weight_groupwise = reconstructed_weight_flat.view( list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size] ) if scales is not None: reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])