diff --git a/fn_gen/rnd_search/0/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/0/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..852493a43044f49cfcf7ed20bc1a5b9423b4f885 Binary files /dev/null and b/fn_gen/rnd_search/0/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/0/distortion.png b/fn_gen/rnd_search/0/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..66b941e2c726c9ba4360f94e81a56bb0dc748cb3 Binary files /dev/null and b/fn_gen/rnd_search/0/distortion.png differ diff --git a/fn_gen/rnd_search/0/expressions.txt b/fn_gen/rnd_search/0/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..ed99293c42843616c361d59b23d32ae553cc0f8d --- /dev/null +++ b/fn_gen/rnd_search/0/expressions.txt @@ -0,0 +1,2 @@ +atanh(_0*x)/_s +tanh(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/0/fn.py b/fn_gen/rnd_search/0/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..cb07d477d9bc977803a4a81e746854be61c34311 --- /dev/null +++ b/fn_gen/rnd_search/0/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/0/loss.png b/fn_gen/rnd_search/0/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..fa95ed4e3da78361f4f69b87c2d7a3f405ced811 Binary files /dev/null and b/fn_gen/rnd_search/0/loss.png differ diff --git a/fn_gen/rnd_search/0/quantization.png b/fn_gen/rnd_search/0/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..af916b7823da0ec75fb1b72c9dbd7922f08b8c95 Binary files /dev/null and b/fn_gen/rnd_search/0/quantization.png differ diff --git a/fn_gen/rnd_search/1/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/1/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb6a6e36a718a284a3d2f580030f228f1f3a187c Binary files /dev/null and b/fn_gen/rnd_search/1/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/1/distortion.png b/fn_gen/rnd_search/1/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..0f5d3d98f2790d3ae49047c90b76db3f6b3d3e0f Binary files /dev/null and b/fn_gen/rnd_search/1/distortion.png differ diff --git a/fn_gen/rnd_search/1/expressions.txt b/fn_gen/rnd_search/1/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..c7b68c388fdf6e1b6e2be8076f1d4b8d7bcef4f9 --- /dev/null +++ b/fn_gen/rnd_search/1/expressions.txt @@ -0,0 +1,2 @@ +(_0*x)**(1/3)/_s +_s**3*x**3/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/1/fn.py b/fn_gen/rnd_search/1/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..cc74648348d22ab58b2e1e5f47e121876f63e5da --- /dev/null +++ b/fn_gen/rnd_search/1/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power((params['_0'] * x), 1 / 3)) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(3)) * guarded_torch_power(x, torch.tensor(3))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power((_0 * x), 1 / 3)) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(3)) * np_guarded_power(x, np.array(3))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/1/loss.png b/fn_gen/rnd_search/1/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..3be22b6ba72d1b59f158d4881d59c997b7484a04 Binary files /dev/null and b/fn_gen/rnd_search/1/loss.png differ diff --git a/fn_gen/rnd_search/1/quantization.png b/fn_gen/rnd_search/1/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..4afc345b5ae242a67d76ec1e9b4496042abcc4a7 Binary files /dev/null and b/fn_gen/rnd_search/1/quantization.png differ diff --git a/fn_gen/rnd_search/10/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/10/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e2fcaff981e2fa51449736bbc6ceabe2cd0ddd8 Binary files /dev/null and b/fn_gen/rnd_search/10/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/10/distortion.png b/fn_gen/rnd_search/10/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..a2ee371b1e8485c42a4c71d82570d92f84b5ccd0 Binary files /dev/null and b/fn_gen/rnd_search/10/distortion.png differ diff --git a/fn_gen/rnd_search/10/expressions.txt b/fn_gen/rnd_search/10/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..dbb6da0fc54c6f23dc12daf2e2c3a395819e1bf4 --- /dev/null +++ b/fn_gen/rnd_search/10/expressions.txt @@ -0,0 +1,2 @@ +x**2/_s +sqrt(_s*x) \ No newline at end of file diff --git a/fn_gen/rnd_search/10/fn.py b/fn_gen/rnd_search/10/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..bbd4b0e99826108ec1033cfc27c4ca2f39f05413 --- /dev/null +++ b/fn_gen/rnd_search/10/fn.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(2))) + + +def dequantization(x, **params): + return torch.sqrt(domain_guard((params['_s'] * x), min=0.1, nan=0.1)) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(2))) + + +def np_dequantization(x, _s): + return np.sqrt(np_domain_guard((_s * x), min=0.1, nan=0.1)) + + +def fit_func(x, _s): + x_ = np_quantization(x, _s) + x_ = np_dequantization(x_, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/10/loss.png b/fn_gen/rnd_search/10/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..73c843742733c2df6c34f4584ec4646db482f942 Binary files /dev/null and b/fn_gen/rnd_search/10/loss.png differ diff --git a/fn_gen/rnd_search/10/quantization.png b/fn_gen/rnd_search/10/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..114c44936bf8e8e42acc1426bb6a9100f9f30d25 Binary files /dev/null and b/fn_gen/rnd_search/10/quantization.png differ diff --git a/fn_gen/rnd_search/11/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/11/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09e83e3e04308e53067b91100e682e8d725e0038 Binary files /dev/null and b/fn_gen/rnd_search/11/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/11/distortion.png b/fn_gen/rnd_search/11/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..f1e312462bbee465ad88250d1e384a3660807940 Binary files /dev/null and b/fn_gen/rnd_search/11/distortion.png differ diff --git a/fn_gen/rnd_search/11/expressions.txt b/fn_gen/rnd_search/11/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..c545adce8b3c320e195336b81461c79d0cc385e6 --- /dev/null +++ b/fn_gen/rnd_search/11/expressions.txt @@ -0,0 +1,2 @@ +asinh(_0*x)/_s +sinh(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/11/fn.py b/fn_gen/rnd_search/11/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..f1393ffef69d7149a761c88434afb99b54b35198 --- /dev/null +++ b/fn_gen/rnd_search/11/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/11/loss.png b/fn_gen/rnd_search/11/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..1882bb6f12f2482005987124cf62687608bd16c4 Binary files /dev/null and b/fn_gen/rnd_search/11/loss.png differ diff --git a/fn_gen/rnd_search/11/quantization.png b/fn_gen/rnd_search/11/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..2889123e850d96ec43c300ccd0a824736a445119 Binary files /dev/null and b/fn_gen/rnd_search/11/quantization.png differ diff --git a/fn_gen/rnd_search/12/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/12/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..696f9b33c00cded192cd6ae14ad3bdce9cf1415e Binary files /dev/null and b/fn_gen/rnd_search/12/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/12/distortion.png b/fn_gen/rnd_search/12/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..02ff798008cc0efe82e54990612d48a6e423d74a Binary files /dev/null and b/fn_gen/rnd_search/12/distortion.png differ diff --git a/fn_gen/rnd_search/12/expressions.txt b/fn_gen/rnd_search/12/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..576ec6a351e26f9982eb17e394804ca906d4b067 --- /dev/null +++ b/fn_gen/rnd_search/12/expressions.txt @@ -0,0 +1,2 @@ +acos(_0*x)/_s +cos(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/12/fn.py b/fn_gen/rnd_search/12/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..b330b1fb4fc659c93e66384d22497987df92f824 --- /dev/null +++ b/fn_gen/rnd_search/12/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/12/loss.png b/fn_gen/rnd_search/12/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..035bdd7b4b93c353f13239768ba2045d277e1c38 Binary files /dev/null and b/fn_gen/rnd_search/12/loss.png differ diff --git a/fn_gen/rnd_search/12/quantization.png b/fn_gen/rnd_search/12/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..8b045ccd1585faca422e764e9d0b9c0086d93d1d Binary files /dev/null and b/fn_gen/rnd_search/12/quantization.png differ diff --git a/fn_gen/rnd_search/13/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/13/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8f9cab57c030768877168baab4dc6dd875608bd Binary files /dev/null and b/fn_gen/rnd_search/13/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/13/distortion.png b/fn_gen/rnd_search/13/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..9059654d6a9aec3abc61d2675e781a5ca0ddfa31 Binary files /dev/null and b/fn_gen/rnd_search/13/distortion.png differ diff --git a/fn_gen/rnd_search/13/expressions.txt b/fn_gen/rnd_search/13/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a7abbbdac98c7d53123fe0b9807e7644bc00acf --- /dev/null +++ b/fn_gen/rnd_search/13/expressions.txt @@ -0,0 +1,2 @@ +acosh(_0*x)/_s +cosh(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/13/fn.py b/fn_gen/rnd_search/13/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..4acfba4d9f23b4aae9b6857edec59fb939a43500 --- /dev/null +++ b/fn_gen/rnd_search/13/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acosh(domain_guard((params['_0'] * x), min=1, nan=1))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cosh((params['_s'] * x))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccosh(np_domain_guard((_0 * x), min=1, nan=1))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cosh((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/13/loss.png b/fn_gen/rnd_search/13/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..bfa22852c1e17002bc2b60c1646d0f3c503f40f7 Binary files /dev/null and b/fn_gen/rnd_search/13/loss.png differ diff --git a/fn_gen/rnd_search/13/quantization.png b/fn_gen/rnd_search/13/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..379bc1bc2a941ac22102ab27631c55aa29b67489 Binary files /dev/null and b/fn_gen/rnd_search/13/quantization.png differ diff --git a/fn_gen/rnd_search/14/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/14/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18f7c86ec319d6bb18566d2f3e9b6d1935bb1372 Binary files /dev/null and b/fn_gen/rnd_search/14/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/14/distortion.png b/fn_gen/rnd_search/14/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..bcfa46879599f1d1abbc22221fddae9580c743cf Binary files /dev/null and b/fn_gen/rnd_search/14/distortion.png differ diff --git a/fn_gen/rnd_search/14/expressions.txt b/fn_gen/rnd_search/14/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..23606e9f370f2e4adb43ed623c49d7fcaabd7355 --- /dev/null +++ b/fn_gen/rnd_search/14/expressions.txt @@ -0,0 +1,2 @@ +tan(_0*x)/_s +atan(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/14/fn.py b/fn_gen/rnd_search/14/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed2d913f3c53607a8f0024f6c1855a09ac5daaf --- /dev/null +++ b/fn_gen/rnd_search/14/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/14/loss.png b/fn_gen/rnd_search/14/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..92e587cd36008b20f886f4de338dac8b63f712a0 Binary files /dev/null and b/fn_gen/rnd_search/14/loss.png differ diff --git a/fn_gen/rnd_search/14/quantization.png b/fn_gen/rnd_search/14/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..1eec55b0489eff49fad01e980908a1064c497cb3 Binary files /dev/null and b/fn_gen/rnd_search/14/quantization.png differ diff --git a/fn_gen/rnd_search/16/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/16/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a020a78c6ea94fbb0215ae1c23a83eae772cead Binary files /dev/null and b/fn_gen/rnd_search/16/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/16/distortion.png b/fn_gen/rnd_search/16/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..7f394e778bbd4b09164d72453881467aa46f39df Binary files /dev/null and b/fn_gen/rnd_search/16/distortion.png differ diff --git a/fn_gen/rnd_search/16/expressions.txt b/fn_gen/rnd_search/16/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa32b575e8c654dbc457c94f36222e70d86dc940 --- /dev/null +++ b/fn_gen/rnd_search/16/expressions.txt @@ -0,0 +1,2 @@ +atan(_0*x)/_s +tan(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/16/fn.py b/fn_gen/rnd_search/16/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..20171c398c6ab36038905c3322d403c62b5137f8 --- /dev/null +++ b/fn_gen/rnd_search/16/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atan((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tan(domain_guard((params['_s'] * x), posinf=1, neginf=-1, nan=0))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctan((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tan(np_domain_guard((_s * x), posinf=1, neginf=-1, nan=0))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/16/loss.png b/fn_gen/rnd_search/16/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..a6885d1513f8da2455a42611c40c88eeeb2158b9 Binary files /dev/null and b/fn_gen/rnd_search/16/loss.png differ diff --git a/fn_gen/rnd_search/16/quantization.png b/fn_gen/rnd_search/16/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..cfb8ec9acbb2009c04c70c076a3fb5440c65746c Binary files /dev/null and b/fn_gen/rnd_search/16/quantization.png differ diff --git a/fn_gen/rnd_search/17/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/17/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..942cad413a2851c41a69dcb6878397e32a863cd4 Binary files /dev/null and b/fn_gen/rnd_search/17/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/17/distortion.png b/fn_gen/rnd_search/17/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..5d9b24fc7fa3a342286b217ea711c2b8d756db41 Binary files /dev/null and b/fn_gen/rnd_search/17/distortion.png differ diff --git a/fn_gen/rnd_search/17/expressions.txt b/fn_gen/rnd_search/17/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..ec55493201f7b2b8effaefed75e0a9258fc25c56 --- /dev/null +++ b/fn_gen/rnd_search/17/expressions.txt @@ -0,0 +1,2 @@ +tanh(_0*x)/_s +log((-_s*x - 1)/(_s*x - 1))/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/17/fn.py b/fn_gen/rnd_search/17/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..53ae3822433a1c47eeefcb4de6982668ff6c4d97 --- /dev/null +++ b/fn_gen/rnd_search/17/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/17/loss.png b/fn_gen/rnd_search/17/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..223d0c35ca8ce9c2d4eb6613d7f56f26ab21c18a Binary files /dev/null and b/fn_gen/rnd_search/17/loss.png differ diff --git a/fn_gen/rnd_search/17/quantization.png b/fn_gen/rnd_search/17/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..a87fe57e778e15be656948732f32bf4f3809bd88 Binary files /dev/null and b/fn_gen/rnd_search/17/quantization.png differ diff --git a/fn_gen/rnd_search/18/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/18/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..507d1df1926c272590ffde3225351d4fc29df35c Binary files /dev/null and b/fn_gen/rnd_search/18/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/18/distortion.png b/fn_gen/rnd_search/18/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..79d8dc0923be28432e2f14c52257077e6cb0e60e Binary files /dev/null and b/fn_gen/rnd_search/18/distortion.png differ diff --git a/fn_gen/rnd_search/18/expressions.txt b/fn_gen/rnd_search/18/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..03413827fa8f4c8ad49a40b543460cf31d1ce803 --- /dev/null +++ b/fn_gen/rnd_search/18/expressions.txt @@ -0,0 +1,2 @@ +asin(_0*x)/_s +sin(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/18/fn.py b/fn_gen/rnd_search/18/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..6545d1e976b683b7cfcd8897b198f29f3dd14805 --- /dev/null +++ b/fn_gen/rnd_search/18/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asin(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sin((params['_s'] * x))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsin(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sin((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/18/loss.png b/fn_gen/rnd_search/18/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..07ebf021d026228f4d4194b8a25afffd5ee305cc Binary files /dev/null and b/fn_gen/rnd_search/18/loss.png differ diff --git a/fn_gen/rnd_search/18/quantization.png b/fn_gen/rnd_search/18/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..2477f2b07af3858ba9262bcb904ec1fac5c1cf90 Binary files /dev/null and b/fn_gen/rnd_search/18/quantization.png differ diff --git a/fn_gen/rnd_search/2/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/2/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2490373cdd8aebdaeb933adbb33127e499782478 Binary files /dev/null and b/fn_gen/rnd_search/2/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/2/distortion.png b/fn_gen/rnd_search/2/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..55f649d865800d8c6b638878ea726ab6db8259d4 Binary files /dev/null and b/fn_gen/rnd_search/2/distortion.png differ diff --git a/fn_gen/rnd_search/2/expressions.txt b/fn_gen/rnd_search/2/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..74791fc40576643d62f6366a8b4eda20eb1ad252 --- /dev/null +++ b/fn_gen/rnd_search/2/expressions.txt @@ -0,0 +1,2 @@ +x**3/_s +(_s*x)**(1/3) \ No newline at end of file diff --git a/fn_gen/rnd_search/2/fn.py b/fn_gen/rnd_search/2/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..0270ee77dff2123b4a8238ad82605d175e68ccfa --- /dev/null +++ b/fn_gen/rnd_search/2/fn.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(3))) + + +def dequantization(x, **params): + return guarded_torch_power((params['_s'] * x), 1 / 3) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(3))) + + +def np_dequantization(x, _s): + return np_guarded_power((_s * x), 1 / 3) + + +def fit_func(x, _s): + x_ = np_quantization(x, _s) + x_ = np_dequantization(x_, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/2/loss.png b/fn_gen/rnd_search/2/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..55362f7b391e24b42e2ec2273da8fa6454548cd3 Binary files /dev/null and b/fn_gen/rnd_search/2/loss.png differ diff --git a/fn_gen/rnd_search/2/quantization.png b/fn_gen/rnd_search/2/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..c462294edeec80629d618b908c3ccb397aa9b238 Binary files /dev/null and b/fn_gen/rnd_search/2/quantization.png differ diff --git a/fn_gen/rnd_search/4/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/4/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcba9964463524a9f30616d282e863d63e4da69d Binary files /dev/null and b/fn_gen/rnd_search/4/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/4/distortion.png b/fn_gen/rnd_search/4/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..125e3d0823d8a3a170c88b4de02f0ddf400b13fc Binary files /dev/null and b/fn_gen/rnd_search/4/distortion.png differ diff --git a/fn_gen/rnd_search/4/expressions.txt b/fn_gen/rnd_search/4/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a7e5be4566beeb4727d82f95d24241966d158dc --- /dev/null +++ b/fn_gen/rnd_search/4/expressions.txt @@ -0,0 +1,2 @@ +log(_0*x)/_s +exp(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/4/fn.py b/fn_gen/rnd_search/4/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..2931f7b0f790699083920ddbd8c400118cc4d1c9 --- /dev/null +++ b/fn_gen/rnd_search/4/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.log(domain_guard((params['_0'] * x), min=1e-5, nan=1e-5))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.exp((params['_s'] * x))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.log(np_domain_guard((_0 * x), min=1e-5, nan=1e-5))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.exp((_s * x))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/4/loss.png b/fn_gen/rnd_search/4/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..7a931d7f61307252402d87cc5a7bd3e8f89436cd Binary files /dev/null and b/fn_gen/rnd_search/4/loss.png differ diff --git a/fn_gen/rnd_search/4/quantization.png b/fn_gen/rnd_search/4/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..497f97b22cb44c14fba57e51307f3de05292df0a Binary files /dev/null and b/fn_gen/rnd_search/4/quantization.png differ diff --git a/fn_gen/rnd_search/5/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/5/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f576b3f978d5668a675e40d8863592972ff08d70 Binary files /dev/null and b/fn_gen/rnd_search/5/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/5/distortion.png b/fn_gen/rnd_search/5/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..8294b9669ca57e9ac3f196b59416b40934a2725f Binary files /dev/null and b/fn_gen/rnd_search/5/distortion.png differ diff --git a/fn_gen/rnd_search/5/expressions.txt b/fn_gen/rnd_search/5/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..8c0b1579c06c048d5603aa39c80e392c5906a879 --- /dev/null +++ b/fn_gen/rnd_search/5/expressions.txt @@ -0,0 +1,2 @@ +cos(_0*x)/_s +acos(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/5/fn.py b/fn_gen/rnd_search/5/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..80d5d3ee8e59df7d248f6f1d4be038171c29392e --- /dev/null +++ b/fn_gen/rnd_search/5/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cos((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.acos(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cos((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arccos(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/5/loss.png b/fn_gen/rnd_search/5/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..2beb959a43cf4c7fd68594465fcab18cf71b33a5 Binary files /dev/null and b/fn_gen/rnd_search/5/loss.png differ diff --git a/fn_gen/rnd_search/5/quantization.png b/fn_gen/rnd_search/5/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..8cb5e72cc15fd2ff150a1211fdf5e29b99711322 Binary files /dev/null and b/fn_gen/rnd_search/5/quantization.png differ diff --git a/fn_gen/rnd_search/6/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/6/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47bbdaf6732d7f44302db81e3b26635166957881 Binary files /dev/null and b/fn_gen/rnd_search/6/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/6/distortion.png b/fn_gen/rnd_search/6/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..a107b42518ea5bd72604d7ab5c576706152a885f Binary files /dev/null and b/fn_gen/rnd_search/6/distortion.png differ diff --git a/fn_gen/rnd_search/6/expressions.txt b/fn_gen/rnd_search/6/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..6d6553d091cd1d343d7aa9b52b85ef6ec88ea854 --- /dev/null +++ b/fn_gen/rnd_search/6/expressions.txt @@ -0,0 +1,2 @@ +x/_s +_s*x \ No newline at end of file diff --git a/fn_gen/rnd_search/6/fn.py b/fn_gen/rnd_search/6/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..c305ce92fb2df15adfd3798830e89c3fbc044243 --- /dev/null +++ b/fn_gen/rnd_search/6/fn.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (x * torch.div(1, replace_num(params['_s'], num=0, to=10000))) + + +def dequantization(x, **params): + return (params['_s'] * x) + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _s): + return (x * np.divide(1, np_replace_num(_s, num=0, to=10000))) + + +def np_dequantization(x, _s): + return (_s * x) + + +def fit_func(x, _s): + x_ = np_quantization(x, _s) + x_ = np_dequantization(x_, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/6/loss.png b/fn_gen/rnd_search/6/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..cdc7865771fde3b22dfac4186c7c496ade81ec46 Binary files /dev/null and b/fn_gen/rnd_search/6/loss.png differ diff --git a/fn_gen/rnd_search/6/quantization.png b/fn_gen/rnd_search/6/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..62c139d7154f151f266325b05429853e810fdb9c Binary files /dev/null and b/fn_gen/rnd_search/6/quantization.png differ diff --git a/fn_gen/rnd_search/7/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/7/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516ac38163827b03dc7b211ffe0466795a7908eb Binary files /dev/null and b/fn_gen/rnd_search/7/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/7/distortion.png b/fn_gen/rnd_search/7/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..b64821265cd5f5cdce7a78c9dc09531e7891fe37 Binary files /dev/null and b/fn_gen/rnd_search/7/distortion.png differ diff --git a/fn_gen/rnd_search/7/expressions.txt b/fn_gen/rnd_search/7/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8458af52eb4cfce21cf8459f3c454003cd78158 --- /dev/null +++ b/fn_gen/rnd_search/7/expressions.txt @@ -0,0 +1,2 @@ +sqrt(_0*x)/_s +_s**2*x**2/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/7/fn.py b/fn_gen/rnd_search/7/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..5f77e312a2be189b6db32f812ef88883d58064c6 --- /dev/null +++ b/fn_gen/rnd_search/7/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sqrt(domain_guard((params['_0'] * x), min=0.1, nan=0.1))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sqrt(np_domain_guard((_0 * x), min=0.1, nan=0.1))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/7/loss.png b/fn_gen/rnd_search/7/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..06b31a5ba7dd4734d99d8edd0ad3d30fd0f51279 Binary files /dev/null and b/fn_gen/rnd_search/7/loss.png differ diff --git a/fn_gen/rnd_search/7/quantization.png b/fn_gen/rnd_search/7/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..17935b149eb337b0520be9808c3dc06303bc1034 Binary files /dev/null and b/fn_gen/rnd_search/7/quantization.png differ diff --git a/fn_gen/rnd_search/8/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/8/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8539a30ced8e60947bfec29638447fb5180a8442 Binary files /dev/null and b/fn_gen/rnd_search/8/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/8/distortion.png b/fn_gen/rnd_search/8/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..58066b9c69e330d69c7d86badaf4ae89af825bc5 Binary files /dev/null and b/fn_gen/rnd_search/8/distortion.png differ diff --git a/fn_gen/rnd_search/8/expressions.txt b/fn_gen/rnd_search/8/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..ecd6e238827dcdb95f4bcb390c1c300696f34254 --- /dev/null +++ b/fn_gen/rnd_search/8/expressions.txt @@ -0,0 +1,2 @@ +sin(_0*x)/_s +asin(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/8/fn.py b/fn_gen/rnd_search/8/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..834605f7ed5ca96aad8058fd8f1a0181ff9bd28c --- /dev/null +++ b/fn_gen/rnd_search/8/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sin((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.asin(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sin((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arcsin(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/8/loss.png b/fn_gen/rnd_search/8/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..a1c87d87c79ecdb4889409f8cbecba0c8f76b794 Binary files /dev/null and b/fn_gen/rnd_search/8/loss.png differ diff --git a/fn_gen/rnd_search/8/quantization.png b/fn_gen/rnd_search/8/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..0e6dd4d0815d9829b9d26b3eaf76e838c949c38c Binary files /dev/null and b/fn_gen/rnd_search/8/quantization.png differ diff --git a/fn_gen/rnd_search/9/__pycache__/fn.cpython-311.pyc b/fn_gen/rnd_search/9/__pycache__/fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45a72fedb63fb8850fd8d0c56f36fd71acf78c74 Binary files /dev/null and b/fn_gen/rnd_search/9/__pycache__/fn.cpython-311.pyc differ diff --git a/fn_gen/rnd_search/9/distortion.png b/fn_gen/rnd_search/9/distortion.png new file mode 100644 index 0000000000000000000000000000000000000000..87dbf44b714416aa1ded5c994351d0e60964c824 Binary files /dev/null and b/fn_gen/rnd_search/9/distortion.png differ diff --git a/fn_gen/rnd_search/9/expressions.txt b/fn_gen/rnd_search/9/expressions.txt new file mode 100644 index 0000000000000000000000000000000000000000..9aa25379a9d1d5a93d60659c6609b2e24e79234d --- /dev/null +++ b/fn_gen/rnd_search/9/expressions.txt @@ -0,0 +1,2 @@ +exp(_0*x)/_s +log(_s*x)/_0 \ No newline at end of file diff --git a/fn_gen/rnd_search/9/fn.py b/fn_gen/rnd_search/9/fn.py new file mode 100644 index 0000000000000000000000000000000000000000..29d4c12348f1efbbf7ae4fc968c579c064d43604 --- /dev/null +++ b/fn_gen/rnd_search/9/fn.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import torch +from torch import amin # Necessary for arcsin +import copy +import torch.nn as nn +import numpy as np + +from scipy.optimize import curve_fit +from typing import Dict, Any, Tuple, List, Callable + + +def quantization(x, **params): + return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.exp((params['_0'] * x))) + + +def dequantization(x, **params): + return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((params['_s'] * x), min=1e-5, nan=1e-5))) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]: + params = { + '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs), + } + params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs) + params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()} + + if 'post_init_hook' in kwargs: + kwargs['post_init_hook'](parameters=params) + + + if 'post_train_hook' in kwargs: + kwargs['post_train_hook'](parameters=params) + + return params + + +############### Numpy Qtz ############### + + +def np_quantization(x, _0, _s): + return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.exp((_0 * x))) + + +def np_dequantization(x, _0, _s): + return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((_s * x), min=1e-5, nan=1e-5))) + + +def fit_func(x, _0, _s): + x_ = np_quantization(x, _0, _s) + x_ = np_dequantization(x_, _0, _s) + return x_ + + + +############### HELPERS ############### + +def domain_guard( + x: torch.Tensor, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> torch.Tensor: + """Guard a tensor to a valid domain.""" + x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = torch.clamp(x, min=min, max=max) + return x + + +def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor: + """Replace a number in a tensor with another number. + + Args: + x (torch.Tensor): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + torch.Tensor: The tensor with the number replaced. + """ + return torch.where(x == num, to, x) + + +def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor: + """Guard the power operation to a valid domain.""" + return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp) + + +def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.ones_like(val, dtype=torch.float32, device=x.device) + + +def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.randn_like(val, dtype=torch.float32, device=x.device) + + +def init_space_search( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int): + """Generates the initial set of parameters. The first iteration generates 10 times more parameters.""" + for _ in range(n_params * 10): # The first iteration generates 10 times more parameters + yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial] + + def _search_param(tensors: List[torch.tensor], n_params): + """Takes the best parameters and generates new parameters around the mean of the best parameters.""" + torch_tensors = torch.stack(tensors) + min_vals, max_vals = torch.aminmax(torch_tensors, dim=0) + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + mean = torch.mean(torch_tensors, dim=0) + for _ in range(n_params): # Generates n_params around the mean of the tensors + yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean + + def _calc(x, qtz_func, deqtz_func, **params): + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params) + x_ = deqtz_func(x=x_, **params) + x_ = x_.transpose(0, 1) + return x_ + + assert "qtz_func" in kwargs, "qtz_func must be provided." + assert "deqtz_func" in kwargs, "deqtz_func must be provided." + assert "params_list" in kwargs, "params list must be provided." + assert "param" in kwargs, "param must be provided." + + qtz_func = kwargs.get('qtz_func') + deqtz_func = kwargs.get('deqtz_func') + params_list = kwargs.get('params_list') + param = kwargs.get('param') + + n_runs = 50 # Number of runs to try to find the best parameters + n_random_params = 50 # Number of random parameters to generate + n_best_to_pick = 5 # Number of best parameters to pick after each run + max_initial = 10000 # Maximum value to initialize the parameters + + # Initializes the parameters + base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param } + params = _build_initial_param(x, max_initial, n_random_params) + + # Performs the search + for _ in range(n_runs): + + best_params = [] + for param_ in params: + try: + x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_}) + loss_ones = nn.MSELoss()(x, x_) + + if len(best_params) < n_best_to_pick: + best_params.append((param_, loss_ones.item())) + best_params = sorted(best_params, key=lambda x: x[1]) + elif loss_ones < best_params[-1][1]: + best_params[-1] = (param_, loss_ones.item()) + best_params = sorted(best_params, key=lambda x: x[1]) + + except Exception: # The parameters might not be valid for the function's domain + continue + + # Generates new parameters around the mean + params = _search_param([p for p, _ in best_params], n_random_params) + + return best_params[0][0] + + +def init_linear_scale( # Symmetric scale. From the study folder + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + assert "bits" in kwargs, "bits must be provided." + assert "params" in kwargs, "params must be provided." + assert "qtz_func" in kwargs, "qtz_func must be provided." + + bits = kwargs.get('bits') + params = kwargs.get('params') + qtz_func = kwargs.get('qtz_func') + + x_ = x.transpose(0, 1) + x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs)) + x_ = x_.transpose(0, 1) + + quant_min, quant_max = get_min_max_from_bits_signed(bits) + min_vals, max_vals = torch.aminmax(x_, dim=1) + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + eps = torch.finfo(torch.float32).eps + + abs_max_val_per_ch = torch.max(-min_vals, max_vals) + scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2) + + scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device) + + # Introduces some noise in scale + # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything + # scale = scale + 0.01 * torch.randn_like(scale) + return scale + + +def init_non_linear_regression_fit( + x: torch.Tensor, + **kwargs: Dict[str, Any], + ) -> torch.Tensor: + + assert "params_list" in kwargs, "params list must be provided." + assert "np_fit_func" in kwargs, "np_fit_func must be provided." + assert "p0" in kwargs, "p0 must be provided." + np_fit_func = kwargs.get('np_fit_func') + params_list = kwargs.get('params_list') + p0 = kwargs.get('p0') + + def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]): + popt, _ = curve_fit( + func, + xdata, + ydata, + maxfev=1000, + p0=p0, + method='lm' + ) + return popt + + # 1. Needs to convert the torch tensor to numpy tensor + xdata = x.cpu().numpy() + + # 2. Sorts the data so that it makes it easier to fit to it + sorted_xdata = np.sort(xdata, axis=-1) + + p0 = {k: v.cpu().numpy() for k, v in p0.items()} + params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order + + # 3. Finds the best parameters for each channel + try: + params = [] + for i in range(sorted_xdata.shape[0]): + xdata_ = sorted_xdata[i] + p0_ = [p0[p][i] for p in params_list] + ch_params = _fit(xdata_, xdata_, np_fit_func, p0_) + params.append(ch_params) + + # 4. Builds the parameters + result = {} + for i, p in enumerate(params_list): + result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device) + + return result + + except ValueError as e: + print(f"Could not fit the function with error: {e}") + print(f"Using fallback result...") + return { + k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items() + } + + +def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + val = torch.amin(x, dim=1) + return torch.zeros_like(val, dtype=torch.float32, device=x.device) + + +def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor: + # Calculate the original minimum and maximum values + min_vals, max_vals = torch.aminmax(tensor, dim=-1) + x_min = torch.min(min_vals, torch.zeros_like(min_vals)) + x_max = torch.max(max_vals, torch.zeros_like(max_vals)) + + if _max is torch.inf: # We do not need to scale the tensor. Just need to move it + return torch.ones_like(x_min) + + # Calculate the scale factor + scale = (_max - _min) / (x_max - x_min) + return scale + + + +############## Quant ############### + +@torch.enable_grad() +def learn_parameters( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + qtz_func: nn.Module, + deqtz_func: nn.Module, + bits: int, + target_dtype: torch.dtype, + epochs: int = 1000, + early_stop: bool = True, + do_report: bool = False +) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]: + loss_fn = nn.MSELoss() + + # Determines the initial learning rate by computing the initial loss and multiplying it by + # the order of magnitude of the loss divided by 2 + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + base_lr = 0.1 + exponent = int(np.floor(np.log10(loss.item()))) + lr = base_lr * (10 ** (exponent // 2)) + + # Requires gradients in the parameters + for p in params.values(): + p.requires_grad = True + p.grad = None + + param_keys = list(params.keys()) + param_values = list(params.values()) + + # Defines optimizer and loss function + optimizer = torch.optim.Adam(param_values, lr=lr) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10) + + # Contains the best loss and the best parameters + best_loss = float("inf") + best_params = None + + # Used to stop the search early + min_delta = 1e-7 + acc_loss = [] + percent_epochs_before_stop = 0.1 + + for i in range(epochs): + optimizer.zero_grad() + + quant = quantize(x, params, qtz_func, bits, target_dtype) + dequant = dequantize(quant, params, deqtz_func, bits, x.dtype) + loss = loss_fn(x, dequant) + + if loss.isnan() or loss.isinf(): + raise Exception("Loss is NaN or Inf. Stopping the search.") + + loss.backward() + optimizer.step() + scheduler.step() + + acc_loss.append(loss.item()) + + # Reports loss every 10 steps + if i % 10 == 0 and do_report: + print(f"Epoch {i}: Loss {loss.item()}") + + # Optimizes the parameter search by storing the best loss and the parameters + if loss.item() < best_loss: + best_loss = loss.item() + best_params = copy.deepcopy({ + k: v for k, v in params.items() if k in param_keys + }) + + # We also stop the search if the loss has not considerably during the last 10% epochs + if early_stop: + epochs_before_stop = int(epochs * percent_epochs_before_stop) + if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta: + break + + # No longer requires gradients in the parameters + for p in best_params.values(): + p.requires_grad = False + p.grad = None + + if do_report: + return best_params, acc_loss + else: + return best_params + + +def quantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + target_dtype: torch.dtype = torch.int8 +) -> torch.Tensor: + quant_min, quant_max = get_min_max_from_bits_signed(bits) + x = x.transpose(0, 1) # Aligns shapes + x = func(x=x, **params) + x = x.transpose(0, 1) + x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype) + return x + + +def dequantize( + x: torch.Tensor, + params: Dict[str, nn.Parameter], + func: nn.Module, + bits: int, + out_dtype: torch.dtype +) -> torch.Tensor: + x = x.to(dtype=out_dtype) + x = x.transpose(0, 1) + x = func(x=x, **params) + x = x.transpose(0, 1) + return x + + +def round_func_BPDA(input): + # This is equivalent to replacing round function (non-differentiable) with + # an identity function (differentiable) only when backward. + forward_value = torch.round(input) + out = input.clone() + out.data = forward_value.data + return out + + +def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]: + return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1 + + + +############## Numpy ############### + +def np_domain_guard( + x: np.ndarray, + min: float = None, + max: float = None, + posinf: float = None, + neginf: float = None, + nan: float = None + ) -> np.ndarray: + """Guard a tensor to a valid domain.""" + x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan) + if min is not None or max is not None: + x = np.clip(x, min, max) + return x + + +def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray: + """Replace a number in a tensor with another number. + + Args: + x (np.ndarray): The input tensor. + num (float): The number to replace. + to (float): The number to replace with. + + Returns: + np.ndarray: The tensor with the number replaced. + """ + return np.where(x == num, to, x) + + +def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray: + """Guard the power operation to a valid domain.""" + return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp) + diff --git a/fn_gen/rnd_search/9/loss.png b/fn_gen/rnd_search/9/loss.png new file mode 100644 index 0000000000000000000000000000000000000000..69fef40cdd1ab8425f9007cd1991984f7e25efdf Binary files /dev/null and b/fn_gen/rnd_search/9/loss.png differ diff --git a/fn_gen/rnd_search/9/quantization.png b/fn_gen/rnd_search/9/quantization.png new file mode 100644 index 0000000000000000000000000000000000000000..bb079b33a0d5c44e97715d6c341793fe2a92ebf9 Binary files /dev/null and b/fn_gen/rnd_search/9/quantization.png differ