Diogo-V's picture
Upload learned functions
511d901 verified
from __future__ import annotations
import torch
from torch import amin # Necessary for arcsin
import copy
import torch.nn as nn
import numpy as np
from scipy.optimize import curve_fit
from typing import Dict, Any, Tuple, List, Callable
def quantization(x, **params):
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x)))
def dequantization(x, **params):
return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5)))
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
base_p0 = {
'_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
}
base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
if 'post_init_hook' in kwargs:
kwargs['post_init_hook'](parameters=base_p0)
params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
if 'post_method_hook' in kwargs:
kwargs['post_method_hook'](parameters=params)
if 'post_train_hook' in kwargs:
kwargs['post_train_hook'](parameters=params)
return params
############### Numpy Qtz ###############
def np_quantization(x, _0, _s):
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x)))
def np_dequantization(x, _0, _s):
return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5)))
def fit_func(x, _0, _s):
x_ = np_quantization(x, _0, _s)
x_ = np_dequantization(x_, _0, _s)
return x_
############### HELPERS ###############
def domain_guard(
x: torch.Tensor,
min: float = None,
max: float = None,
posinf: float = None,
neginf: float = None,
nan: float = None
) -> torch.Tensor:
"""Guard a tensor to a valid domain."""
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
if min is not None or max is not None:
x = torch.clamp(x, min=min, max=max)
return x
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
"""Replace a number in a tensor with another number.
Args:
x (torch.Tensor): The input tensor.
num (float): The number to replace.
to (float): The number to replace with.
Returns:
torch.Tensor: The tensor with the number replaced.
"""
return torch.where(x == num, to, x)
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
"""Guard the power operation to a valid domain."""
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
val = torch.amin(x, dim=1)
return torch.ones_like(val, dtype=torch.float32, device=x.device)
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
val = torch.amin(x, dim=1)
return torch.randn_like(val, dtype=torch.float32, device=x.device)
def init_space_search(
x: torch.Tensor,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
def _search_param(tensors: List[torch.tensor], n_params):
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
torch_tensors = torch.stack(tensors)
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
mean = torch.mean(torch_tensors, dim=0)
for _ in range(n_params): # Generates n_params around the mean of the tensors
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
def _calc(x, qtz_func, deqtz_func, **params):
x_ = x.transpose(0, 1)
x_ = qtz_func(x=x_, **params)
x_ = deqtz_func(x=x_, **params)
x_ = x_.transpose(0, 1)
return x_
assert "qtz_func" in kwargs, "qtz_func must be provided."
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
assert "params_list" in kwargs, "params list must be provided."
assert "param" in kwargs, "param must be provided."
qtz_func = kwargs.get('qtz_func')
deqtz_func = kwargs.get('deqtz_func')
params_list = kwargs.get('params_list')
param = kwargs.get('param')
n_runs = 50 # Number of runs to try to find the best parameters
n_random_params = 50 # Number of random parameters to generate
n_best_to_pick = 5 # Number of best parameters to pick after each run
max_initial = 10000 # Maximum value to initialize the parameters
# Initializes the parameters
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
params = _build_initial_param(x, max_initial, n_random_params)
# Performs the search
for _ in range(n_runs):
best_params = []
for param_ in params:
try:
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
loss_ones = nn.MSELoss()(x, x_)
if len(best_params) < n_best_to_pick:
best_params.append((param_, loss_ones.item()))
best_params = sorted(best_params, key=lambda x: x[1])
elif loss_ones < best_params[-1][1]:
best_params[-1] = (param_, loss_ones.item())
best_params = sorted(best_params, key=lambda x: x[1])
except Exception: # The parameters might not be valid for the function's domain
continue
# Generates new parameters around the mean
params = _search_param([p for p, _ in best_params], n_random_params)
# Checks if the best parameter is better than the init_ones
p_ones = init_ones(x, **kwargs)
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
loss_ones = nn.MSELoss()(x, x_)
# Checks if the best parameter is better than the init_rand
p_rand = init_rand(x, **kwargs)
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
loss_rand = nn.MSELoss()(x, x_)
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
return p_rand
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
return p_ones
else:
return best_params[0][0]
def init_linear_scale( # Symmetric scale. From the study folder
x: torch.Tensor,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
assert "bits" in kwargs, "bits must be provided."
assert "params" in kwargs, "params must be provided."
assert "qtz_func" in kwargs, "qtz_func must be provided."
bits = kwargs.get('bits')
params = kwargs.get('params')
qtz_func = kwargs.get('qtz_func')
x_ = x.transpose(0, 1)
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
x_ = x_.transpose(0, 1)
quant_min, quant_max = get_min_max_from_bits_signed(bits)
min_vals, max_vals = torch.aminmax(x_, dim=1)
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
eps = torch.finfo(torch.float32).eps
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
# Introduces some noise in scale
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
# scale = scale + 0.01 * torch.randn_like(scale)
return scale
def init_non_linear_regression_fit(
x: torch.Tensor,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
assert "params_list" in kwargs, "params list must be provided."
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
assert "p0" in kwargs, "p0 must be provided."
np_fit_func = kwargs.get('np_fit_func')
params_list = kwargs.get('params_list')
p0 = kwargs.get('p0')
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
popt, _ = curve_fit(
func,
xdata,
ydata,
maxfev=1000,
p0=p0,
method='lm'
)
return popt
# 1. Needs to convert the torch tensor to numpy tensor
xdata = x.cpu().numpy()
# 2. Sorts the data so that it makes it easier to fit to it
sorted_xdata = np.sort(xdata, axis=-1)
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
# 3. Finds the best parameters for each channel
try:
params = []
for i in range(sorted_xdata.shape[0]):
xdata_ = sorted_xdata[i]
p0_ = [p0[p][i] for p in params_list]
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
params.append(ch_params)
# 4. Builds the parameters
result = {}
for i, p in enumerate(params_list):
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
return result
except ValueError as e:
print(f"Could not fit the function with error: {e}")
print(f"Using fallback result...")
return {
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
}
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
val = torch.amin(x, dim=1)
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
# Calculate the original minimum and maximum values
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
return torch.ones_like(x_min)
# Calculate the scale factor
scale = (_max - _min) / (x_max - x_min)
return scale
############## Quant ###############
@torch.enable_grad()
def learn_parameters(
x: torch.Tensor,
params: Dict[str, nn.Parameter],
qtz_func: nn.Module,
deqtz_func: nn.Module,
bits: int,
target_dtype: torch.dtype,
epochs: int = 1000,
early_stop: bool = True,
do_report: bool = False
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
loss_fn = nn.MSELoss()
# Determines the initial learning rate by computing the initial loss and multiplying it by
# the order of magnitude of the loss divided by 2
quant = quantize(x, params, qtz_func, bits, target_dtype)
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
loss = loss_fn(x, dequant)
base_lr = 0.1
exponent = int(np.floor(np.log10(loss.item())))
lr = base_lr * (10 ** (exponent // 2))
# Requires gradients in the parameters
for p in params.values():
p.requires_grad = True
p.grad = None
param_keys = list(params.keys())
param_values = list(params.values())
# Defines optimizer and loss function
optimizer = torch.optim.Adam(param_values, lr=lr)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
# Contains the best loss and the best parameters
best_loss = float("inf")
best_params = None
# Used to stop the search early
min_delta = 1e-7
acc_loss = []
percent_epochs_before_stop = 0.1
for i in range(epochs):
optimizer.zero_grad()
quant = quantize(x, params, qtz_func, bits, target_dtype)
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
loss = loss_fn(x, dequant)
if loss.isnan() or loss.isinf():
raise Exception("Loss is NaN or Inf. Stopping the search.")
loss.backward()
optimizer.step()
scheduler.step()
acc_loss.append(loss.item())
# Reports loss every 10 steps
if i % 10 == 0 and do_report:
print(f"Epoch {i}: Loss {loss.item()}")
# Optimizes the parameter search by storing the best loss and the parameters
if loss.item() < best_loss:
best_loss = loss.item()
best_params = copy.deepcopy({
k: v for k, v in params.items() if k in param_keys
})
# We also stop the search if the loss has not considerably during the last 10% epochs
if early_stop:
epochs_before_stop = int(epochs * percent_epochs_before_stop)
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
break
# No longer requires gradients in the parameters
for p in best_params.values():
p.requires_grad = False
p.grad = None
if do_report:
return best_params, acc_loss
else:
return best_params
def quantize(
x: torch.Tensor,
params: Dict[str, nn.Parameter],
func: nn.Module,
bits: int,
target_dtype: torch.dtype = torch.int8
) -> torch.Tensor:
quant_min, quant_max = get_min_max_from_bits_signed(bits)
x = x.transpose(0, 1) # Aligns shapes
x = func(x=x, **params)
x = x.transpose(0, 1)
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
return x
def dequantize(
x: torch.Tensor,
params: Dict[str, nn.Parameter],
func: nn.Module,
bits: int,
out_dtype: torch.dtype
) -> torch.Tensor:
x = x.to(dtype=out_dtype)
x = x.transpose(0, 1)
x = func(x=x, **params)
x = x.transpose(0, 1)
return x
def round_func_BPDA(input):
# This is equivalent to replacing round function (non-differentiable) with
# an identity function (differentiable) only when backward.
forward_value = torch.round(input)
out = input.clone()
out.data = forward_value.data
return out
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
############## Numpy ###############
def np_domain_guard(
x: np.ndarray,
min: float = None,
max: float = None,
posinf: float = None,
neginf: float = None,
nan: float = None
) -> np.ndarray:
"""Guard a tensor to a valid domain."""
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
if min is not None or max is not None:
x = np.clip(x, min, max)
return x
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
"""Replace a number in a tensor with another number.
Args:
x (np.ndarray): The input tensor.
num (float): The number to replace.
to (float): The number to replace with.
Returns:
np.ndarray: The tensor with the number replaced.
"""
return np.where(x == num, to, x)
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
"""Guard the power operation to a valid domain."""
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)