File size: 17,273 Bytes
4685bcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 |
from __future__ import annotations
import torch
from torch import amin # Necessary for arcsin
import copy
import torch.nn as nn
import numpy as np
from scipy.optimize import curve_fit
from typing import Dict, Any, Tuple, List, Callable
def quantization(x, **params):
return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(3)))
def dequantization(x, **params):
return guarded_torch_power((params['_s'] * x), 1 / 3)
def init_linear_scale( # Symmetric scale. From the study folder
x: torch.Tensor,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
assert "bits" in kwargs, "bits must be provided."
assert "params" in kwargs, "params must be provided."
assert "qtz_func" in kwargs, "qtz_func must be provided."
bits = kwargs.get('bits')
params = kwargs.get('params')
qtz_func = kwargs.get('qtz_func')
x_ = x.transpose(0, 1)
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
x_ = x_.transpose(0, 1)
quant_min, quant_max = get_min_max_from_bits_signed(bits)
min_vals, max_vals = torch.aminmax(x_, dim=1)
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
eps = torch.finfo(torch.float32).eps
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
# Introduces some noise in scale
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
# NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
# left it here for future reference. Will be removed later.
# scale = scale + 0.01 * torch.randn_like(scale)
return scale
def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
params = {
}
params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
if 'post_init_hook' in kwargs:
kwargs['post_init_hook'](parameters=params)
params = learn_parameters(x, params,
qtz_func=quantization,
deqtz_func=dequantization,
bits=kwargs['bits'],
target_dtype=torch.int8,
epochs=500,
early_stop=False,
)
if 'post_train_hook' in kwargs:
kwargs['post_train_hook'](parameters=params)
return params
############### Numpy Qtz ###############
def np_quantization(x, _s):
return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(3)))
def np_dequantization(x, _s):
return np_guarded_power((_s * x), 1 / 3)
def fit_func(x, _s):
x_ = np_quantization(x, _s)
x_ = np_dequantization(x_, _s)
return x_
############### HELPERS ###############
def domain_guard(
x: torch.Tensor,
min: float = None,
max: float = None,
posinf: float = None,
neginf: float = None,
nan: float = None
) -> torch.Tensor:
"""Guard a tensor to a valid domain."""
x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
if min is not None or max is not None:
x = torch.clamp(x, min=min, max=max)
return x
def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
"""Replace a number in a tensor with another number.
Args:
x (torch.Tensor): The input tensor.
num (float): The number to replace.
to (float): The number to replace with.
Returns:
torch.Tensor: The tensor with the number replaced.
"""
return torch.where(x == num, to, x)
def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
"""Guard the power operation to a valid domain."""
return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
val = torch.amin(x, dim=1)
return torch.ones_like(val, dtype=torch.float32, device=x.device)
def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
val = torch.amin(x, dim=1)
return torch.randn_like(val, dtype=torch.float32, device=x.device)
def init_space_search(
x: torch.Tensor,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
"""Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
def _search_param(tensors: List[torch.tensor], n_params):
"""Takes the best parameters and generates new parameters around the mean of the best parameters."""
torch_tensors = torch.stack(tensors)
min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
mean = torch.mean(torch_tensors, dim=0)
for _ in range(n_params): # Generates n_params around the mean of the tensors
yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
def _calc(x, qtz_func, deqtz_func, **params):
x_ = x.transpose(0, 1)
x_ = qtz_func(x=x_, **params)
x_ = deqtz_func(x=x_, **params)
x_ = x_.transpose(0, 1)
return x_
assert "qtz_func" in kwargs, "qtz_func must be provided."
assert "deqtz_func" in kwargs, "deqtz_func must be provided."
assert "params_list" in kwargs, "params list must be provided."
assert "param" in kwargs, "param must be provided."
qtz_func = kwargs.get('qtz_func')
deqtz_func = kwargs.get('deqtz_func')
params_list = kwargs.get('params_list')
param = kwargs.get('param')
n_runs = 50 # Number of runs to try to find the best parameters
n_random_params = 50 # Number of random parameters to generate
n_best_to_pick = 5 # Number of best parameters to pick after each run
max_initial = 10000 # Maximum value to initialize the parameters
# Initializes the parameters
base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
params = _build_initial_param(x, max_initial, n_random_params)
# Performs the search
for _ in range(n_runs):
best_params = []
for param_ in params:
try:
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
loss_ones = nn.MSELoss()(x, x_)
if len(best_params) < n_best_to_pick:
best_params.append((param_, loss_ones.item()))
best_params = sorted(best_params, key=lambda x: x[1])
elif loss_ones < best_params[-1][1]:
best_params[-1] = (param_, loss_ones.item())
best_params = sorted(best_params, key=lambda x: x[1])
except Exception: # The parameters might not be valid for the function's domain
continue
# Generates new parameters around the mean
params = _search_param([p for p, _ in best_params], n_random_params)
# Checks if the best parameter is better than the init_ones
p_ones = init_ones(x, **kwargs)
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
loss_ones = nn.MSELoss()(x, x_)
# Checks if the best parameter is better than the init_rand
p_rand = init_rand(x, **kwargs)
x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
loss_rand = nn.MSELoss()(x, x_)
if loss_rand < best_params[0][1] and loss_rand < loss_ones:
return p_rand
elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
return p_ones
else:
return best_params[0][0]
def init_linear_scale( # Symmetric scale. From the study folder
x: torch.Tensor,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
assert "bits" in kwargs, "bits must be provided."
assert "params" in kwargs, "params must be provided."
assert "qtz_func" in kwargs, "qtz_func must be provided."
bits = kwargs.get('bits')
params = kwargs.get('params')
qtz_func = kwargs.get('qtz_func')
x_ = x.transpose(0, 1)
x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
x_ = x_.transpose(0, 1)
quant_min, quant_max = get_min_max_from_bits_signed(bits)
min_vals, max_vals = torch.aminmax(x_, dim=1)
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
eps = torch.finfo(torch.float32).eps
abs_max_val_per_ch = torch.max(-min_vals, max_vals)
scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
# Introduces some noise in scale
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
# If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
# NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
# left it here for future reference. Will be removed later.
# scale = scale + 0.01 * torch.randn_like(scale)
return scale
def init_non_linear_regression_fit(
x: torch.Tensor,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
assert "params_list" in kwargs, "params list must be provided."
assert "np_fit_func" in kwargs, "np_fit_func must be provided."
assert "p0" in kwargs, "p0 must be provided."
np_fit_func = kwargs.get('np_fit_func')
params_list = kwargs.get('params_list')
p0 = kwargs.get('p0')
def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
popt, _ = curve_fit(
func,
xdata,
ydata,
maxfev=1000,
p0=p0,
method='lm'
)
return popt
# 1. Needs to convert the torch tensor to numpy tensor
xdata = x.cpu().numpy()
# 2. Sorts the data so that it makes it easier to fit to it
sorted_xdata = np.sort(xdata, axis=-1)
p0 = {k: v.cpu().numpy() for k, v in p0.items()}
params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
# 3. Finds the best parameters for each channel
try:
params = []
for i in range(sorted_xdata.shape[0]):
xdata_ = sorted_xdata[i]
p0_ = [p0[p][i] for p in params_list]
ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
params.append(ch_params)
# 4. Builds the parameters
result = {}
for i, p in enumerate(params_list):
result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
return result
except ValueError as e:
print(f"Could not fit the function with error: {e}")
print(f"Using fallback result...")
return {
k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
}
def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
val = torch.amin(x, dim=1)
return torch.zeros_like(val, dtype=torch.float32, device=x.device)
def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
# Calculate the original minimum and maximum values
min_vals, max_vals = torch.aminmax(tensor, dim=-1)
x_min = torch.min(min_vals, torch.zeros_like(min_vals))
x_max = torch.max(max_vals, torch.zeros_like(max_vals))
if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
return torch.ones_like(x_min)
# Calculate the scale factor
scale = (_max - _min) / (x_max - x_min)
return scale
############## Quant ###############
@torch.enable_grad()
def learn_parameters(
x: torch.Tensor,
params: Dict[str, nn.Parameter],
qtz_func: nn.Module,
deqtz_func: nn.Module,
bits: int,
target_dtype: torch.dtype,
epochs: int = 1000,
early_stop: bool = True,
do_report: bool = False
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
loss_fn = nn.MSELoss()
# Determines the initial learning rate by computing the initial loss and multiplying it by
# the order of magnitude of the loss divided by 2
quant = quantize(x, params, qtz_func, bits, target_dtype)
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
loss = loss_fn(x, dequant)
base_lr = 0.1
exponent = int(np.floor(np.log10(loss.item())))
lr = base_lr * (10 ** (exponent // 2))
# Requires gradients in the parameters
for p in params.values():
p.requires_grad = True
p.grad = None
param_keys = list(params.keys())
param_values = list(params.values())
# Defines optimizer and loss function
optimizer = torch.optim.Adam(param_values, lr=lr)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=epochs // 10)
# Contains the best loss and the best parameters
best_loss = float("inf")
best_params = None
# Used to stop the search early
min_delta = 1e-7
acc_loss = []
percent_epochs_before_stop = 0.1
for i in range(epochs):
optimizer.zero_grad()
quant = quantize(x, params, qtz_func, bits, target_dtype)
dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
loss = loss_fn(x, dequant)
if loss.isnan() or loss.isinf():
raise Exception("Loss is NaN or Inf. Stopping the search.")
loss.backward()
optimizer.step()
scheduler.step()
acc_loss.append(loss.item())
# Reports loss every 10 steps
if i % 10 == 0 and do_report:
print(f"Epoch {i}: Loss {loss.item()}")
# Optimizes the parameter search by storing the best loss and the parameters
if loss.item() < best_loss:
best_loss = loss.item()
best_params = copy.deepcopy({
k: v for k, v in params.items() if k in param_keys
})
# We also stop the search if the loss has not considerably during the last 10% epochs
if early_stop:
epochs_before_stop = int(epochs * percent_epochs_before_stop)
if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
break
# No longer requires gradients in the parameters
for p in best_params.values():
p.requires_grad = False
p.grad = None
if do_report:
print(f"Best loss: {best_loss}")
return best_params, acc_loss
else:
return best_params
def quantize(
x: torch.Tensor,
params: Dict[str, nn.Parameter],
func: nn.Module,
bits: int,
target_dtype: torch.dtype = torch.int8
) -> torch.Tensor:
quant_min, quant_max = get_min_max_from_bits_signed(bits)
x = x.transpose(0, 1) # Aligns shapes
x = func(x=x, **params)
x = x.transpose(0, 1)
x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
return x
def dequantize(
x: torch.Tensor,
params: Dict[str, nn.Parameter],
func: nn.Module,
bits: int,
out_dtype: torch.dtype
) -> torch.Tensor:
x = x.to(dtype=out_dtype)
x = x.transpose(0, 1)
x = func(x=x, **params)
x = x.transpose(0, 1)
return x
def round_func_BPDA(input):
# This is equivalent to replacing round function (non-differentiable) with
# an identity function (differentiable) only when backward.
forward_value = torch.round(input)
out = input.clone()
out.data = forward_value.data
return out
def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
############## Numpy ###############
def np_domain_guard(
x: np.ndarray,
min: float = None,
max: float = None,
posinf: float = None,
neginf: float = None,
nan: float = None
) -> np.ndarray:
"""Guard a tensor to a valid domain."""
x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
if min is not None or max is not None:
x = np.clip(x, min, max)
return x
def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
"""Replace a number in a tensor with another number.
Args:
x (np.ndarray): The input tensor.
num (float): The number to replace.
to (float): The number to replace with.
Returns:
np.ndarray: The tensor with the number replaced.
"""
return np.where(x == num, to, x)
def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
"""Guard the power operation to a valid domain."""
return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
|