|
from __future__ import annotations |
|
|
|
import torch |
|
from torch import amin |
|
import copy |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from scipy.optimize import curve_fit |
|
from typing import Dict, Any, Tuple, List, Callable |
|
|
|
|
|
def quantization(x, **params): |
|
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acosh(domain_guard((params['_0'] * x), min=1, nan=1))) |
|
|
|
|
|
def dequantization(x, **params): |
|
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cosh((params['_s'] * x))) |
|
|
|
|
|
def init_linear_scale( |
|
x: torch.Tensor, |
|
**kwargs: Dict[str, Any], |
|
) -> torch.Tensor: |
|
assert "bits" in kwargs, "bits must be provided." |
|
assert "params" in kwargs, "params must be provided." |
|
assert "qtz_func" in kwargs, "qtz_func must be provided." |
|
|
|
bits = kwargs.get('bits') |
|
params = kwargs.get('params') |
|
qtz_func = kwargs.get('qtz_func') |
|
|
|
x_ = x.transpose(0, 1) |
|
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) |
|
x_ = x_.transpose(0, 1) |
|
|
|
quant_min, quant_max = get_min_max_from_bits_signed(bits) |
|
min_vals, max_vals = torch.aminmax(x_, dim=1) |
|
min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) |
|
max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) |
|
|
|
eps = torch.finfo(torch.float32).eps |
|
|
|
abs_max_val_per_ch = torch.max(-min_vals, max_vals) |
|
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) |
|
|
|
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) |
|
|
|
|
|
|
|
scale = scale + 0.01 * torch.randn_like(scale) |
|
return scale |
|
|
|
|
|
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: |
|
params = { |
|
'_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), |
|
} |
|
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) |
|
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} |
|
|
|
if 'post_init_hook' in kwargs: |
|
kwargs['post_init_hook'](parameters=params) |
|
|
|
|
|
if 'post_train_hook' in kwargs: |
|
kwargs['post_train_hook'](parameters=params) |
|
|
|
return params |
|
|
|
|
|
|
|
|
|
|
|
def np_quantization(x, _0, _s): |
|
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccosh(np_domain_guard((_0 * x), min=1, nan=1))) |
|
|
|
|
|
def np_dequantization(x, _0, _s): |
|
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cosh((_s * x))) |
|
|
|
|
|
def fit_func(x, _0, _s): |
|
x_ = np_quantization(x, _0, _s) |
|
x_ = np_dequantization(x_, _0, _s) |
|
return x_ |
|
|
|
|
|
|
|
|
|
|
|
def domain_guard( |
|
x: torch.Tensor, |
|
min: float = None, |
|
max: float = None, |
|
posinf: float = None, |
|
neginf: float = None, |
|
nan: float = None |
|
) -> torch.Tensor: |
|
"""Guard a tensor to a valid domain.""" |
|
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) |
|
if min is not None or max is not None: |
|
x = torch.clamp(x, min=min, max=max) |
|
return x |
|
|
|
|
|
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: |
|
"""Replace a number in a tensor with another number. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
num (float): The number to replace. |
|
to (float): The number to replace with. |
|
|
|
Returns: |
|
torch.Tensor: The tensor with the number replaced. |
|
""" |
|
return torch.where(x == num, to, x) |
|
|
|
|
|
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: |
|
"""Guard the power operation to a valid domain.""" |
|
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) |
|
|
|
|
|
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: |
|
val = torch.amin(x, dim=1) |
|
return torch.ones_like(val, dtype=torch.float32, device=x.device) |
|
|
|
|
|
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: |
|
val = torch.amin(x, dim=1) |
|
return torch.randn_like(val, dtype=torch.float32, device=x.device) |
|
|
|
|
|
def init_space_search( |
|
x: torch.Tensor, |
|
**kwargs: Dict[str, Any], |
|
) -> torch.Tensor: |
|
|
|
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): |
|
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" |
|
for _ in range(n_params * 10): |
|
yield init_rand(tensor) * max_initial |
|
|
|
def _search_param(tensors: List[torch.tensor], n_params): |
|
"""Takes the best parameters and generates new parameters around the mean of the best parameters.""" |
|
torch_tensors = torch.stack(tensors) |
|
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) |
|
abs_max_val_per_ch = torch.max(-min_vals, max_vals) |
|
mean = torch.mean(torch_tensors, dim=0) |
|
for _ in range(n_params): |
|
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean |
|
|
|
def _calc(x, qtz_func, deqtz_func, **params): |
|
x_ = x.transpose(0, 1) |
|
x_ = qtz_func(x=x_, **params) |
|
x_ = deqtz_func(x=x_, **params) |
|
x_ = x_.transpose(0, 1) |
|
return x_ |
|
|
|
assert "qtz_func" in kwargs, "qtz_func must be provided." |
|
assert "deqtz_func" in kwargs, "deqtz_func must be provided." |
|
assert "params_list" in kwargs, "params list must be provided." |
|
assert "param" in kwargs, "param must be provided." |
|
|
|
qtz_func = kwargs.get('qtz_func') |
|
deqtz_func = kwargs.get('deqtz_func') |
|
params_list = kwargs.get('params_list') |
|
param = kwargs.get('param') |
|
|
|
n_runs = 50 |
|
n_random_params = 50 |
|
n_best_to_pick = 5 |
|
max_initial = 10000 |
|
|
|
|
|
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } |
|
params = _build_initial_param(x, max_initial, n_random_params) |
|
|
|
|
|
for _ in range(n_runs): |
|
|
|
best_params = [] |
|
for param_ in params: |
|
try: |
|
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) |
|
loss_ones = nn.MSELoss()(x, x_) |
|
|
|
if len(best_params) < n_best_to_pick: |
|
best_params.append((param_, loss_ones.item())) |
|
best_params = sorted(best_params, key=lambda x: x[1]) |
|
elif loss_ones < best_params[-1][1]: |
|
best_params[-1] = (param_, loss_ones.item()) |
|
best_params = sorted(best_params, key=lambda x: x[1]) |
|
|
|
except Exception: |
|
continue |
|
|
|
|
|
params = _search_param([p for p, _ in best_params], n_random_params) |
|
|
|
|
|
p_ones = init_ones(x, **kwargs) |
|
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones}) |
|
loss_ones = nn.MSELoss()(x, x_) |
|
|
|
|
|
p_rand = init_rand(x, **kwargs) |
|
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand}) |
|
loss_rand = nn.MSELoss()(x, x_) |
|
|
|
if loss_rand < best_params[0][1] and loss_rand < loss_ones: |
|
return p_rand |
|
elif loss_ones < best_params[0][1] and loss_ones < loss_rand: |
|
return p_ones |
|
else: |
|
return best_params[0][0] |
|
|
|
|
|
def init_linear_scale( |
|
x: torch.Tensor, |
|
**kwargs: Dict[str, Any], |
|
) -> torch.Tensor: |
|
assert "bits" in kwargs, "bits must be provided." |
|
assert "params" in kwargs, "params must be provided." |
|
assert "qtz_func" in kwargs, "qtz_func must be provided." |
|
|
|
bits = kwargs.get('bits') |
|
params = kwargs.get('params') |
|
qtz_func = kwargs.get('qtz_func') |
|
|
|
x_ = x.transpose(0, 1) |
|
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) |
|
x_ = x_.transpose(0, 1) |
|
|
|
quant_min, quant_max = get_min_max_from_bits_signed(bits) |
|
min_vals, max_vals = torch.aminmax(x_, dim=1) |
|
min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) |
|
max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) |
|
|
|
eps = torch.finfo(torch.float32).eps |
|
|
|
abs_max_val_per_ch = torch.max(-min_vals, max_vals) |
|
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) |
|
|
|
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) |
|
|
|
|
|
|
|
scale = scale + 0.01 * torch.randn_like(scale) |
|
return scale |
|
|
|
|
|
def init_non_linear_regression_fit( |
|
x: torch.Tensor, |
|
**kwargs: Dict[str, Any], |
|
) -> torch.Tensor: |
|
|
|
assert "params_list" in kwargs, "params list must be provided." |
|
assert "np_fit_func" in kwargs, "np_fit_func must be provided." |
|
assert "p0" in kwargs, "p0 must be provided." |
|
np_fit_func = kwargs.get('np_fit_func') |
|
params_list = kwargs.get('params_list') |
|
p0 = kwargs.get('p0') |
|
|
|
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): |
|
popt, _ = curve_fit( |
|
func, |
|
xdata, |
|
ydata, |
|
maxfev=1000, |
|
p0=p0, |
|
method='lm' |
|
) |
|
return popt |
|
|
|
|
|
xdata = x.cpu().numpy() |
|
|
|
|
|
sorted_xdata = np.sort(xdata, axis=-1) |
|
|
|
p0 = {k: v.cpu().numpy() for k, v in p0.items()} |
|
params_list = sorted(params_list) |
|
|
|
|
|
try: |
|
params = [] |
|
for i in range(sorted_xdata.shape[0]): |
|
xdata_ = sorted_xdata[i] |
|
p0_ = [p0[p][i] for p in params_list] |
|
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) |
|
params.append(ch_params) |
|
|
|
|
|
result = {} |
|
for i, p in enumerate(params_list): |
|
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) |
|
|
|
return result |
|
|
|
except ValueError as e: |
|
print(f"Could not fit the function with error: {e}") |
|
print(f"Using fallback result...") |
|
return { |
|
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() |
|
} |
|
|
|
|
|
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: |
|
val = torch.amin(x, dim=1) |
|
return torch.zeros_like(val, dtype=torch.float32, device=x.device) |
|
|
|
|
|
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: |
|
|
|
min_vals, max_vals = torch.aminmax(tensor, dim=-1) |
|
x_min = torch.min(min_vals, torch.zeros_like(min_vals)) |
|
x_max = torch.max(max_vals, torch.zeros_like(max_vals)) |
|
|
|
if _max is torch.inf: |
|
return torch.ones_like(x_min) |
|
|
|
|
|
scale = (_max - _min) / (x_max - x_min) |
|
return scale |
|
|
|
|
|
|
|
|
|
|
|
@torch.enable_grad() |
|
def learn_parameters( |
|
x: torch.Tensor, |
|
params: Dict[str, nn.Parameter], |
|
qtz_func: nn.Module, |
|
deqtz_func: nn.Module, |
|
bits: int, |
|
target_dtype: torch.dtype, |
|
epochs: int = 1000, |
|
early_stop: bool = True, |
|
do_report: bool = False |
|
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: |
|
|
|
|
|
for p in params.values(): |
|
p.requires_grad = True |
|
p.grad = None |
|
|
|
param_keys = list(params.keys()) |
|
param_values = list(params.values()) |
|
|
|
|
|
optimizer = torch.optim.Adam(param_values, lr=0.001) |
|
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10) |
|
loss_fn = nn.MSELoss() |
|
|
|
|
|
best_loss = float("inf") |
|
best_params = None |
|
|
|
|
|
min_delta = 1e-7 |
|
acc_loss = [] |
|
percent_epochs_before_stop = 0.1 |
|
|
|
for i in range(epochs): |
|
optimizer.zero_grad() |
|
|
|
quant = quantize(x, params, qtz_func, bits, target_dtype) |
|
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) |
|
loss = loss_fn(x, dequant) |
|
|
|
if loss.isnan() or loss.isinf(): |
|
raise Exception("Loss is NaN or Inf. Stopping the search.") |
|
|
|
loss.backward() |
|
optimizer.step() |
|
scheduler.step() |
|
|
|
acc_loss.append(loss.item()) |
|
|
|
|
|
if i % 10 == 0 and do_report: |
|
print(f"Epoch {i}: Loss {loss.item()}") |
|
|
|
|
|
if loss.item() < best_loss: |
|
best_loss = loss.item() |
|
best_params = copy.deepcopy({ |
|
k: v for k, v in params.items() if k in param_keys |
|
}) |
|
|
|
|
|
if early_stop: |
|
epochs_before_stop = int(epochs * percent_epochs_before_stop) |
|
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: |
|
break |
|
|
|
|
|
for p in best_params.values(): |
|
p.requires_grad = False |
|
p.grad = None |
|
|
|
if do_report: |
|
return best_params, acc_loss |
|
else: |
|
return best_params |
|
|
|
|
|
def quantize( |
|
x: torch.Tensor, |
|
params: Dict[str, nn.Parameter], |
|
func: nn.Module, |
|
bits: int, |
|
target_dtype: torch.dtype = torch.int8 |
|
) -> torch.Tensor: |
|
quant_min, quant_max = get_min_max_from_bits_signed(bits) |
|
x = x.transpose(0, 1) |
|
x = func(x=x, **params) |
|
x = x.transpose(0, 1) |
|
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) |
|
return x |
|
|
|
|
|
def dequantize( |
|
x: torch.Tensor, |
|
params: Dict[str, nn.Parameter], |
|
func: nn.Module, |
|
bits: int, |
|
out_dtype: torch.dtype |
|
) -> torch.Tensor: |
|
x = x.to(dtype=out_dtype) |
|
x = x.transpose(0, 1) |
|
x = func(x=x, **params) |
|
x = x.transpose(0, 1) |
|
return x |
|
|
|
|
|
def round_func_BPDA(input): |
|
|
|
|
|
forward_value = torch.round(input) |
|
out = input.clone() |
|
out.data = forward_value.data |
|
return out |
|
|
|
|
|
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: |
|
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 |
|
|
|
|
|
|
|
|
|
|
|
def np_domain_guard( |
|
x: np.ndarray, |
|
min: float = None, |
|
max: float = None, |
|
posinf: float = None, |
|
neginf: float = None, |
|
nan: float = None |
|
) -> np.ndarray: |
|
"""Guard a tensor to a valid domain.""" |
|
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) |
|
if min is not None or max is not None: |
|
x = np.clip(x, min, max) |
|
return x |
|
|
|
|
|
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: |
|
"""Replace a number in a tensor with another number. |
|
|
|
Args: |
|
x (np.ndarray): The input tensor. |
|
num (float): The number to replace. |
|
to (float): The number to replace with. |
|
|
|
Returns: |
|
np.ndarray: The tensor with the number replaced. |
|
""" |
|
return np.where(x == num, to, x) |
|
|
|
|
|
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: |
|
"""Guard the power operation to a valid domain.""" |
|
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) |
|
|
|
|