Diogo-V commited on
Commit
c779e0a
1 Parent(s): 8588234

Upload learned functions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fn_gen/nlr_t_no_sched/0/distortion.png +0 -0
  2. fn_gen/nlr_t_no_sched/0/expressions.txt +2 -0
  3. fn_gen/nlr_t_no_sched/0/fn.py +479 -0
  4. fn_gen/nlr_t_no_sched/0/loss.png +0 -0
  5. fn_gen/nlr_t_no_sched/0/quantization.png +0 -0
  6. fn_gen/nlr_t_no_sched/1/distortion.png +0 -0
  7. fn_gen/nlr_t_no_sched/1/expressions.txt +2 -0
  8. fn_gen/nlr_t_no_sched/1/fn.py +479 -0
  9. fn_gen/nlr_t_no_sched/1/loss.png +0 -0
  10. fn_gen/nlr_t_no_sched/1/quantization.png +0 -0
  11. fn_gen/nlr_t_no_sched/10/distortion.png +0 -0
  12. fn_gen/nlr_t_no_sched/10/expressions.txt +2 -0
  13. fn_gen/nlr_t_no_sched/10/fn.py +479 -0
  14. fn_gen/nlr_t_no_sched/10/loss.png +0 -0
  15. fn_gen/nlr_t_no_sched/10/quantization.png +0 -0
  16. fn_gen/nlr_t_no_sched/11/distortion.png +0 -0
  17. fn_gen/nlr_t_no_sched/11/expressions.txt +2 -0
  18. fn_gen/nlr_t_no_sched/11/fn.py +478 -0
  19. fn_gen/nlr_t_no_sched/11/loss.png +0 -0
  20. fn_gen/nlr_t_no_sched/11/quantization.png +0 -0
  21. fn_gen/nlr_t_no_sched/12/distortion.png +0 -0
  22. fn_gen/nlr_t_no_sched/12/expressions.txt +2 -0
  23. fn_gen/nlr_t_no_sched/12/fn.py +479 -0
  24. fn_gen/nlr_t_no_sched/12/loss.png +0 -0
  25. fn_gen/nlr_t_no_sched/12/quantization.png +0 -0
  26. fn_gen/nlr_t_no_sched/14/distortion.png +0 -0
  27. fn_gen/nlr_t_no_sched/14/expressions.txt +2 -0
  28. fn_gen/nlr_t_no_sched/14/fn.py +479 -0
  29. fn_gen/nlr_t_no_sched/14/loss.png +0 -0
  30. fn_gen/nlr_t_no_sched/14/quantization.png +0 -0
  31. fn_gen/nlr_t_no_sched/16/distortion.png +0 -0
  32. fn_gen/nlr_t_no_sched/16/expressions.txt +2 -0
  33. fn_gen/nlr_t_no_sched/16/fn.py +479 -0
  34. fn_gen/nlr_t_no_sched/16/loss.png +0 -0
  35. fn_gen/nlr_t_no_sched/16/quantization.png +0 -0
  36. fn_gen/nlr_t_no_sched/17/distortion.png +0 -0
  37. fn_gen/nlr_t_no_sched/17/expressions.txt +2 -0
  38. fn_gen/nlr_t_no_sched/17/fn.py +479 -0
  39. fn_gen/nlr_t_no_sched/17/loss.png +0 -0
  40. fn_gen/nlr_t_no_sched/17/quantization.png +0 -0
  41. fn_gen/nlr_t_no_sched/18/distortion.png +0 -0
  42. fn_gen/nlr_t_no_sched/18/expressions.txt +2 -0
  43. fn_gen/nlr_t_no_sched/18/fn.py +479 -0
  44. fn_gen/nlr_t_no_sched/18/loss.png +0 -0
  45. fn_gen/nlr_t_no_sched/18/quantization.png +0 -0
  46. fn_gen/nlr_t_no_sched/2/distortion.png +0 -0
  47. fn_gen/nlr_t_no_sched/2/expressions.txt +2 -0
  48. fn_gen/nlr_t_no_sched/2/fn.py +479 -0
  49. fn_gen/nlr_t_no_sched/2/loss.png +0 -0
  50. fn_gen/nlr_t_no_sched/2/quantization.png +0 -0
fn_gen/nlr_t_no_sched/0/distortion.png ADDED
fn_gen/nlr_t_no_sched/0/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ log(_0*x)/_s
2
+ exp(_s*x)/_0
fn_gen/nlr_t_no_sched/0/fn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.log(domain_guard((params['_0'] * x), min=1e-5, nan=1e-5)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.exp((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+ params = learn_parameters(x, params,
36
+ qtz_func=quantization,
37
+ deqtz_func=dequantization,
38
+ bits=kwargs['bits'],
39
+ target_dtype=torch.int8,
40
+ epochs=500,
41
+ early_stop=False,
42
+ )
43
+ if 'post_train_hook' in kwargs:
44
+ kwargs['post_train_hook'](parameters=params)
45
+
46
+ return params
47
+
48
+
49
+ ############### Numpy Qtz ###############
50
+
51
+
52
+ def np_quantization(x, _0, _s):
53
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.log(np_domain_guard((_0 * x), min=1e-5, nan=1e-5)))
54
+
55
+
56
+ def np_dequantization(x, _0, _s):
57
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.exp((_s * x)))
58
+
59
+
60
+ def fit_func(x, _0, _s):
61
+ x_ = np_quantization(x, _0, _s)
62
+ x_ = np_dequantization(x_, _0, _s)
63
+ return x_
64
+
65
+
66
+
67
+ ############### HELPERS ###############
68
+
69
+ def domain_guard(
70
+ x: torch.Tensor,
71
+ min: float = None,
72
+ max: float = None,
73
+ posinf: float = None,
74
+ neginf: float = None,
75
+ nan: float = None
76
+ ) -> torch.Tensor:
77
+ """Guard a tensor to a valid domain."""
78
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
79
+ if min is not None or max is not None:
80
+ x = torch.clamp(x, min=min, max=max)
81
+ return x
82
+
83
+
84
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
85
+ """Replace a number in a tensor with another number.
86
+
87
+ Args:
88
+ x (torch.Tensor): The input tensor.
89
+ num (float): The number to replace.
90
+ to (float): The number to replace with.
91
+
92
+ Returns:
93
+ torch.Tensor: The tensor with the number replaced.
94
+ """
95
+ return torch.where(x == num, to, x)
96
+
97
+
98
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
99
+ """Guard the power operation to a valid domain."""
100
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
101
+
102
+
103
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
104
+ val = torch.amin(x, dim=1)
105
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
106
+
107
+
108
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
109
+ val = torch.amin(x, dim=1)
110
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
111
+
112
+
113
+ def init_space_search(
114
+ x: torch.Tensor,
115
+ **kwargs: Dict[str, Any],
116
+ ) -> torch.Tensor:
117
+
118
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
119
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
120
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
121
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
122
+
123
+ def _search_param(tensors: List[torch.tensor], n_params):
124
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
125
+ torch_tensors = torch.stack(tensors)
126
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
127
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
128
+ mean = torch.mean(torch_tensors, dim=0)
129
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
130
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
131
+
132
+ def _calc(x, qtz_func, deqtz_func, **params):
133
+ x_ = x.transpose(0, 1)
134
+ x_ = qtz_func(x=x_, **params)
135
+ x_ = deqtz_func(x=x_, **params)
136
+ x_ = x_.transpose(0, 1)
137
+ return x_
138
+
139
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
140
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
141
+ assert "params_list" in kwargs, "params list must be provided."
142
+ assert "param" in kwargs, "param must be provided."
143
+
144
+ qtz_func = kwargs.get('qtz_func')
145
+ deqtz_func = kwargs.get('deqtz_func')
146
+ params_list = kwargs.get('params_list')
147
+ param = kwargs.get('param')
148
+
149
+ n_runs = 50 # Number of runs to try to find the best parameters
150
+ n_random_params = 50 # Number of random parameters to generate
151
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
152
+ max_initial = 10000 # Maximum value to initialize the parameters
153
+
154
+ # Initializes the parameters
155
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
156
+ params = _build_initial_param(x, max_initial, n_random_params)
157
+
158
+ # Performs the search
159
+ for _ in range(n_runs):
160
+
161
+ best_params = []
162
+ for param_ in params:
163
+ try:
164
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
165
+ loss_ones = nn.MSELoss()(x, x_)
166
+
167
+ if len(best_params) < n_best_to_pick:
168
+ best_params.append((param_, loss_ones.item()))
169
+ best_params = sorted(best_params, key=lambda x: x[1])
170
+ elif loss_ones < best_params[-1][1]:
171
+ best_params[-1] = (param_, loss_ones.item())
172
+ best_params = sorted(best_params, key=lambda x: x[1])
173
+
174
+ except Exception: # The parameters might not be valid for the function's domain
175
+ continue
176
+
177
+ # Generates new parameters around the mean
178
+ params = _search_param([p for p, _ in best_params], n_random_params)
179
+
180
+ # Checks if the best parameter is better than the init_ones
181
+ p_ones = init_ones(x, **kwargs)
182
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
183
+ loss_ones = nn.MSELoss()(x, x_)
184
+
185
+ # Checks if the best parameter is better than the init_rand
186
+ p_rand = init_rand(x, **kwargs)
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
188
+ loss_rand = nn.MSELoss()(x, x_)
189
+
190
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
191
+ return p_rand
192
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
193
+ return p_ones
194
+ else:
195
+ return best_params[0][0]
196
+
197
+
198
+ def init_linear_scale( # Symmetric scale. From the study folder
199
+ x: torch.Tensor,
200
+ **kwargs: Dict[str, Any],
201
+ ) -> torch.Tensor:
202
+ assert "bits" in kwargs, "bits must be provided."
203
+ assert "params" in kwargs, "params must be provided."
204
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
205
+
206
+ bits = kwargs.get('bits')
207
+ params = kwargs.get('params')
208
+ qtz_func = kwargs.get('qtz_func')
209
+
210
+ x_ = x.transpose(0, 1)
211
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
212
+ x_ = x_.transpose(0, 1)
213
+
214
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
215
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
216
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
217
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
218
+
219
+ eps = torch.finfo(torch.float32).eps
220
+
221
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
222
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
223
+
224
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
225
+
226
+ # Introduces some noise in scale
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
228
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
229
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
230
+ # left it here for future reference. Will be removed later.
231
+ # scale = scale + 0.01 * torch.randn_like(scale)
232
+
233
+ return scale
234
+
235
+
236
+ def init_non_linear_regression_fit(
237
+ x: torch.Tensor,
238
+ **kwargs: Dict[str, Any],
239
+ ) -> torch.Tensor:
240
+
241
+ assert "params_list" in kwargs, "params list must be provided."
242
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
243
+ assert "p0" in kwargs, "p0 must be provided."
244
+ np_fit_func = kwargs.get('np_fit_func')
245
+ params_list = kwargs.get('params_list')
246
+ p0 = kwargs.get('p0')
247
+
248
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
249
+ popt, _ = curve_fit(
250
+ func,
251
+ xdata,
252
+ ydata,
253
+ maxfev=1000,
254
+ p0=p0,
255
+ method='lm'
256
+ )
257
+ return popt
258
+
259
+ # 1. Needs to convert the torch tensor to numpy tensor
260
+ xdata = x.cpu().numpy()
261
+
262
+ # 2. Sorts the data so that it makes it easier to fit to it
263
+ sorted_xdata = np.sort(xdata, axis=-1)
264
+
265
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
266
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
267
+
268
+ # 3. Finds the best parameters for each channel
269
+ try:
270
+ params = []
271
+ for i in range(sorted_xdata.shape[0]):
272
+ xdata_ = sorted_xdata[i]
273
+ p0_ = [p0[p][i] for p in params_list]
274
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
275
+ params.append(ch_params)
276
+
277
+ # 4. Builds the parameters
278
+ result = {}
279
+ for i, p in enumerate(params_list):
280
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
281
+
282
+ return result
283
+
284
+ except ValueError as e:
285
+ print(f"Could not fit the function with error: {e}")
286
+ print(f"Using fallback result...")
287
+ return {
288
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
289
+ }
290
+
291
+
292
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
293
+ val = torch.amin(x, dim=1)
294
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
295
+
296
+
297
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
298
+ # Calculate the original minimum and maximum values
299
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
300
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
301
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
302
+
303
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
304
+ return torch.ones_like(x_min)
305
+
306
+ # Calculate the scale factor
307
+ scale = (_max - _min) / (x_max - x_min)
308
+ return scale
309
+
310
+
311
+
312
+ ############## Quant ###############
313
+
314
+ @torch.enable_grad()
315
+ def learn_parameters(
316
+ x: torch.Tensor,
317
+ params: Dict[str, nn.Parameter],
318
+ qtz_func: nn.Module,
319
+ deqtz_func: nn.Module,
320
+ bits: int,
321
+ target_dtype: torch.dtype,
322
+ epochs: int = 1000,
323
+ early_stop: bool = True,
324
+ do_report: bool = False
325
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
326
+ loss_fn = nn.MSELoss()
327
+
328
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
329
+ # the order of magnitude of the loss divided by 2
330
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
331
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
332
+ loss = loss_fn(x, dequant)
333
+
334
+ base_lr = 0.1
335
+ exponent = int(np.floor(np.log10(loss.item())))
336
+ lr = base_lr * (10 ** (exponent // 2))
337
+
338
+ # Requires gradients in the parameters
339
+ for p in params.values():
340
+ p.requires_grad = True
341
+ p.grad = None
342
+
343
+ param_keys = list(params.keys())
344
+ param_values = list(params.values())
345
+
346
+ # Defines optimizer and loss function
347
+ optimizer = torch.optim.Adam(param_values, lr=lr)
348
+
349
+ # Contains the best loss and the best parameters
350
+ best_loss = float("inf")
351
+ best_params = None
352
+
353
+ # Used to stop the search early
354
+ min_delta = 1e-7
355
+ acc_loss = []
356
+ percent_epochs_before_stop = 0.1
357
+
358
+ for i in range(epochs):
359
+ optimizer.zero_grad()
360
+
361
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
362
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
363
+ loss = loss_fn(x, dequant)
364
+
365
+ if loss.isnan() or loss.isinf():
366
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
367
+
368
+ loss.backward()
369
+ optimizer.step()
370
+
371
+ acc_loss.append(loss.item())
372
+
373
+ # Reports loss every 10 steps
374
+ if i % 10 == 0 and do_report:
375
+ print(f"Epoch {i}: Loss {loss.item()}")
376
+
377
+ # Optimizes the parameter search by storing the best loss and the parameters
378
+ if loss.item() < best_loss:
379
+ best_loss = loss.item()
380
+ best_params = copy.deepcopy({
381
+ k: v for k, v in params.items() if k in param_keys
382
+ })
383
+
384
+ # We also stop the search if the loss has not considerably during the last 10% epochs
385
+ if early_stop:
386
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
387
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
388
+ break
389
+
390
+ # No longer requires gradients in the parameters
391
+ for p in best_params.values():
392
+ p.requires_grad = False
393
+ p.grad = None
394
+
395
+ if do_report:
396
+ print(f"Best loss: {best_loss}")
397
+ return best_params, acc_loss
398
+ else:
399
+ return best_params
400
+
401
+
402
+ def quantize(
403
+ x: torch.Tensor,
404
+ params: Dict[str, nn.Parameter],
405
+ func: nn.Module,
406
+ bits: int,
407
+ target_dtype: torch.dtype = torch.int8
408
+ ) -> torch.Tensor:
409
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
410
+ x = x.transpose(0, 1) # Aligns shapes
411
+ x = func(x=x, **params)
412
+ x = x.transpose(0, 1)
413
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
414
+ return x
415
+
416
+
417
+ def dequantize(
418
+ x: torch.Tensor,
419
+ params: Dict[str, nn.Parameter],
420
+ func: nn.Module,
421
+ bits: int,
422
+ out_dtype: torch.dtype
423
+ ) -> torch.Tensor:
424
+ x = x.to(dtype=out_dtype)
425
+ x = x.transpose(0, 1)
426
+ x = func(x=x, **params)
427
+ x = x.transpose(0, 1)
428
+ return x
429
+
430
+
431
+ def round_func_BPDA(input):
432
+ # This is equivalent to replacing round function (non-differentiable) with
433
+ # an identity function (differentiable) only when backward.
434
+ forward_value = torch.round(input)
435
+ out = input.clone()
436
+ out.data = forward_value.data
437
+ return out
438
+
439
+
440
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
441
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
442
+
443
+
444
+
445
+ ############## Numpy ###############
446
+
447
+ def np_domain_guard(
448
+ x: np.ndarray,
449
+ min: float = None,
450
+ max: float = None,
451
+ posinf: float = None,
452
+ neginf: float = None,
453
+ nan: float = None
454
+ ) -> np.ndarray:
455
+ """Guard a tensor to a valid domain."""
456
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
457
+ if min is not None or max is not None:
458
+ x = np.clip(x, min, max)
459
+ return x
460
+
461
+
462
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
463
+ """Replace a number in a tensor with another number.
464
+
465
+ Args:
466
+ x (np.ndarray): The input tensor.
467
+ num (float): The number to replace.
468
+ to (float): The number to replace with.
469
+
470
+ Returns:
471
+ np.ndarray: The tensor with the number replaced.
472
+ """
473
+ return np.where(x == num, to, x)
474
+
475
+
476
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
477
+ """Guard the power operation to a valid domain."""
478
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
479
+
fn_gen/nlr_t_no_sched/0/loss.png ADDED
fn_gen/nlr_t_no_sched/0/quantization.png ADDED
fn_gen/nlr_t_no_sched/1/distortion.png ADDED
fn_gen/nlr_t_no_sched/1/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sqrt(_0*x)/_s
2
+ _s**2*x**2/_0
fn_gen/nlr_t_no_sched/1/fn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sqrt(domain_guard((params['_0'] * x), min=0.1, nan=0.1)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+ params = learn_parameters(x, params,
36
+ qtz_func=quantization,
37
+ deqtz_func=dequantization,
38
+ bits=kwargs['bits'],
39
+ target_dtype=torch.int8,
40
+ epochs=500,
41
+ early_stop=False,
42
+ )
43
+ if 'post_train_hook' in kwargs:
44
+ kwargs['post_train_hook'](parameters=params)
45
+
46
+ return params
47
+
48
+
49
+ ############### Numpy Qtz ###############
50
+
51
+
52
+ def np_quantization(x, _0, _s):
53
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sqrt(np_domain_guard((_0 * x), min=0.1, nan=0.1)))
54
+
55
+
56
+ def np_dequantization(x, _0, _s):
57
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))
58
+
59
+
60
+ def fit_func(x, _0, _s):
61
+ x_ = np_quantization(x, _0, _s)
62
+ x_ = np_dequantization(x_, _0, _s)
63
+ return x_
64
+
65
+
66
+
67
+ ############### HELPERS ###############
68
+
69
+ def domain_guard(
70
+ x: torch.Tensor,
71
+ min: float = None,
72
+ max: float = None,
73
+ posinf: float = None,
74
+ neginf: float = None,
75
+ nan: float = None
76
+ ) -> torch.Tensor:
77
+ """Guard a tensor to a valid domain."""
78
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
79
+ if min is not None or max is not None:
80
+ x = torch.clamp(x, min=min, max=max)
81
+ return x
82
+
83
+
84
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
85
+ """Replace a number in a tensor with another number.
86
+
87
+ Args:
88
+ x (torch.Tensor): The input tensor.
89
+ num (float): The number to replace.
90
+ to (float): The number to replace with.
91
+
92
+ Returns:
93
+ torch.Tensor: The tensor with the number replaced.
94
+ """
95
+ return torch.where(x == num, to, x)
96
+
97
+
98
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
99
+ """Guard the power operation to a valid domain."""
100
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
101
+
102
+
103
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
104
+ val = torch.amin(x, dim=1)
105
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
106
+
107
+
108
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
109
+ val = torch.amin(x, dim=1)
110
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
111
+
112
+
113
+ def init_space_search(
114
+ x: torch.Tensor,
115
+ **kwargs: Dict[str, Any],
116
+ ) -> torch.Tensor:
117
+
118
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
119
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
120
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
121
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
122
+
123
+ def _search_param(tensors: List[torch.tensor], n_params):
124
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
125
+ torch_tensors = torch.stack(tensors)
126
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
127
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
128
+ mean = torch.mean(torch_tensors, dim=0)
129
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
130
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
131
+
132
+ def _calc(x, qtz_func, deqtz_func, **params):
133
+ x_ = x.transpose(0, 1)
134
+ x_ = qtz_func(x=x_, **params)
135
+ x_ = deqtz_func(x=x_, **params)
136
+ x_ = x_.transpose(0, 1)
137
+ return x_
138
+
139
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
140
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
141
+ assert "params_list" in kwargs, "params list must be provided."
142
+ assert "param" in kwargs, "param must be provided."
143
+
144
+ qtz_func = kwargs.get('qtz_func')
145
+ deqtz_func = kwargs.get('deqtz_func')
146
+ params_list = kwargs.get('params_list')
147
+ param = kwargs.get('param')
148
+
149
+ n_runs = 50 # Number of runs to try to find the best parameters
150
+ n_random_params = 50 # Number of random parameters to generate
151
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
152
+ max_initial = 10000 # Maximum value to initialize the parameters
153
+
154
+ # Initializes the parameters
155
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
156
+ params = _build_initial_param(x, max_initial, n_random_params)
157
+
158
+ # Performs the search
159
+ for _ in range(n_runs):
160
+
161
+ best_params = []
162
+ for param_ in params:
163
+ try:
164
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
165
+ loss_ones = nn.MSELoss()(x, x_)
166
+
167
+ if len(best_params) < n_best_to_pick:
168
+ best_params.append((param_, loss_ones.item()))
169
+ best_params = sorted(best_params, key=lambda x: x[1])
170
+ elif loss_ones < best_params[-1][1]:
171
+ best_params[-1] = (param_, loss_ones.item())
172
+ best_params = sorted(best_params, key=lambda x: x[1])
173
+
174
+ except Exception: # The parameters might not be valid for the function's domain
175
+ continue
176
+
177
+ # Generates new parameters around the mean
178
+ params = _search_param([p for p, _ in best_params], n_random_params)
179
+
180
+ # Checks if the best parameter is better than the init_ones
181
+ p_ones = init_ones(x, **kwargs)
182
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
183
+ loss_ones = nn.MSELoss()(x, x_)
184
+
185
+ # Checks if the best parameter is better than the init_rand
186
+ p_rand = init_rand(x, **kwargs)
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
188
+ loss_rand = nn.MSELoss()(x, x_)
189
+
190
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
191
+ return p_rand
192
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
193
+ return p_ones
194
+ else:
195
+ return best_params[0][0]
196
+
197
+
198
+ def init_linear_scale( # Symmetric scale. From the study folder
199
+ x: torch.Tensor,
200
+ **kwargs: Dict[str, Any],
201
+ ) -> torch.Tensor:
202
+ assert "bits" in kwargs, "bits must be provided."
203
+ assert "params" in kwargs, "params must be provided."
204
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
205
+
206
+ bits = kwargs.get('bits')
207
+ params = kwargs.get('params')
208
+ qtz_func = kwargs.get('qtz_func')
209
+
210
+ x_ = x.transpose(0, 1)
211
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
212
+ x_ = x_.transpose(0, 1)
213
+
214
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
215
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
216
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
217
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
218
+
219
+ eps = torch.finfo(torch.float32).eps
220
+
221
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
222
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
223
+
224
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
225
+
226
+ # Introduces some noise in scale
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
228
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
229
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
230
+ # left it here for future reference. Will be removed later.
231
+ # scale = scale + 0.01 * torch.randn_like(scale)
232
+
233
+ return scale
234
+
235
+
236
+ def init_non_linear_regression_fit(
237
+ x: torch.Tensor,
238
+ **kwargs: Dict[str, Any],
239
+ ) -> torch.Tensor:
240
+
241
+ assert "params_list" in kwargs, "params list must be provided."
242
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
243
+ assert "p0" in kwargs, "p0 must be provided."
244
+ np_fit_func = kwargs.get('np_fit_func')
245
+ params_list = kwargs.get('params_list')
246
+ p0 = kwargs.get('p0')
247
+
248
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
249
+ popt, _ = curve_fit(
250
+ func,
251
+ xdata,
252
+ ydata,
253
+ maxfev=1000,
254
+ p0=p0,
255
+ method='lm'
256
+ )
257
+ return popt
258
+
259
+ # 1. Needs to convert the torch tensor to numpy tensor
260
+ xdata = x.cpu().numpy()
261
+
262
+ # 2. Sorts the data so that it makes it easier to fit to it
263
+ sorted_xdata = np.sort(xdata, axis=-1)
264
+
265
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
266
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
267
+
268
+ # 3. Finds the best parameters for each channel
269
+ try:
270
+ params = []
271
+ for i in range(sorted_xdata.shape[0]):
272
+ xdata_ = sorted_xdata[i]
273
+ p0_ = [p0[p][i] for p in params_list]
274
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
275
+ params.append(ch_params)
276
+
277
+ # 4. Builds the parameters
278
+ result = {}
279
+ for i, p in enumerate(params_list):
280
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
281
+
282
+ return result
283
+
284
+ except ValueError as e:
285
+ print(f"Could not fit the function with error: {e}")
286
+ print(f"Using fallback result...")
287
+ return {
288
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
289
+ }
290
+
291
+
292
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
293
+ val = torch.amin(x, dim=1)
294
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
295
+
296
+
297
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
298
+ # Calculate the original minimum and maximum values
299
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
300
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
301
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
302
+
303
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
304
+ return torch.ones_like(x_min)
305
+
306
+ # Calculate the scale factor
307
+ scale = (_max - _min) / (x_max - x_min)
308
+ return scale
309
+
310
+
311
+
312
+ ############## Quant ###############
313
+
314
+ @torch.enable_grad()
315
+ def learn_parameters(
316
+ x: torch.Tensor,
317
+ params: Dict[str, nn.Parameter],
318
+ qtz_func: nn.Module,
319
+ deqtz_func: nn.Module,
320
+ bits: int,
321
+ target_dtype: torch.dtype,
322
+ epochs: int = 1000,
323
+ early_stop: bool = True,
324
+ do_report: bool = False
325
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
326
+ loss_fn = nn.MSELoss()
327
+
328
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
329
+ # the order of magnitude of the loss divided by 2
330
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
331
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
332
+ loss = loss_fn(x, dequant)
333
+
334
+ base_lr = 0.1
335
+ exponent = int(np.floor(np.log10(loss.item())))
336
+ lr = base_lr * (10 ** (exponent // 2))
337
+
338
+ # Requires gradients in the parameters
339
+ for p in params.values():
340
+ p.requires_grad = True
341
+ p.grad = None
342
+
343
+ param_keys = list(params.keys())
344
+ param_values = list(params.values())
345
+
346
+ # Defines optimizer and loss function
347
+ optimizer = torch.optim.Adam(param_values, lr=lr)
348
+
349
+ # Contains the best loss and the best parameters
350
+ best_loss = float("inf")
351
+ best_params = None
352
+
353
+ # Used to stop the search early
354
+ min_delta = 1e-7
355
+ acc_loss = []
356
+ percent_epochs_before_stop = 0.1
357
+
358
+ for i in range(epochs):
359
+ optimizer.zero_grad()
360
+
361
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
362
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
363
+ loss = loss_fn(x, dequant)
364
+
365
+ if loss.isnan() or loss.isinf():
366
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
367
+
368
+ loss.backward()
369
+ optimizer.step()
370
+
371
+ acc_loss.append(loss.item())
372
+
373
+ # Reports loss every 10 steps
374
+ if i % 10 == 0 and do_report:
375
+ print(f"Epoch {i}: Loss {loss.item()}")
376
+
377
+ # Optimizes the parameter search by storing the best loss and the parameters
378
+ if loss.item() < best_loss:
379
+ best_loss = loss.item()
380
+ best_params = copy.deepcopy({
381
+ k: v for k, v in params.items() if k in param_keys
382
+ })
383
+
384
+ # We also stop the search if the loss has not considerably during the last 10% epochs
385
+ if early_stop:
386
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
387
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
388
+ break
389
+
390
+ # No longer requires gradients in the parameters
391
+ for p in best_params.values():
392
+ p.requires_grad = False
393
+ p.grad = None
394
+
395
+ if do_report:
396
+ print(f"Best loss: {best_loss}")
397
+ return best_params, acc_loss
398
+ else:
399
+ return best_params
400
+
401
+
402
+ def quantize(
403
+ x: torch.Tensor,
404
+ params: Dict[str, nn.Parameter],
405
+ func: nn.Module,
406
+ bits: int,
407
+ target_dtype: torch.dtype = torch.int8
408
+ ) -> torch.Tensor:
409
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
410
+ x = x.transpose(0, 1) # Aligns shapes
411
+ x = func(x=x, **params)
412
+ x = x.transpose(0, 1)
413
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
414
+ return x
415
+
416
+
417
+ def dequantize(
418
+ x: torch.Tensor,
419
+ params: Dict[str, nn.Parameter],
420
+ func: nn.Module,
421
+ bits: int,
422
+ out_dtype: torch.dtype
423
+ ) -> torch.Tensor:
424
+ x = x.to(dtype=out_dtype)
425
+ x = x.transpose(0, 1)
426
+ x = func(x=x, **params)
427
+ x = x.transpose(0, 1)
428
+ return x
429
+
430
+
431
+ def round_func_BPDA(input):
432
+ # This is equivalent to replacing round function (non-differentiable) with
433
+ # an identity function (differentiable) only when backward.
434
+ forward_value = torch.round(input)
435
+ out = input.clone()
436
+ out.data = forward_value.data
437
+ return out
438
+
439
+
440
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
441
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
442
+
443
+
444
+
445
+ ############## Numpy ###############
446
+
447
+ def np_domain_guard(
448
+ x: np.ndarray,
449
+ min: float = None,
450
+ max: float = None,
451
+ posinf: float = None,
452
+ neginf: float = None,
453
+ nan: float = None
454
+ ) -> np.ndarray:
455
+ """Guard a tensor to a valid domain."""
456
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
457
+ if min is not None or max is not None:
458
+ x = np.clip(x, min, max)
459
+ return x
460
+
461
+
462
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
463
+ """Replace a number in a tensor with another number.
464
+
465
+ Args:
466
+ x (np.ndarray): The input tensor.
467
+ num (float): The number to replace.
468
+ to (float): The number to replace with.
469
+
470
+ Returns:
471
+ np.ndarray: The tensor with the number replaced.
472
+ """
473
+ return np.where(x == num, to, x)
474
+
475
+
476
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
477
+ """Guard the power operation to a valid domain."""
478
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
479
+
fn_gen/nlr_t_no_sched/1/loss.png ADDED
fn_gen/nlr_t_no_sched/1/quantization.png ADDED
fn_gen/nlr_t_no_sched/10/distortion.png ADDED
fn_gen/nlr_t_no_sched/10/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ exp(_0*x)/_s
2
+ log(_s*x)/_0
fn_gen/nlr_t_no_sched/10/fn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.exp((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((params['_s'] * x), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+ params = learn_parameters(x, params,
36
+ qtz_func=quantization,
37
+ deqtz_func=dequantization,
38
+ bits=kwargs['bits'],
39
+ target_dtype=torch.int8,
40
+ epochs=500,
41
+ early_stop=False,
42
+ )
43
+ if 'post_train_hook' in kwargs:
44
+ kwargs['post_train_hook'](parameters=params)
45
+
46
+ return params
47
+
48
+
49
+ ############### Numpy Qtz ###############
50
+
51
+
52
+ def np_quantization(x, _0, _s):
53
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.exp((_0 * x)))
54
+
55
+
56
+ def np_dequantization(x, _0, _s):
57
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((_s * x), min=1e-5, nan=1e-5)))
58
+
59
+
60
+ def fit_func(x, _0, _s):
61
+ x_ = np_quantization(x, _0, _s)
62
+ x_ = np_dequantization(x_, _0, _s)
63
+ return x_
64
+
65
+
66
+
67
+ ############### HELPERS ###############
68
+
69
+ def domain_guard(
70
+ x: torch.Tensor,
71
+ min: float = None,
72
+ max: float = None,
73
+ posinf: float = None,
74
+ neginf: float = None,
75
+ nan: float = None
76
+ ) -> torch.Tensor:
77
+ """Guard a tensor to a valid domain."""
78
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
79
+ if min is not None or max is not None:
80
+ x = torch.clamp(x, min=min, max=max)
81
+ return x
82
+
83
+
84
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
85
+ """Replace a number in a tensor with another number.
86
+
87
+ Args:
88
+ x (torch.Tensor): The input tensor.
89
+ num (float): The number to replace.
90
+ to (float): The number to replace with.
91
+
92
+ Returns:
93
+ torch.Tensor: The tensor with the number replaced.
94
+ """
95
+ return torch.where(x == num, to, x)
96
+
97
+
98
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
99
+ """Guard the power operation to a valid domain."""
100
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
101
+
102
+
103
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
104
+ val = torch.amin(x, dim=1)
105
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
106
+
107
+
108
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
109
+ val = torch.amin(x, dim=1)
110
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
111
+
112
+
113
+ def init_space_search(
114
+ x: torch.Tensor,
115
+ **kwargs: Dict[str, Any],
116
+ ) -> torch.Tensor:
117
+
118
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
119
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
120
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
121
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
122
+
123
+ def _search_param(tensors: List[torch.tensor], n_params):
124
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
125
+ torch_tensors = torch.stack(tensors)
126
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
127
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
128
+ mean = torch.mean(torch_tensors, dim=0)
129
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
130
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
131
+
132
+ def _calc(x, qtz_func, deqtz_func, **params):
133
+ x_ = x.transpose(0, 1)
134
+ x_ = qtz_func(x=x_, **params)
135
+ x_ = deqtz_func(x=x_, **params)
136
+ x_ = x_.transpose(0, 1)
137
+ return x_
138
+
139
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
140
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
141
+ assert "params_list" in kwargs, "params list must be provided."
142
+ assert "param" in kwargs, "param must be provided."
143
+
144
+ qtz_func = kwargs.get('qtz_func')
145
+ deqtz_func = kwargs.get('deqtz_func')
146
+ params_list = kwargs.get('params_list')
147
+ param = kwargs.get('param')
148
+
149
+ n_runs = 50 # Number of runs to try to find the best parameters
150
+ n_random_params = 50 # Number of random parameters to generate
151
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
152
+ max_initial = 10000 # Maximum value to initialize the parameters
153
+
154
+ # Initializes the parameters
155
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
156
+ params = _build_initial_param(x, max_initial, n_random_params)
157
+
158
+ # Performs the search
159
+ for _ in range(n_runs):
160
+
161
+ best_params = []
162
+ for param_ in params:
163
+ try:
164
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
165
+ loss_ones = nn.MSELoss()(x, x_)
166
+
167
+ if len(best_params) < n_best_to_pick:
168
+ best_params.append((param_, loss_ones.item()))
169
+ best_params = sorted(best_params, key=lambda x: x[1])
170
+ elif loss_ones < best_params[-1][1]:
171
+ best_params[-1] = (param_, loss_ones.item())
172
+ best_params = sorted(best_params, key=lambda x: x[1])
173
+
174
+ except Exception: # The parameters might not be valid for the function's domain
175
+ continue
176
+
177
+ # Generates new parameters around the mean
178
+ params = _search_param([p for p, _ in best_params], n_random_params)
179
+
180
+ # Checks if the best parameter is better than the init_ones
181
+ p_ones = init_ones(x, **kwargs)
182
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
183
+ loss_ones = nn.MSELoss()(x, x_)
184
+
185
+ # Checks if the best parameter is better than the init_rand
186
+ p_rand = init_rand(x, **kwargs)
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
188
+ loss_rand = nn.MSELoss()(x, x_)
189
+
190
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
191
+ return p_rand
192
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
193
+ return p_ones
194
+ else:
195
+ return best_params[0][0]
196
+
197
+
198
+ def init_linear_scale( # Symmetric scale. From the study folder
199
+ x: torch.Tensor,
200
+ **kwargs: Dict[str, Any],
201
+ ) -> torch.Tensor:
202
+ assert "bits" in kwargs, "bits must be provided."
203
+ assert "params" in kwargs, "params must be provided."
204
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
205
+
206
+ bits = kwargs.get('bits')
207
+ params = kwargs.get('params')
208
+ qtz_func = kwargs.get('qtz_func')
209
+
210
+ x_ = x.transpose(0, 1)
211
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
212
+ x_ = x_.transpose(0, 1)
213
+
214
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
215
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
216
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
217
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
218
+
219
+ eps = torch.finfo(torch.float32).eps
220
+
221
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
222
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
223
+
224
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
225
+
226
+ # Introduces some noise in scale
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
228
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
229
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
230
+ # left it here for future reference. Will be removed later.
231
+ # scale = scale + 0.01 * torch.randn_like(scale)
232
+
233
+ return scale
234
+
235
+
236
+ def init_non_linear_regression_fit(
237
+ x: torch.Tensor,
238
+ **kwargs: Dict[str, Any],
239
+ ) -> torch.Tensor:
240
+
241
+ assert "params_list" in kwargs, "params list must be provided."
242
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
243
+ assert "p0" in kwargs, "p0 must be provided."
244
+ np_fit_func = kwargs.get('np_fit_func')
245
+ params_list = kwargs.get('params_list')
246
+ p0 = kwargs.get('p0')
247
+
248
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
249
+ popt, _ = curve_fit(
250
+ func,
251
+ xdata,
252
+ ydata,
253
+ maxfev=1000,
254
+ p0=p0,
255
+ method='lm'
256
+ )
257
+ return popt
258
+
259
+ # 1. Needs to convert the torch tensor to numpy tensor
260
+ xdata = x.cpu().numpy()
261
+
262
+ # 2. Sorts the data so that it makes it easier to fit to it
263
+ sorted_xdata = np.sort(xdata, axis=-1)
264
+
265
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
266
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
267
+
268
+ # 3. Finds the best parameters for each channel
269
+ try:
270
+ params = []
271
+ for i in range(sorted_xdata.shape[0]):
272
+ xdata_ = sorted_xdata[i]
273
+ p0_ = [p0[p][i] for p in params_list]
274
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
275
+ params.append(ch_params)
276
+
277
+ # 4. Builds the parameters
278
+ result = {}
279
+ for i, p in enumerate(params_list):
280
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
281
+
282
+ return result
283
+
284
+ except ValueError as e:
285
+ print(f"Could not fit the function with error: {e}")
286
+ print(f"Using fallback result...")
287
+ return {
288
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
289
+ }
290
+
291
+
292
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
293
+ val = torch.amin(x, dim=1)
294
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
295
+
296
+
297
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
298
+ # Calculate the original minimum and maximum values
299
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
300
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
301
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
302
+
303
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
304
+ return torch.ones_like(x_min)
305
+
306
+ # Calculate the scale factor
307
+ scale = (_max - _min) / (x_max - x_min)
308
+ return scale
309
+
310
+
311
+
312
+ ############## Quant ###############
313
+
314
+ @torch.enable_grad()
315
+ def learn_parameters(
316
+ x: torch.Tensor,
317
+ params: Dict[str, nn.Parameter],
318
+ qtz_func: nn.Module,
319
+ deqtz_func: nn.Module,
320
+ bits: int,
321
+ target_dtype: torch.dtype,
322
+ epochs: int = 1000,
323
+ early_stop: bool = True,
324
+ do_report: bool = False
325
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
326
+ loss_fn = nn.MSELoss()
327
+
328
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
329
+ # the order of magnitude of the loss divided by 2
330
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
331
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
332
+ loss = loss_fn(x, dequant)
333
+
334
+ base_lr = 0.1
335
+ exponent = int(np.floor(np.log10(loss.item())))
336
+ lr = base_lr * (10 ** (exponent // 2))
337
+
338
+ # Requires gradients in the parameters
339
+ for p in params.values():
340
+ p.requires_grad = True
341
+ p.grad = None
342
+
343
+ param_keys = list(params.keys())
344
+ param_values = list(params.values())
345
+
346
+ # Defines optimizer and loss function
347
+ optimizer = torch.optim.Adam(param_values, lr=lr)
348
+
349
+ # Contains the best loss and the best parameters
350
+ best_loss = float("inf")
351
+ best_params = None
352
+
353
+ # Used to stop the search early
354
+ min_delta = 1e-7
355
+ acc_loss = []
356
+ percent_epochs_before_stop = 0.1
357
+
358
+ for i in range(epochs):
359
+ optimizer.zero_grad()
360
+
361
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
362
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
363
+ loss = loss_fn(x, dequant)
364
+
365
+ if loss.isnan() or loss.isinf():
366
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
367
+
368
+ loss.backward()
369
+ optimizer.step()
370
+
371
+ acc_loss.append(loss.item())
372
+
373
+ # Reports loss every 10 steps
374
+ if i % 10 == 0 and do_report:
375
+ print(f"Epoch {i}: Loss {loss.item()}")
376
+
377
+ # Optimizes the parameter search by storing the best loss and the parameters
378
+ if loss.item() < best_loss:
379
+ best_loss = loss.item()
380
+ best_params = copy.deepcopy({
381
+ k: v for k, v in params.items() if k in param_keys
382
+ })
383
+
384
+ # We also stop the search if the loss has not considerably during the last 10% epochs
385
+ if early_stop:
386
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
387
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
388
+ break
389
+
390
+ # No longer requires gradients in the parameters
391
+ for p in best_params.values():
392
+ p.requires_grad = False
393
+ p.grad = None
394
+
395
+ if do_report:
396
+ print(f"Best loss: {best_loss}")
397
+ return best_params, acc_loss
398
+ else:
399
+ return best_params
400
+
401
+
402
+ def quantize(
403
+ x: torch.Tensor,
404
+ params: Dict[str, nn.Parameter],
405
+ func: nn.Module,
406
+ bits: int,
407
+ target_dtype: torch.dtype = torch.int8
408
+ ) -> torch.Tensor:
409
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
410
+ x = x.transpose(0, 1) # Aligns shapes
411
+ x = func(x=x, **params)
412
+ x = x.transpose(0, 1)
413
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
414
+ return x
415
+
416
+
417
+ def dequantize(
418
+ x: torch.Tensor,
419
+ params: Dict[str, nn.Parameter],
420
+ func: nn.Module,
421
+ bits: int,
422
+ out_dtype: torch.dtype
423
+ ) -> torch.Tensor:
424
+ x = x.to(dtype=out_dtype)
425
+ x = x.transpose(0, 1)
426
+ x = func(x=x, **params)
427
+ x = x.transpose(0, 1)
428
+ return x
429
+
430
+
431
+ def round_func_BPDA(input):
432
+ # This is equivalent to replacing round function (non-differentiable) with
433
+ # an identity function (differentiable) only when backward.
434
+ forward_value = torch.round(input)
435
+ out = input.clone()
436
+ out.data = forward_value.data
437
+ return out
438
+
439
+
440
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
441
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
442
+
443
+
444
+
445
+ ############## Numpy ###############
446
+
447
+ def np_domain_guard(
448
+ x: np.ndarray,
449
+ min: float = None,
450
+ max: float = None,
451
+ posinf: float = None,
452
+ neginf: float = None,
453
+ nan: float = None
454
+ ) -> np.ndarray:
455
+ """Guard a tensor to a valid domain."""
456
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
457
+ if min is not None or max is not None:
458
+ x = np.clip(x, min, max)
459
+ return x
460
+
461
+
462
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
463
+ """Replace a number in a tensor with another number.
464
+
465
+ Args:
466
+ x (np.ndarray): The input tensor.
467
+ num (float): The number to replace.
468
+ to (float): The number to replace with.
469
+
470
+ Returns:
471
+ np.ndarray: The tensor with the number replaced.
472
+ """
473
+ return np.where(x == num, to, x)
474
+
475
+
476
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
477
+ """Guard the power operation to a valid domain."""
478
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
479
+
fn_gen/nlr_t_no_sched/10/loss.png ADDED
fn_gen/nlr_t_no_sched/10/quantization.png ADDED
fn_gen/nlr_t_no_sched/11/distortion.png ADDED
fn_gen/nlr_t_no_sched/11/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**2/_s
2
+ sqrt(_s*x)
fn_gen/nlr_t_no_sched/11/fn.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(2)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return torch.sqrt(domain_guard((params['_s'] * x), min=0.1, nan=0.1))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ }
24
+
25
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
26
+ if 'post_init_hook' in kwargs:
27
+ kwargs['post_init_hook'](parameters=base_p0)
28
+
29
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_s'], **kwargs)
30
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
31
+ if 'post_method_hook' in kwargs:
32
+ kwargs['post_method_hook'](parameters=params)
33
+
34
+ params = learn_parameters(x, params,
35
+ qtz_func=quantization,
36
+ deqtz_func=dequantization,
37
+ bits=kwargs['bits'],
38
+ target_dtype=torch.int8,
39
+ epochs=500,
40
+ early_stop=False,
41
+ )
42
+ if 'post_train_hook' in kwargs:
43
+ kwargs['post_train_hook'](parameters=params)
44
+
45
+ return params
46
+
47
+
48
+ ############### Numpy Qtz ###############
49
+
50
+
51
+ def np_quantization(x, _s):
52
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(2)))
53
+
54
+
55
+ def np_dequantization(x, _s):
56
+ return np.sqrt(np_domain_guard((_s * x), min=0.1, nan=0.1))
57
+
58
+
59
+ def fit_func(x, _s):
60
+ x_ = np_quantization(x, _s)
61
+ x_ = np_dequantization(x_, _s)
62
+ return x_
63
+
64
+
65
+
66
+ ############### HELPERS ###############
67
+
68
+ def domain_guard(
69
+ x: torch.Tensor,
70
+ min: float = None,
71
+ max: float = None,
72
+ posinf: float = None,
73
+ neginf: float = None,
74
+ nan: float = None
75
+ ) -> torch.Tensor:
76
+ """Guard a tensor to a valid domain."""
77
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
78
+ if min is not None or max is not None:
79
+ x = torch.clamp(x, min=min, max=max)
80
+ return x
81
+
82
+
83
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
84
+ """Replace a number in a tensor with another number.
85
+
86
+ Args:
87
+ x (torch.Tensor): The input tensor.
88
+ num (float): The number to replace.
89
+ to (float): The number to replace with.
90
+
91
+ Returns:
92
+ torch.Tensor: The tensor with the number replaced.
93
+ """
94
+ return torch.where(x == num, to, x)
95
+
96
+
97
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
98
+ """Guard the power operation to a valid domain."""
99
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
100
+
101
+
102
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
103
+ val = torch.amin(x, dim=1)
104
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
105
+
106
+
107
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
108
+ val = torch.amin(x, dim=1)
109
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
110
+
111
+
112
+ def init_space_search(
113
+ x: torch.Tensor,
114
+ **kwargs: Dict[str, Any],
115
+ ) -> torch.Tensor:
116
+
117
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
118
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
119
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
120
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
121
+
122
+ def _search_param(tensors: List[torch.tensor], n_params):
123
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
124
+ torch_tensors = torch.stack(tensors)
125
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
126
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
127
+ mean = torch.mean(torch_tensors, dim=0)
128
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
129
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
130
+
131
+ def _calc(x, qtz_func, deqtz_func, **params):
132
+ x_ = x.transpose(0, 1)
133
+ x_ = qtz_func(x=x_, **params)
134
+ x_ = deqtz_func(x=x_, **params)
135
+ x_ = x_.transpose(0, 1)
136
+ return x_
137
+
138
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
139
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
140
+ assert "params_list" in kwargs, "params list must be provided."
141
+ assert "param" in kwargs, "param must be provided."
142
+
143
+ qtz_func = kwargs.get('qtz_func')
144
+ deqtz_func = kwargs.get('deqtz_func')
145
+ params_list = kwargs.get('params_list')
146
+ param = kwargs.get('param')
147
+
148
+ n_runs = 50 # Number of runs to try to find the best parameters
149
+ n_random_params = 50 # Number of random parameters to generate
150
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
151
+ max_initial = 10000 # Maximum value to initialize the parameters
152
+
153
+ # Initializes the parameters
154
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
155
+ params = _build_initial_param(x, max_initial, n_random_params)
156
+
157
+ # Performs the search
158
+ for _ in range(n_runs):
159
+
160
+ best_params = []
161
+ for param_ in params:
162
+ try:
163
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
164
+ loss_ones = nn.MSELoss()(x, x_)
165
+
166
+ if len(best_params) < n_best_to_pick:
167
+ best_params.append((param_, loss_ones.item()))
168
+ best_params = sorted(best_params, key=lambda x: x[1])
169
+ elif loss_ones < best_params[-1][1]:
170
+ best_params[-1] = (param_, loss_ones.item())
171
+ best_params = sorted(best_params, key=lambda x: x[1])
172
+
173
+ except Exception: # The parameters might not be valid for the function's domain
174
+ continue
175
+
176
+ # Generates new parameters around the mean
177
+ params = _search_param([p for p, _ in best_params], n_random_params)
178
+
179
+ # Checks if the best parameter is better than the init_ones
180
+ p_ones = init_ones(x, **kwargs)
181
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
182
+ loss_ones = nn.MSELoss()(x, x_)
183
+
184
+ # Checks if the best parameter is better than the init_rand
185
+ p_rand = init_rand(x, **kwargs)
186
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
187
+ loss_rand = nn.MSELoss()(x, x_)
188
+
189
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
190
+ return p_rand
191
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
192
+ return p_ones
193
+ else:
194
+ return best_params[0][0]
195
+
196
+
197
+ def init_linear_scale( # Symmetric scale. From the study folder
198
+ x: torch.Tensor,
199
+ **kwargs: Dict[str, Any],
200
+ ) -> torch.Tensor:
201
+ assert "bits" in kwargs, "bits must be provided."
202
+ assert "params" in kwargs, "params must be provided."
203
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
204
+
205
+ bits = kwargs.get('bits')
206
+ params = kwargs.get('params')
207
+ qtz_func = kwargs.get('qtz_func')
208
+
209
+ x_ = x.transpose(0, 1)
210
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
211
+ x_ = x_.transpose(0, 1)
212
+
213
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
214
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
215
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
216
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
217
+
218
+ eps = torch.finfo(torch.float32).eps
219
+
220
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
221
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
222
+
223
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
224
+
225
+ # Introduces some noise in scale
226
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
228
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
229
+ # left it here for future reference. Will be removed later.
230
+ # scale = scale + 0.01 * torch.randn_like(scale)
231
+
232
+ return scale
233
+
234
+
235
+ def init_non_linear_regression_fit(
236
+ x: torch.Tensor,
237
+ **kwargs: Dict[str, Any],
238
+ ) -> torch.Tensor:
239
+
240
+ assert "params_list" in kwargs, "params list must be provided."
241
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
242
+ assert "p0" in kwargs, "p0 must be provided."
243
+ np_fit_func = kwargs.get('np_fit_func')
244
+ params_list = kwargs.get('params_list')
245
+ p0 = kwargs.get('p0')
246
+
247
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
248
+ popt, _ = curve_fit(
249
+ func,
250
+ xdata,
251
+ ydata,
252
+ maxfev=1000,
253
+ p0=p0,
254
+ method='lm'
255
+ )
256
+ return popt
257
+
258
+ # 1. Needs to convert the torch tensor to numpy tensor
259
+ xdata = x.cpu().numpy()
260
+
261
+ # 2. Sorts the data so that it makes it easier to fit to it
262
+ sorted_xdata = np.sort(xdata, axis=-1)
263
+
264
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
265
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
266
+
267
+ # 3. Finds the best parameters for each channel
268
+ try:
269
+ params = []
270
+ for i in range(sorted_xdata.shape[0]):
271
+ xdata_ = sorted_xdata[i]
272
+ p0_ = [p0[p][i] for p in params_list]
273
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
274
+ params.append(ch_params)
275
+
276
+ # 4. Builds the parameters
277
+ result = {}
278
+ for i, p in enumerate(params_list):
279
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
280
+
281
+ return result
282
+
283
+ except ValueError as e:
284
+ print(f"Could not fit the function with error: {e}")
285
+ print(f"Using fallback result...")
286
+ return {
287
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
288
+ }
289
+
290
+
291
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
292
+ val = torch.amin(x, dim=1)
293
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
294
+
295
+
296
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
297
+ # Calculate the original minimum and maximum values
298
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
299
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
300
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
301
+
302
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
303
+ return torch.ones_like(x_min)
304
+
305
+ # Calculate the scale factor
306
+ scale = (_max - _min) / (x_max - x_min)
307
+ return scale
308
+
309
+
310
+
311
+ ############## Quant ###############
312
+
313
+ @torch.enable_grad()
314
+ def learn_parameters(
315
+ x: torch.Tensor,
316
+ params: Dict[str, nn.Parameter],
317
+ qtz_func: nn.Module,
318
+ deqtz_func: nn.Module,
319
+ bits: int,
320
+ target_dtype: torch.dtype,
321
+ epochs: int = 1000,
322
+ early_stop: bool = True,
323
+ do_report: bool = False
324
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
325
+ loss_fn = nn.MSELoss()
326
+
327
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
328
+ # the order of magnitude of the loss divided by 2
329
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
330
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
331
+ loss = loss_fn(x, dequant)
332
+
333
+ base_lr = 0.1
334
+ exponent = int(np.floor(np.log10(loss.item())))
335
+ lr = base_lr * (10 ** (exponent // 2))
336
+
337
+ # Requires gradients in the parameters
338
+ for p in params.values():
339
+ p.requires_grad = True
340
+ p.grad = None
341
+
342
+ param_keys = list(params.keys())
343
+ param_values = list(params.values())
344
+
345
+ # Defines optimizer and loss function
346
+ optimizer = torch.optim.Adam(param_values, lr=lr)
347
+
348
+ # Contains the best loss and the best parameters
349
+ best_loss = float("inf")
350
+ best_params = None
351
+
352
+ # Used to stop the search early
353
+ min_delta = 1e-7
354
+ acc_loss = []
355
+ percent_epochs_before_stop = 0.1
356
+
357
+ for i in range(epochs):
358
+ optimizer.zero_grad()
359
+
360
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
361
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
362
+ loss = loss_fn(x, dequant)
363
+
364
+ if loss.isnan() or loss.isinf():
365
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
366
+
367
+ loss.backward()
368
+ optimizer.step()
369
+
370
+ acc_loss.append(loss.item())
371
+
372
+ # Reports loss every 10 steps
373
+ if i % 10 == 0 and do_report:
374
+ print(f"Epoch {i}: Loss {loss.item()}")
375
+
376
+ # Optimizes the parameter search by storing the best loss and the parameters
377
+ if loss.item() < best_loss:
378
+ best_loss = loss.item()
379
+ best_params = copy.deepcopy({
380
+ k: v for k, v in params.items() if k in param_keys
381
+ })
382
+
383
+ # We also stop the search if the loss has not considerably during the last 10% epochs
384
+ if early_stop:
385
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
386
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
387
+ break
388
+
389
+ # No longer requires gradients in the parameters
390
+ for p in best_params.values():
391
+ p.requires_grad = False
392
+ p.grad = None
393
+
394
+ if do_report:
395
+ print(f"Best loss: {best_loss}")
396
+ return best_params, acc_loss
397
+ else:
398
+ return best_params
399
+
400
+
401
+ def quantize(
402
+ x: torch.Tensor,
403
+ params: Dict[str, nn.Parameter],
404
+ func: nn.Module,
405
+ bits: int,
406
+ target_dtype: torch.dtype = torch.int8
407
+ ) -> torch.Tensor:
408
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
409
+ x = x.transpose(0, 1) # Aligns shapes
410
+ x = func(x=x, **params)
411
+ x = x.transpose(0, 1)
412
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
413
+ return x
414
+
415
+
416
+ def dequantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ out_dtype: torch.dtype
422
+ ) -> torch.Tensor:
423
+ x = x.to(dtype=out_dtype)
424
+ x = x.transpose(0, 1)
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ return x
428
+
429
+
430
+ def round_func_BPDA(input):
431
+ # This is equivalent to replacing round function (non-differentiable) with
432
+ # an identity function (differentiable) only when backward.
433
+ forward_value = torch.round(input)
434
+ out = input.clone()
435
+ out.data = forward_value.data
436
+ return out
437
+
438
+
439
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
440
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
441
+
442
+
443
+
444
+ ############## Numpy ###############
445
+
446
+ def np_domain_guard(
447
+ x: np.ndarray,
448
+ min: float = None,
449
+ max: float = None,
450
+ posinf: float = None,
451
+ neginf: float = None,
452
+ nan: float = None
453
+ ) -> np.ndarray:
454
+ """Guard a tensor to a valid domain."""
455
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
456
+ if min is not None or max is not None:
457
+ x = np.clip(x, min, max)
458
+ return x
459
+
460
+
461
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
462
+ """Replace a number in a tensor with another number.
463
+
464
+ Args:
465
+ x (np.ndarray): The input tensor.
466
+ num (float): The number to replace.
467
+ to (float): The number to replace with.
468
+
469
+ Returns:
470
+ np.ndarray: The tensor with the number replaced.
471
+ """
472
+ return np.where(x == num, to, x)
473
+
474
+
475
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
476
+ """Guard the power operation to a valid domain."""
477
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
478
+
fn_gen/nlr_t_no_sched/11/loss.png ADDED
fn_gen/nlr_t_no_sched/11/quantization.png ADDED
fn_gen/nlr_t_no_sched/12/distortion.png ADDED
fn_gen/nlr_t_no_sched/12/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cos(_0*x)/_s
2
+ acos(_s*x)/_0
fn_gen/nlr_t_no_sched/12/fn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cos((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.acos(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+ params = learn_parameters(x, params,
36
+ qtz_func=quantization,
37
+ deqtz_func=dequantization,
38
+ bits=kwargs['bits'],
39
+ target_dtype=torch.int8,
40
+ epochs=500,
41
+ early_stop=False,
42
+ )
43
+ if 'post_train_hook' in kwargs:
44
+ kwargs['post_train_hook'](parameters=params)
45
+
46
+ return params
47
+
48
+
49
+ ############### Numpy Qtz ###############
50
+
51
+
52
+ def np_quantization(x, _0, _s):
53
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cos((_0 * x)))
54
+
55
+
56
+ def np_dequantization(x, _0, _s):
57
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arccos(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0)))
58
+
59
+
60
+ def fit_func(x, _0, _s):
61
+ x_ = np_quantization(x, _0, _s)
62
+ x_ = np_dequantization(x_, _0, _s)
63
+ return x_
64
+
65
+
66
+
67
+ ############### HELPERS ###############
68
+
69
+ def domain_guard(
70
+ x: torch.Tensor,
71
+ min: float = None,
72
+ max: float = None,
73
+ posinf: float = None,
74
+ neginf: float = None,
75
+ nan: float = None
76
+ ) -> torch.Tensor:
77
+ """Guard a tensor to a valid domain."""
78
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
79
+ if min is not None or max is not None:
80
+ x = torch.clamp(x, min=min, max=max)
81
+ return x
82
+
83
+
84
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
85
+ """Replace a number in a tensor with another number.
86
+
87
+ Args:
88
+ x (torch.Tensor): The input tensor.
89
+ num (float): The number to replace.
90
+ to (float): The number to replace with.
91
+
92
+ Returns:
93
+ torch.Tensor: The tensor with the number replaced.
94
+ """
95
+ return torch.where(x == num, to, x)
96
+
97
+
98
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
99
+ """Guard the power operation to a valid domain."""
100
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
101
+
102
+
103
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
104
+ val = torch.amin(x, dim=1)
105
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
106
+
107
+
108
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
109
+ val = torch.amin(x, dim=1)
110
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
111
+
112
+
113
+ def init_space_search(
114
+ x: torch.Tensor,
115
+ **kwargs: Dict[str, Any],
116
+ ) -> torch.Tensor:
117
+
118
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
119
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
120
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
121
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
122
+
123
+ def _search_param(tensors: List[torch.tensor], n_params):
124
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
125
+ torch_tensors = torch.stack(tensors)
126
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
127
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
128
+ mean = torch.mean(torch_tensors, dim=0)
129
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
130
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
131
+
132
+ def _calc(x, qtz_func, deqtz_func, **params):
133
+ x_ = x.transpose(0, 1)
134
+ x_ = qtz_func(x=x_, **params)
135
+ x_ = deqtz_func(x=x_, **params)
136
+ x_ = x_.transpose(0, 1)
137
+ return x_
138
+
139
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
140
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
141
+ assert "params_list" in kwargs, "params list must be provided."
142
+ assert "param" in kwargs, "param must be provided."
143
+
144
+ qtz_func = kwargs.get('qtz_func')
145
+ deqtz_func = kwargs.get('deqtz_func')
146
+ params_list = kwargs.get('params_list')
147
+ param = kwargs.get('param')
148
+
149
+ n_runs = 50 # Number of runs to try to find the best parameters
150
+ n_random_params = 50 # Number of random parameters to generate
151
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
152
+ max_initial = 10000 # Maximum value to initialize the parameters
153
+
154
+ # Initializes the parameters
155
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
156
+ params = _build_initial_param(x, max_initial, n_random_params)
157
+
158
+ # Performs the search
159
+ for _ in range(n_runs):
160
+
161
+ best_params = []
162
+ for param_ in params:
163
+ try:
164
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
165
+ loss_ones = nn.MSELoss()(x, x_)
166
+
167
+ if len(best_params) < n_best_to_pick:
168
+ best_params.append((param_, loss_ones.item()))
169
+ best_params = sorted(best_params, key=lambda x: x[1])
170
+ elif loss_ones < best_params[-1][1]:
171
+ best_params[-1] = (param_, loss_ones.item())
172
+ best_params = sorted(best_params, key=lambda x: x[1])
173
+
174
+ except Exception: # The parameters might not be valid for the function's domain
175
+ continue
176
+
177
+ # Generates new parameters around the mean
178
+ params = _search_param([p for p, _ in best_params], n_random_params)
179
+
180
+ # Checks if the best parameter is better than the init_ones
181
+ p_ones = init_ones(x, **kwargs)
182
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
183
+ loss_ones = nn.MSELoss()(x, x_)
184
+
185
+ # Checks if the best parameter is better than the init_rand
186
+ p_rand = init_rand(x, **kwargs)
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
188
+ loss_rand = nn.MSELoss()(x, x_)
189
+
190
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
191
+ return p_rand
192
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
193
+ return p_ones
194
+ else:
195
+ return best_params[0][0]
196
+
197
+
198
+ def init_linear_scale( # Symmetric scale. From the study folder
199
+ x: torch.Tensor,
200
+ **kwargs: Dict[str, Any],
201
+ ) -> torch.Tensor:
202
+ assert "bits" in kwargs, "bits must be provided."
203
+ assert "params" in kwargs, "params must be provided."
204
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
205
+
206
+ bits = kwargs.get('bits')
207
+ params = kwargs.get('params')
208
+ qtz_func = kwargs.get('qtz_func')
209
+
210
+ x_ = x.transpose(0, 1)
211
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
212
+ x_ = x_.transpose(0, 1)
213
+
214
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
215
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
216
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
217
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
218
+
219
+ eps = torch.finfo(torch.float32).eps
220
+
221
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
222
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
223
+
224
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
225
+
226
+ # Introduces some noise in scale
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
228
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
229
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
230
+ # left it here for future reference. Will be removed later.
231
+ # scale = scale + 0.01 * torch.randn_like(scale)
232
+
233
+ return scale
234
+
235
+
236
+ def init_non_linear_regression_fit(
237
+ x: torch.Tensor,
238
+ **kwargs: Dict[str, Any],
239
+ ) -> torch.Tensor:
240
+
241
+ assert "params_list" in kwargs, "params list must be provided."
242
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
243
+ assert "p0" in kwargs, "p0 must be provided."
244
+ np_fit_func = kwargs.get('np_fit_func')
245
+ params_list = kwargs.get('params_list')
246
+ p0 = kwargs.get('p0')
247
+
248
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
249
+ popt, _ = curve_fit(
250
+ func,
251
+ xdata,
252
+ ydata,
253
+ maxfev=1000,
254
+ p0=p0,
255
+ method='lm'
256
+ )
257
+ return popt
258
+
259
+ # 1. Needs to convert the torch tensor to numpy tensor
260
+ xdata = x.cpu().numpy()
261
+
262
+ # 2. Sorts the data so that it makes it easier to fit to it
263
+ sorted_xdata = np.sort(xdata, axis=-1)
264
+
265
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
266
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
267
+
268
+ # 3. Finds the best parameters for each channel
269
+ try:
270
+ params = []
271
+ for i in range(sorted_xdata.shape[0]):
272
+ xdata_ = sorted_xdata[i]
273
+ p0_ = [p0[p][i] for p in params_list]
274
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
275
+ params.append(ch_params)
276
+
277
+ # 4. Builds the parameters
278
+ result = {}
279
+ for i, p in enumerate(params_list):
280
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
281
+
282
+ return result
283
+
284
+ except ValueError as e:
285
+ print(f"Could not fit the function with error: {e}")
286
+ print(f"Using fallback result...")
287
+ return {
288
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
289
+ }
290
+
291
+
292
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
293
+ val = torch.amin(x, dim=1)
294
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
295
+
296
+
297
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
298
+ # Calculate the original minimum and maximum values
299
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
300
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
301
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
302
+
303
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
304
+ return torch.ones_like(x_min)
305
+
306
+ # Calculate the scale factor
307
+ scale = (_max - _min) / (x_max - x_min)
308
+ return scale
309
+
310
+
311
+
312
+ ############## Quant ###############
313
+
314
+ @torch.enable_grad()
315
+ def learn_parameters(
316
+ x: torch.Tensor,
317
+ params: Dict[str, nn.Parameter],
318
+ qtz_func: nn.Module,
319
+ deqtz_func: nn.Module,
320
+ bits: int,
321
+ target_dtype: torch.dtype,
322
+ epochs: int = 1000,
323
+ early_stop: bool = True,
324
+ do_report: bool = False
325
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
326
+ loss_fn = nn.MSELoss()
327
+
328
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
329
+ # the order of magnitude of the loss divided by 2
330
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
331
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
332
+ loss = loss_fn(x, dequant)
333
+
334
+ base_lr = 0.1
335
+ exponent = int(np.floor(np.log10(loss.item())))
336
+ lr = base_lr * (10 ** (exponent // 2))
337
+
338
+ # Requires gradients in the parameters
339
+ for p in params.values():
340
+ p.requires_grad = True
341
+ p.grad = None
342
+
343
+ param_keys = list(params.keys())
344
+ param_values = list(params.values())
345
+
346
+ # Defines optimizer and loss function
347
+ optimizer = torch.optim.Adam(param_values, lr=lr)
348
+
349
+ # Contains the best loss and the best parameters
350
+ best_loss = float("inf")
351
+ best_params = None
352
+
353
+ # Used to stop the search early
354
+ min_delta = 1e-7
355
+ acc_loss = []
356
+ percent_epochs_before_stop = 0.1
357
+
358
+ for i in range(epochs):
359
+ optimizer.zero_grad()
360
+
361
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
362
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
363
+ loss = loss_fn(x, dequant)
364
+
365
+ if loss.isnan() or loss.isinf():
366
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
367
+
368
+ loss.backward()
369
+ optimizer.step()
370
+
371
+ acc_loss.append(loss.item())
372
+
373
+ # Reports loss every 10 steps
374
+ if i % 10 == 0 and do_report:
375
+ print(f"Epoch {i}: Loss {loss.item()}")
376
+
377
+ # Optimizes the parameter search by storing the best loss and the parameters
378
+ if loss.item() < best_loss:
379
+ best_loss = loss.item()
380
+ best_params = copy.deepcopy({
381
+ k: v for k, v in params.items() if k in param_keys
382
+ })
383
+
384
+ # We also stop the search if the loss has not considerably during the last 10% epochs
385
+ if early_stop:
386
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
387
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
388
+ break
389
+
390
+ # No longer requires gradients in the parameters
391
+ for p in best_params.values():
392
+ p.requires_grad = False
393
+ p.grad = None
394
+
395
+ if do_report:
396
+ print(f"Best loss: {best_loss}")
397
+ return best_params, acc_loss
398
+ else:
399
+ return best_params
400
+
401
+
402
+ def quantize(
403
+ x: torch.Tensor,
404
+ params: Dict[str, nn.Parameter],
405
+ func: nn.Module,
406
+ bits: int,
407
+ target_dtype: torch.dtype = torch.int8
408
+ ) -> torch.Tensor:
409
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
410
+ x = x.transpose(0, 1) # Aligns shapes
411
+ x = func(x=x, **params)
412
+ x = x.transpose(0, 1)
413
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
414
+ return x
415
+
416
+
417
+ def dequantize(
418
+ x: torch.Tensor,
419
+ params: Dict[str, nn.Parameter],
420
+ func: nn.Module,
421
+ bits: int,
422
+ out_dtype: torch.dtype
423
+ ) -> torch.Tensor:
424
+ x = x.to(dtype=out_dtype)
425
+ x = x.transpose(0, 1)
426
+ x = func(x=x, **params)
427
+ x = x.transpose(0, 1)
428
+ return x
429
+
430
+
431
+ def round_func_BPDA(input):
432
+ # This is equivalent to replacing round function (non-differentiable) with
433
+ # an identity function (differentiable) only when backward.
434
+ forward_value = torch.round(input)
435
+ out = input.clone()
436
+ out.data = forward_value.data
437
+ return out
438
+
439
+
440
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
441
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
442
+
443
+
444
+
445
+ ############## Numpy ###############
446
+
447
+ def np_domain_guard(
448
+ x: np.ndarray,
449
+ min: float = None,
450
+ max: float = None,
451
+ posinf: float = None,
452
+ neginf: float = None,
453
+ nan: float = None
454
+ ) -> np.ndarray:
455
+ """Guard a tensor to a valid domain."""
456
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
457
+ if min is not None or max is not None:
458
+ x = np.clip(x, min, max)
459
+ return x
460
+
461
+
462
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
463
+ """Replace a number in a tensor with another number.
464
+
465
+ Args:
466
+ x (np.ndarray): The input tensor.
467
+ num (float): The number to replace.
468
+ to (float): The number to replace with.
469
+
470
+ Returns:
471
+ np.ndarray: The tensor with the number replaced.
472
+ """
473
+ return np.where(x == num, to, x)
474
+
475
+
476
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
477
+ """Guard the power operation to a valid domain."""
478
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
479
+
fn_gen/nlr_t_no_sched/12/loss.png ADDED
fn_gen/nlr_t_no_sched/12/quantization.png ADDED
fn_gen/nlr_t_no_sched/14/distortion.png ADDED
fn_gen/nlr_t_no_sched/14/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asinh(_0*x)/_s
2
+ sinh(_s*x)/_0
fn_gen/nlr_t_no_sched/14/fn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+ params = learn_parameters(x, params,
36
+ qtz_func=quantization,
37
+ deqtz_func=dequantization,
38
+ bits=kwargs['bits'],
39
+ target_dtype=torch.int8,
40
+ epochs=500,
41
+ early_stop=False,
42
+ )
43
+ if 'post_train_hook' in kwargs:
44
+ kwargs['post_train_hook'](parameters=params)
45
+
46
+ return params
47
+
48
+
49
+ ############### Numpy Qtz ###############
50
+
51
+
52
+ def np_quantization(x, _0, _s):
53
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x)))
54
+
55
+
56
+ def np_dequantization(x, _0, _s):
57
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x)))
58
+
59
+
60
+ def fit_func(x, _0, _s):
61
+ x_ = np_quantization(x, _0, _s)
62
+ x_ = np_dequantization(x_, _0, _s)
63
+ return x_
64
+
65
+
66
+
67
+ ############### HELPERS ###############
68
+
69
+ def domain_guard(
70
+ x: torch.Tensor,
71
+ min: float = None,
72
+ max: float = None,
73
+ posinf: float = None,
74
+ neginf: float = None,
75
+ nan: float = None
76
+ ) -> torch.Tensor:
77
+ """Guard a tensor to a valid domain."""
78
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
79
+ if min is not None or max is not None:
80
+ x = torch.clamp(x, min=min, max=max)
81
+ return x
82
+
83
+
84
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
85
+ """Replace a number in a tensor with another number.
86
+
87
+ Args:
88
+ x (torch.Tensor): The input tensor.
89
+ num (float): The number to replace.
90
+ to (float): The number to replace with.
91
+
92
+ Returns:
93
+ torch.Tensor: The tensor with the number replaced.
94
+ """
95
+ return torch.where(x == num, to, x)
96
+
97
+
98
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
99
+ """Guard the power operation to a valid domain."""
100
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
101
+
102
+
103
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
104
+ val = torch.amin(x, dim=1)
105
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
106
+
107
+
108
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
109
+ val = torch.amin(x, dim=1)
110
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
111
+
112
+
113
+ def init_space_search(
114
+ x: torch.Tensor,
115
+ **kwargs: Dict[str, Any],
116
+ ) -> torch.Tensor:
117
+
118
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
119
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
120
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
121
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
122
+
123
+ def _search_param(tensors: List[torch.tensor], n_params):
124
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
125
+ torch_tensors = torch.stack(tensors)
126
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
127
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
128
+ mean = torch.mean(torch_tensors, dim=0)
129
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
130
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
131
+
132
+ def _calc(x, qtz_func, deqtz_func, **params):
133
+ x_ = x.transpose(0, 1)
134
+ x_ = qtz_func(x=x_, **params)
135
+ x_ = deqtz_func(x=x_, **params)
136
+ x_ = x_.transpose(0, 1)
137
+ return x_
138
+
139
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
140
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
141
+ assert "params_list" in kwargs, "params list must be provided."
142
+ assert "param" in kwargs, "param must be provided."
143
+
144
+ qtz_func = kwargs.get('qtz_func')
145
+ deqtz_func = kwargs.get('deqtz_func')
146
+ params_list = kwargs.get('params_list')
147
+ param = kwargs.get('param')
148
+
149
+ n_runs = 50 # Number of runs to try to find the best parameters
150
+ n_random_params = 50 # Number of random parameters to generate
151
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
152
+ max_initial = 10000 # Maximum value to initialize the parameters
153
+
154
+ # Initializes the parameters
155
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
156
+ params = _build_initial_param(x, max_initial, n_random_params)
157
+
158
+ # Performs the search
159
+ for _ in range(n_runs):
160
+
161
+ best_params = []
162
+ for param_ in params:
163
+ try:
164
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
165
+ loss_ones = nn.MSELoss()(x, x_)
166
+
167
+ if len(best_params) < n_best_to_pick:
168
+ best_params.append((param_, loss_ones.item()))
169
+ best_params = sorted(best_params, key=lambda x: x[1])
170
+ elif loss_ones < best_params[-1][1]:
171
+ best_params[-1] = (param_, loss_ones.item())
172
+ best_params = sorted(best_params, key=lambda x: x[1])
173
+
174
+ except Exception: # The parameters might not be valid for the function's domain
175
+ continue
176
+
177
+ # Generates new parameters around the mean
178
+ params = _search_param([p for p, _ in best_params], n_random_params)
179
+
180
+ # Checks if the best parameter is better than the init_ones
181
+ p_ones = init_ones(x, **kwargs)
182
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
183
+ loss_ones = nn.MSELoss()(x, x_)
184
+
185
+ # Checks if the best parameter is better than the init_rand
186
+ p_rand = init_rand(x, **kwargs)
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
188
+ loss_rand = nn.MSELoss()(x, x_)
189
+
190
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
191
+ return p_rand
192
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
193
+ return p_ones
194
+ else:
195
+ return best_params[0][0]
196
+
197
+
198
+ def init_linear_scale( # Symmetric scale. From the study folder
199
+ x: torch.Tensor,
200
+ **kwargs: Dict[str, Any],
201
+ ) -> torch.Tensor:
202
+ assert "bits" in kwargs, "bits must be provided."
203
+ assert "params" in kwargs, "params must be provided."
204
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
205
+
206
+ bits = kwargs.get('bits')
207
+ params = kwargs.get('params')
208
+ qtz_func = kwargs.get('qtz_func')
209
+
210
+ x_ = x.transpose(0, 1)
211
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
212
+ x_ = x_.transpose(0, 1)
213
+
214
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
215
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
216
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
217
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
218
+
219
+ eps = torch.finfo(torch.float32).eps
220
+
221
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
222
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
223
+
224
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
225
+
226
+ # Introduces some noise in scale
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
228
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
229
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
230
+ # left it here for future reference. Will be removed later.
231
+ # scale = scale + 0.01 * torch.randn_like(scale)
232
+
233
+ return scale
234
+
235
+
236
+ def init_non_linear_regression_fit(
237
+ x: torch.Tensor,
238
+ **kwargs: Dict[str, Any],
239
+ ) -> torch.Tensor:
240
+
241
+ assert "params_list" in kwargs, "params list must be provided."
242
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
243
+ assert "p0" in kwargs, "p0 must be provided."
244
+ np_fit_func = kwargs.get('np_fit_func')
245
+ params_list = kwargs.get('params_list')
246
+ p0 = kwargs.get('p0')
247
+
248
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
249
+ popt, _ = curve_fit(
250
+ func,
251
+ xdata,
252
+ ydata,
253
+ maxfev=1000,
254
+ p0=p0,
255
+ method='lm'
256
+ )
257
+ return popt
258
+
259
+ # 1. Needs to convert the torch tensor to numpy tensor
260
+ xdata = x.cpu().numpy()
261
+
262
+ # 2. Sorts the data so that it makes it easier to fit to it
263
+ sorted_xdata = np.sort(xdata, axis=-1)
264
+
265
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
266
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
267
+
268
+ # 3. Finds the best parameters for each channel
269
+ try:
270
+ params = []
271
+ for i in range(sorted_xdata.shape[0]):
272
+ xdata_ = sorted_xdata[i]
273
+ p0_ = [p0[p][i] for p in params_list]
274
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
275
+ params.append(ch_params)
276
+
277
+ # 4. Builds the parameters
278
+ result = {}
279
+ for i, p in enumerate(params_list):
280
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
281
+
282
+ return result
283
+
284
+ except ValueError as e:
285
+ print(f"Could not fit the function with error: {e}")
286
+ print(f"Using fallback result...")
287
+ return {
288
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
289
+ }
290
+
291
+
292
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
293
+ val = torch.amin(x, dim=1)
294
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
295
+
296
+
297
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
298
+ # Calculate the original minimum and maximum values
299
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
300
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
301
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
302
+
303
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
304
+ return torch.ones_like(x_min)
305
+
306
+ # Calculate the scale factor
307
+ scale = (_max - _min) / (x_max - x_min)
308
+ return scale
309
+
310
+
311
+
312
+ ############## Quant ###############
313
+
314
+ @torch.enable_grad()
315
+ def learn_parameters(
316
+ x: torch.Tensor,
317
+ params: Dict[str, nn.Parameter],
318
+ qtz_func: nn.Module,
319
+ deqtz_func: nn.Module,
320
+ bits: int,
321
+ target_dtype: torch.dtype,
322
+ epochs: int = 1000,
323
+ early_stop: bool = True,
324
+ do_report: bool = False
325
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
326
+ loss_fn = nn.MSELoss()
327
+
328
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
329
+ # the order of magnitude of the loss divided by 2
330
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
331
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
332
+ loss = loss_fn(x, dequant)
333
+
334
+ base_lr = 0.1
335
+ exponent = int(np.floor(np.log10(loss.item())))
336
+ lr = base_lr * (10 ** (exponent // 2))
337
+
338
+ # Requires gradients in the parameters
339
+ for p in params.values():
340
+ p.requires_grad = True
341
+ p.grad = None
342
+
343
+ param_keys = list(params.keys())
344
+ param_values = list(params.values())
345
+
346
+ # Defines optimizer and loss function
347
+ optimizer = torch.optim.Adam(param_values, lr=lr)
348
+
349
+ # Contains the best loss and the best parameters
350
+ best_loss = float("inf")
351
+ best_params = None
352
+
353
+ # Used to stop the search early
354
+ min_delta = 1e-7
355
+ acc_loss = []
356
+ percent_epochs_before_stop = 0.1
357
+
358
+ for i in range(epochs):
359
+ optimizer.zero_grad()
360
+
361
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
362
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
363
+ loss = loss_fn(x, dequant)
364
+
365
+ if loss.isnan() or loss.isinf():
366
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
367
+
368
+ loss.backward()
369
+ optimizer.step()
370
+
371
+ acc_loss.append(loss.item())
372
+
373
+ # Reports loss every 10 steps
374
+ if i % 10 == 0 and do_report:
375
+ print(f"Epoch {i}: Loss {loss.item()}")
376
+
377
+ # Optimizes the parameter search by storing the best loss and the parameters
378
+ if loss.item() < best_loss:
379
+ best_loss = loss.item()
380
+ best_params = copy.deepcopy({
381
+ k: v for k, v in params.items() if k in param_keys
382
+ })
383
+
384
+ # We also stop the search if the loss has not considerably during the last 10% epochs
385
+ if early_stop:
386
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
387
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
388
+ break
389
+
390
+ # No longer requires gradients in the parameters
391
+ for p in best_params.values():
392
+ p.requires_grad = False
393
+ p.grad = None
394
+
395
+ if do_report:
396
+ print(f"Best loss: {best_loss}")
397
+ return best_params, acc_loss
398
+ else:
399
+ return best_params
400
+
401
+
402
+ def quantize(
403
+ x: torch.Tensor,
404
+ params: Dict[str, nn.Parameter],
405
+ func: nn.Module,
406
+ bits: int,
407
+ target_dtype: torch.dtype = torch.int8
408
+ ) -> torch.Tensor:
409
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
410
+ x = x.transpose(0, 1) # Aligns shapes
411
+ x = func(x=x, **params)
412
+ x = x.transpose(0, 1)
413
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
414
+ return x
415
+
416
+
417
+ def dequantize(
418
+ x: torch.Tensor,
419
+ params: Dict[str, nn.Parameter],
420
+ func: nn.Module,
421
+ bits: int,
422
+ out_dtype: torch.dtype
423
+ ) -> torch.Tensor:
424
+ x = x.to(dtype=out_dtype)
425
+ x = x.transpose(0, 1)
426
+ x = func(x=x, **params)
427
+ x = x.transpose(0, 1)
428
+ return x
429
+
430
+
431
+ def round_func_BPDA(input):
432
+ # This is equivalent to replacing round function (non-differentiable) with
433
+ # an identity function (differentiable) only when backward.
434
+ forward_value = torch.round(input)
435
+ out = input.clone()
436
+ out.data = forward_value.data
437
+ return out
438
+
439
+
440
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
441
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
442
+
443
+
444
+
445
+ ############## Numpy ###############
446
+
447
+ def np_domain_guard(
448
+ x: np.ndarray,
449
+ min: float = None,
450
+ max: float = None,
451
+ posinf: float = None,
452
+ neginf: float = None,
453
+ nan: float = None
454
+ ) -> np.ndarray:
455
+ """Guard a tensor to a valid domain."""
456
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
457
+ if min is not None or max is not None:
458
+ x = np.clip(x, min, max)
459
+ return x
460
+
461
+
462
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
463
+ """Replace a number in a tensor with another number.
464
+
465
+ Args:
466
+ x (np.ndarray): The input tensor.
467
+ num (float): The number to replace.
468
+ to (float): The number to replace with.
469
+
470
+ Returns:
471
+ np.ndarray: The tensor with the number replaced.
472
+ """
473
+ return np.where(x == num, to, x)
474
+
475
+
476
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
477
+ """Guard the power operation to a valid domain."""
478
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
479
+
fn_gen/nlr_t_no_sched/14/loss.png ADDED
fn_gen/nlr_t_no_sched/14/quantization.png ADDED
fn_gen/nlr_t_no_sched/16/distortion.png ADDED
fn_gen/nlr_t_no_sched/16/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tan(_0*x)/_s
2
+ atan(_s*x)/_0
fn_gen/nlr_t_no_sched/16/fn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+ params = learn_parameters(x, params,
36
+ qtz_func=quantization,
37
+ deqtz_func=dequantization,
38
+ bits=kwargs['bits'],
39
+ target_dtype=torch.int8,
40
+ epochs=500,
41
+ early_stop=False,
42
+ )
43
+ if 'post_train_hook' in kwargs:
44
+ kwargs['post_train_hook'](parameters=params)
45
+
46
+ return params
47
+
48
+
49
+ ############### Numpy Qtz ###############
50
+
51
+
52
+ def np_quantization(x, _0, _s):
53
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0)))
54
+
55
+
56
+ def np_dequantization(x, _0, _s):
57
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x)))
58
+
59
+
60
+ def fit_func(x, _0, _s):
61
+ x_ = np_quantization(x, _0, _s)
62
+ x_ = np_dequantization(x_, _0, _s)
63
+ return x_
64
+
65
+
66
+
67
+ ############### HELPERS ###############
68
+
69
+ def domain_guard(
70
+ x: torch.Tensor,
71
+ min: float = None,
72
+ max: float = None,
73
+ posinf: float = None,
74
+ neginf: float = None,
75
+ nan: float = None
76
+ ) -> torch.Tensor:
77
+ """Guard a tensor to a valid domain."""
78
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
79
+ if min is not None or max is not None:
80
+ x = torch.clamp(x, min=min, max=max)
81
+ return x
82
+
83
+
84
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
85
+ """Replace a number in a tensor with another number.
86
+
87
+ Args:
88
+ x (torch.Tensor): The input tensor.
89
+ num (float): The number to replace.
90
+ to (float): The number to replace with.
91
+
92
+ Returns:
93
+ torch.Tensor: The tensor with the number replaced.
94
+ """
95
+ return torch.where(x == num, to, x)
96
+
97
+
98
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
99
+ """Guard the power operation to a valid domain."""
100
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
101
+
102
+
103
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
104
+ val = torch.amin(x, dim=1)
105
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
106
+
107
+
108
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
109
+ val = torch.amin(x, dim=1)
110
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
111
+
112
+
113
+ def init_space_search(
114
+ x: torch.Tensor,
115
+ **kwargs: Dict[str, Any],
116
+ ) -> torch.Tensor:
117
+
118
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
119
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
120
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
121
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
122
+
123
+ def _search_param(tensors: List[torch.tensor], n_params):
124
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
125
+ torch_tensors = torch.stack(tensors)
126
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
127
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
128
+ mean = torch.mean(torch_tensors, dim=0)
129
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
130
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
131
+
132
+ def _calc(x, qtz_func, deqtz_func, **params):
133
+ x_ = x.transpose(0, 1)
134
+ x_ = qtz_func(x=x_, **params)
135
+ x_ = deqtz_func(x=x_, **params)
136
+ x_ = x_.transpose(0, 1)
137
+ return x_
138
+
139
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
140
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
141
+ assert "params_list" in kwargs, "params list must be provided."
142
+ assert "param" in kwargs, "param must be provided."
143
+
144
+ qtz_func = kwargs.get('qtz_func')
145
+ deqtz_func = kwargs.get('deqtz_func')
146
+ params_list = kwargs.get('params_list')
147
+ param = kwargs.get('param')
148
+
149
+ n_runs = 50 # Number of runs to try to find the best parameters
150
+ n_random_params = 50 # Number of random parameters to generate
151
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
152
+ max_initial = 10000 # Maximum value to initialize the parameters
153
+
154
+ # Initializes the parameters
155
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
156
+ params = _build_initial_param(x, max_initial, n_random_params)
157
+
158
+ # Performs the search
159
+ for _ in range(n_runs):
160
+
161
+ best_params = []
162
+ for param_ in params:
163
+ try:
164
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
165
+ loss_ones = nn.MSELoss()(x, x_)
166
+
167
+ if len(best_params) < n_best_to_pick:
168
+ best_params.append((param_, loss_ones.item()))
169
+ best_params = sorted(best_params, key=lambda x: x[1])
170
+ elif loss_ones < best_params[-1][1]:
171
+ best_params[-1] = (param_, loss_ones.item())
172
+ best_params = sorted(best_params, key=lambda x: x[1])
173
+
174
+ except Exception: # The parameters might not be valid for the function's domain
175
+ continue
176
+
177
+ # Generates new parameters around the mean
178
+ params = _search_param([p for p, _ in best_params], n_random_params)
179
+
180
+ # Checks if the best parameter is better than the init_ones
181
+ p_ones = init_ones(x, **kwargs)
182
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
183
+ loss_ones = nn.MSELoss()(x, x_)
184
+
185
+ # Checks if the best parameter is better than the init_rand
186
+ p_rand = init_rand(x, **kwargs)
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
188
+ loss_rand = nn.MSELoss()(x, x_)
189
+
190
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
191
+ return p_rand
192
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
193
+ return p_ones
194
+ else:
195
+ return best_params[0][0]
196
+
197
+
198
+ def init_linear_scale( # Symmetric scale. From the study folder
199
+ x: torch.Tensor,
200
+ **kwargs: Dict[str, Any],
201
+ ) -> torch.Tensor:
202
+ assert "bits" in kwargs, "bits must be provided."
203
+ assert "params" in kwargs, "params must be provided."
204
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
205
+
206
+ bits = kwargs.get('bits')
207
+ params = kwargs.get('params')
208
+ qtz_func = kwargs.get('qtz_func')
209
+
210
+ x_ = x.transpose(0, 1)
211
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
212
+ x_ = x_.transpose(0, 1)
213
+
214
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
215
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
216
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
217
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
218
+
219
+ eps = torch.finfo(torch.float32).eps
220
+
221
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
222
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
223
+
224
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
225
+
226
+ # Introduces some noise in scale
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
228
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
229
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
230
+ # left it here for future reference. Will be removed later.
231
+ # scale = scale + 0.01 * torch.randn_like(scale)
232
+
233
+ return scale
234
+
235
+
236
+ def init_non_linear_regression_fit(
237
+ x: torch.Tensor,
238
+ **kwargs: Dict[str, Any],
239
+ ) -> torch.Tensor:
240
+
241
+ assert "params_list" in kwargs, "params list must be provided."
242
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
243
+ assert "p0" in kwargs, "p0 must be provided."
244
+ np_fit_func = kwargs.get('np_fit_func')
245
+ params_list = kwargs.get('params_list')
246
+ p0 = kwargs.get('p0')
247
+
248
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
249
+ popt, _ = curve_fit(
250
+ func,
251
+ xdata,
252
+ ydata,
253
+ maxfev=1000,
254
+ p0=p0,
255
+ method='lm'
256
+ )
257
+ return popt
258
+
259
+ # 1. Needs to convert the torch tensor to numpy tensor
260
+ xdata = x.cpu().numpy()
261
+
262
+ # 2. Sorts the data so that it makes it easier to fit to it
263
+ sorted_xdata = np.sort(xdata, axis=-1)
264
+
265
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
266
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
267
+
268
+ # 3. Finds the best parameters for each channel
269
+ try:
270
+ params = []
271
+ for i in range(sorted_xdata.shape[0]):
272
+ xdata_ = sorted_xdata[i]
273
+ p0_ = [p0[p][i] for p in params_list]
274
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
275
+ params.append(ch_params)
276
+
277
+ # 4. Builds the parameters
278
+ result = {}
279
+ for i, p in enumerate(params_list):
280
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
281
+
282
+ return result
283
+
284
+ except ValueError as e:
285
+ print(f"Could not fit the function with error: {e}")
286
+ print(f"Using fallback result...")
287
+ return {
288
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
289
+ }
290
+
291
+
292
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
293
+ val = torch.amin(x, dim=1)
294
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
295
+
296
+
297
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
298
+ # Calculate the original minimum and maximum values
299
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
300
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
301
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
302
+
303
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
304
+ return torch.ones_like(x_min)
305
+
306
+ # Calculate the scale factor
307
+ scale = (_max - _min) / (x_max - x_min)
308
+ return scale
309
+
310
+
311
+
312
+ ############## Quant ###############
313
+
314
+ @torch.enable_grad()
315
+ def learn_parameters(
316
+ x: torch.Tensor,
317
+ params: Dict[str, nn.Parameter],
318
+ qtz_func: nn.Module,
319
+ deqtz_func: nn.Module,
320
+ bits: int,
321
+ target_dtype: torch.dtype,
322
+ epochs: int = 1000,
323
+ early_stop: bool = True,
324
+ do_report: bool = False
325
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
326
+ loss_fn = nn.MSELoss()
327
+
328
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
329
+ # the order of magnitude of the loss divided by 2
330
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
331
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
332
+ loss = loss_fn(x, dequant)
333
+
334
+ base_lr = 0.1
335
+ exponent = int(np.floor(np.log10(loss.item())))
336
+ lr = base_lr * (10 ** (exponent // 2))
337
+
338
+ # Requires gradients in the parameters
339
+ for p in params.values():
340
+ p.requires_grad = True
341
+ p.grad = None
342
+
343
+ param_keys = list(params.keys())
344
+ param_values = list(params.values())
345
+
346
+ # Defines optimizer and loss function
347
+ optimizer = torch.optim.Adam(param_values, lr=lr)
348
+
349
+ # Contains the best loss and the best parameters
350
+ best_loss = float("inf")
351
+ best_params = None
352
+
353
+ # Used to stop the search early
354
+ min_delta = 1e-7
355
+ acc_loss = []
356
+ percent_epochs_before_stop = 0.1
357
+
358
+ for i in range(epochs):
359
+ optimizer.zero_grad()
360
+
361
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
362
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
363
+ loss = loss_fn(x, dequant)
364
+
365
+ if loss.isnan() or loss.isinf():
366
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
367
+
368
+ loss.backward()
369
+ optimizer.step()
370
+
371
+ acc_loss.append(loss.item())
372
+
373
+ # Reports loss every 10 steps
374
+ if i % 10 == 0 and do_report:
375
+ print(f"Epoch {i}: Loss {loss.item()}")
376
+
377
+ # Optimizes the parameter search by storing the best loss and the parameters
378
+ if loss.item() < best_loss:
379
+ best_loss = loss.item()
380
+ best_params = copy.deepcopy({
381
+ k: v for k, v in params.items() if k in param_keys
382
+ })
383
+
384
+ # We also stop the search if the loss has not considerably during the last 10% epochs
385
+ if early_stop:
386
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
387
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
388
+ break
389
+
390
+ # No longer requires gradients in the parameters
391
+ for p in best_params.values():
392
+ p.requires_grad = False
393
+ p.grad = None
394
+
395
+ if do_report:
396
+ print(f"Best loss: {best_loss}")
397
+ return best_params, acc_loss
398
+ else:
399
+ return best_params
400
+
401
+
402
+ def quantize(
403
+ x: torch.Tensor,
404
+ params: Dict[str, nn.Parameter],
405
+ func: nn.Module,
406
+ bits: int,
407
+ target_dtype: torch.dtype = torch.int8
408
+ ) -> torch.Tensor:
409
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
410
+ x = x.transpose(0, 1) # Aligns shapes
411
+ x = func(x=x, **params)
412
+ x = x.transpose(0, 1)
413
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
414
+ return x
415
+
416
+
417
+ def dequantize(
418
+ x: torch.Tensor,
419
+ params: Dict[str, nn.Parameter],
420
+ func: nn.Module,
421
+ bits: int,
422
+ out_dtype: torch.dtype
423
+ ) -> torch.Tensor:
424
+ x = x.to(dtype=out_dtype)
425
+ x = x.transpose(0, 1)
426
+ x = func(x=x, **params)
427
+ x = x.transpose(0, 1)
428
+ return x
429
+
430
+
431
+ def round_func_BPDA(input):
432
+ # This is equivalent to replacing round function (non-differentiable) with
433
+ # an identity function (differentiable) only when backward.
434
+ forward_value = torch.round(input)
435
+ out = input.clone()
436
+ out.data = forward_value.data
437
+ return out
438
+
439
+
440
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
441
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
442
+
443
+
444
+
445
+ ############## Numpy ###############
446
+
447
+ def np_domain_guard(
448
+ x: np.ndarray,
449
+ min: float = None,
450
+ max: float = None,
451
+ posinf: float = None,
452
+ neginf: float = None,
453
+ nan: float = None
454
+ ) -> np.ndarray:
455
+ """Guard a tensor to a valid domain."""
456
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
457
+ if min is not None or max is not None:
458
+ x = np.clip(x, min, max)
459
+ return x
460
+
461
+
462
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
463
+ """Replace a number in a tensor with another number.
464
+
465
+ Args:
466
+ x (np.ndarray): The input tensor.
467
+ num (float): The number to replace.
468
+ to (float): The number to replace with.
469
+
470
+ Returns:
471
+ np.ndarray: The tensor with the number replaced.
472
+ """
473
+ return np.where(x == num, to, x)
474
+
475
+
476
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
477
+ """Guard the power operation to a valid domain."""
478
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
479
+
fn_gen/nlr_t_no_sched/16/loss.png ADDED
fn_gen/nlr_t_no_sched/16/quantization.png ADDED
fn_gen/nlr_t_no_sched/17/distortion.png ADDED
fn_gen/nlr_t_no_sched/17/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atanh(_0*x)/_s
2
+ tanh(_s*x)/_0
fn_gen/nlr_t_no_sched/17/fn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+ params = learn_parameters(x, params,
36
+ qtz_func=quantization,
37
+ deqtz_func=dequantization,
38
+ bits=kwargs['bits'],
39
+ target_dtype=torch.int8,
40
+ epochs=500,
41
+ early_stop=False,
42
+ )
43
+ if 'post_train_hook' in kwargs:
44
+ kwargs['post_train_hook'](parameters=params)
45
+
46
+ return params
47
+
48
+
49
+ ############### Numpy Qtz ###############
50
+
51
+
52
+ def np_quantization(x, _0, _s):
53
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0)))
54
+
55
+
56
+ def np_dequantization(x, _0, _s):
57
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x)))
58
+
59
+
60
+ def fit_func(x, _0, _s):
61
+ x_ = np_quantization(x, _0, _s)
62
+ x_ = np_dequantization(x_, _0, _s)
63
+ return x_
64
+
65
+
66
+
67
+ ############### HELPERS ###############
68
+
69
+ def domain_guard(
70
+ x: torch.Tensor,
71
+ min: float = None,
72
+ max: float = None,
73
+ posinf: float = None,
74
+ neginf: float = None,
75
+ nan: float = None
76
+ ) -> torch.Tensor:
77
+ """Guard a tensor to a valid domain."""
78
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
79
+ if min is not None or max is not None:
80
+ x = torch.clamp(x, min=min, max=max)
81
+ return x
82
+
83
+
84
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
85
+ """Replace a number in a tensor with another number.
86
+
87
+ Args:
88
+ x (torch.Tensor): The input tensor.
89
+ num (float): The number to replace.
90
+ to (float): The number to replace with.
91
+
92
+ Returns:
93
+ torch.Tensor: The tensor with the number replaced.
94
+ """
95
+ return torch.where(x == num, to, x)
96
+
97
+
98
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
99
+ """Guard the power operation to a valid domain."""
100
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
101
+
102
+
103
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
104
+ val = torch.amin(x, dim=1)
105
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
106
+
107
+
108
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
109
+ val = torch.amin(x, dim=1)
110
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
111
+
112
+
113
+ def init_space_search(
114
+ x: torch.Tensor,
115
+ **kwargs: Dict[str, Any],
116
+ ) -> torch.Tensor:
117
+
118
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
119
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
120
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
121
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
122
+
123
+ def _search_param(tensors: List[torch.tensor], n_params):
124
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
125
+ torch_tensors = torch.stack(tensors)
126
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
127
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
128
+ mean = torch.mean(torch_tensors, dim=0)
129
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
130
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
131
+
132
+ def _calc(x, qtz_func, deqtz_func, **params):
133
+ x_ = x.transpose(0, 1)
134
+ x_ = qtz_func(x=x_, **params)
135
+ x_ = deqtz_func(x=x_, **params)
136
+ x_ = x_.transpose(0, 1)
137
+ return x_
138
+
139
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
140
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
141
+ assert "params_list" in kwargs, "params list must be provided."
142
+ assert "param" in kwargs, "param must be provided."
143
+
144
+ qtz_func = kwargs.get('qtz_func')
145
+ deqtz_func = kwargs.get('deqtz_func')
146
+ params_list = kwargs.get('params_list')
147
+ param = kwargs.get('param')
148
+
149
+ n_runs = 50 # Number of runs to try to find the best parameters
150
+ n_random_params = 50 # Number of random parameters to generate
151
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
152
+ max_initial = 10000 # Maximum value to initialize the parameters
153
+
154
+ # Initializes the parameters
155
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
156
+ params = _build_initial_param(x, max_initial, n_random_params)
157
+
158
+ # Performs the search
159
+ for _ in range(n_runs):
160
+
161
+ best_params = []
162
+ for param_ in params:
163
+ try:
164
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
165
+ loss_ones = nn.MSELoss()(x, x_)
166
+
167
+ if len(best_params) < n_best_to_pick:
168
+ best_params.append((param_, loss_ones.item()))
169
+ best_params = sorted(best_params, key=lambda x: x[1])
170
+ elif loss_ones < best_params[-1][1]:
171
+ best_params[-1] = (param_, loss_ones.item())
172
+ best_params = sorted(best_params, key=lambda x: x[1])
173
+
174
+ except Exception: # The parameters might not be valid for the function's domain
175
+ continue
176
+
177
+ # Generates new parameters around the mean
178
+ params = _search_param([p for p, _ in best_params], n_random_params)
179
+
180
+ # Checks if the best parameter is better than the init_ones
181
+ p_ones = init_ones(x, **kwargs)
182
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
183
+ loss_ones = nn.MSELoss()(x, x_)
184
+
185
+ # Checks if the best parameter is better than the init_rand
186
+ p_rand = init_rand(x, **kwargs)
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
188
+ loss_rand = nn.MSELoss()(x, x_)
189
+
190
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
191
+ return p_rand
192
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
193
+ return p_ones
194
+ else:
195
+ return best_params[0][0]
196
+
197
+
198
+ def init_linear_scale( # Symmetric scale. From the study folder
199
+ x: torch.Tensor,
200
+ **kwargs: Dict[str, Any],
201
+ ) -> torch.Tensor:
202
+ assert "bits" in kwargs, "bits must be provided."
203
+ assert "params" in kwargs, "params must be provided."
204
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
205
+
206
+ bits = kwargs.get('bits')
207
+ params = kwargs.get('params')
208
+ qtz_func = kwargs.get('qtz_func')
209
+
210
+ x_ = x.transpose(0, 1)
211
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
212
+ x_ = x_.transpose(0, 1)
213
+
214
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
215
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
216
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
217
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
218
+
219
+ eps = torch.finfo(torch.float32).eps
220
+
221
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
222
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
223
+
224
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
225
+
226
+ # Introduces some noise in scale
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
228
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
229
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
230
+ # left it here for future reference. Will be removed later.
231
+ # scale = scale + 0.01 * torch.randn_like(scale)
232
+
233
+ return scale
234
+
235
+
236
+ def init_non_linear_regression_fit(
237
+ x: torch.Tensor,
238
+ **kwargs: Dict[str, Any],
239
+ ) -> torch.Tensor:
240
+
241
+ assert "params_list" in kwargs, "params list must be provided."
242
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
243
+ assert "p0" in kwargs, "p0 must be provided."
244
+ np_fit_func = kwargs.get('np_fit_func')
245
+ params_list = kwargs.get('params_list')
246
+ p0 = kwargs.get('p0')
247
+
248
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
249
+ popt, _ = curve_fit(
250
+ func,
251
+ xdata,
252
+ ydata,
253
+ maxfev=1000,
254
+ p0=p0,
255
+ method='lm'
256
+ )
257
+ return popt
258
+
259
+ # 1. Needs to convert the torch tensor to numpy tensor
260
+ xdata = x.cpu().numpy()
261
+
262
+ # 2. Sorts the data so that it makes it easier to fit to it
263
+ sorted_xdata = np.sort(xdata, axis=-1)
264
+
265
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
266
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
267
+
268
+ # 3. Finds the best parameters for each channel
269
+ try:
270
+ params = []
271
+ for i in range(sorted_xdata.shape[0]):
272
+ xdata_ = sorted_xdata[i]
273
+ p0_ = [p0[p][i] for p in params_list]
274
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
275
+ params.append(ch_params)
276
+
277
+ # 4. Builds the parameters
278
+ result = {}
279
+ for i, p in enumerate(params_list):
280
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
281
+
282
+ return result
283
+
284
+ except ValueError as e:
285
+ print(f"Could not fit the function with error: {e}")
286
+ print(f"Using fallback result...")
287
+ return {
288
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
289
+ }
290
+
291
+
292
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
293
+ val = torch.amin(x, dim=1)
294
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
295
+
296
+
297
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
298
+ # Calculate the original minimum and maximum values
299
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
300
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
301
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
302
+
303
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
304
+ return torch.ones_like(x_min)
305
+
306
+ # Calculate the scale factor
307
+ scale = (_max - _min) / (x_max - x_min)
308
+ return scale
309
+
310
+
311
+
312
+ ############## Quant ###############
313
+
314
+ @torch.enable_grad()
315
+ def learn_parameters(
316
+ x: torch.Tensor,
317
+ params: Dict[str, nn.Parameter],
318
+ qtz_func: nn.Module,
319
+ deqtz_func: nn.Module,
320
+ bits: int,
321
+ target_dtype: torch.dtype,
322
+ epochs: int = 1000,
323
+ early_stop: bool = True,
324
+ do_report: bool = False
325
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
326
+ loss_fn = nn.MSELoss()
327
+
328
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
329
+ # the order of magnitude of the loss divided by 2
330
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
331
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
332
+ loss = loss_fn(x, dequant)
333
+
334
+ base_lr = 0.1
335
+ exponent = int(np.floor(np.log10(loss.item())))
336
+ lr = base_lr * (10 ** (exponent // 2))
337
+
338
+ # Requires gradients in the parameters
339
+ for p in params.values():
340
+ p.requires_grad = True
341
+ p.grad = None
342
+
343
+ param_keys = list(params.keys())
344
+ param_values = list(params.values())
345
+
346
+ # Defines optimizer and loss function
347
+ optimizer = torch.optim.Adam(param_values, lr=lr)
348
+
349
+ # Contains the best loss and the best parameters
350
+ best_loss = float("inf")
351
+ best_params = None
352
+
353
+ # Used to stop the search early
354
+ min_delta = 1e-7
355
+ acc_loss = []
356
+ percent_epochs_before_stop = 0.1
357
+
358
+ for i in range(epochs):
359
+ optimizer.zero_grad()
360
+
361
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
362
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
363
+ loss = loss_fn(x, dequant)
364
+
365
+ if loss.isnan() or loss.isinf():
366
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
367
+
368
+ loss.backward()
369
+ optimizer.step()
370
+
371
+ acc_loss.append(loss.item())
372
+
373
+ # Reports loss every 10 steps
374
+ if i % 10 == 0 and do_report:
375
+ print(f"Epoch {i}: Loss {loss.item()}")
376
+
377
+ # Optimizes the parameter search by storing the best loss and the parameters
378
+ if loss.item() < best_loss:
379
+ best_loss = loss.item()
380
+ best_params = copy.deepcopy({
381
+ k: v for k, v in params.items() if k in param_keys
382
+ })
383
+
384
+ # We also stop the search if the loss has not considerably during the last 10% epochs
385
+ if early_stop:
386
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
387
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
388
+ break
389
+
390
+ # No longer requires gradients in the parameters
391
+ for p in best_params.values():
392
+ p.requires_grad = False
393
+ p.grad = None
394
+
395
+ if do_report:
396
+ print(f"Best loss: {best_loss}")
397
+ return best_params, acc_loss
398
+ else:
399
+ return best_params
400
+
401
+
402
+ def quantize(
403
+ x: torch.Tensor,
404
+ params: Dict[str, nn.Parameter],
405
+ func: nn.Module,
406
+ bits: int,
407
+ target_dtype: torch.dtype = torch.int8
408
+ ) -> torch.Tensor:
409
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
410
+ x = x.transpose(0, 1) # Aligns shapes
411
+ x = func(x=x, **params)
412
+ x = x.transpose(0, 1)
413
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
414
+ return x
415
+
416
+
417
+ def dequantize(
418
+ x: torch.Tensor,
419
+ params: Dict[str, nn.Parameter],
420
+ func: nn.Module,
421
+ bits: int,
422
+ out_dtype: torch.dtype
423
+ ) -> torch.Tensor:
424
+ x = x.to(dtype=out_dtype)
425
+ x = x.transpose(0, 1)
426
+ x = func(x=x, **params)
427
+ x = x.transpose(0, 1)
428
+ return x
429
+
430
+
431
+ def round_func_BPDA(input):
432
+ # This is equivalent to replacing round function (non-differentiable) with
433
+ # an identity function (differentiable) only when backward.
434
+ forward_value = torch.round(input)
435
+ out = input.clone()
436
+ out.data = forward_value.data
437
+ return out
438
+
439
+
440
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
441
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
442
+
443
+
444
+
445
+ ############## Numpy ###############
446
+
447
+ def np_domain_guard(
448
+ x: np.ndarray,
449
+ min: float = None,
450
+ max: float = None,
451
+ posinf: float = None,
452
+ neginf: float = None,
453
+ nan: float = None
454
+ ) -> np.ndarray:
455
+ """Guard a tensor to a valid domain."""
456
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
457
+ if min is not None or max is not None:
458
+ x = np.clip(x, min, max)
459
+ return x
460
+
461
+
462
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
463
+ """Replace a number in a tensor with another number.
464
+
465
+ Args:
466
+ x (np.ndarray): The input tensor.
467
+ num (float): The number to replace.
468
+ to (float): The number to replace with.
469
+
470
+ Returns:
471
+ np.ndarray: The tensor with the number replaced.
472
+ """
473
+ return np.where(x == num, to, x)
474
+
475
+
476
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
477
+ """Guard the power operation to a valid domain."""
478
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
479
+
fn_gen/nlr_t_no_sched/17/loss.png ADDED
fn_gen/nlr_t_no_sched/17/quantization.png ADDED
fn_gen/nlr_t_no_sched/18/distortion.png ADDED
fn_gen/nlr_t_no_sched/18/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acos(_0*x)/_s
2
+ cos(_s*x)/_0
fn_gen/nlr_t_no_sched/18/fn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+ params = learn_parameters(x, params,
36
+ qtz_func=quantization,
37
+ deqtz_func=dequantization,
38
+ bits=kwargs['bits'],
39
+ target_dtype=torch.int8,
40
+ epochs=500,
41
+ early_stop=False,
42
+ )
43
+ if 'post_train_hook' in kwargs:
44
+ kwargs['post_train_hook'](parameters=params)
45
+
46
+ return params
47
+
48
+
49
+ ############### Numpy Qtz ###############
50
+
51
+
52
+ def np_quantization(x, _0, _s):
53
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
54
+
55
+
56
+ def np_dequantization(x, _0, _s):
57
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x)))
58
+
59
+
60
+ def fit_func(x, _0, _s):
61
+ x_ = np_quantization(x, _0, _s)
62
+ x_ = np_dequantization(x_, _0, _s)
63
+ return x_
64
+
65
+
66
+
67
+ ############### HELPERS ###############
68
+
69
+ def domain_guard(
70
+ x: torch.Tensor,
71
+ min: float = None,
72
+ max: float = None,
73
+ posinf: float = None,
74
+ neginf: float = None,
75
+ nan: float = None
76
+ ) -> torch.Tensor:
77
+ """Guard a tensor to a valid domain."""
78
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
79
+ if min is not None or max is not None:
80
+ x = torch.clamp(x, min=min, max=max)
81
+ return x
82
+
83
+
84
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
85
+ """Replace a number in a tensor with another number.
86
+
87
+ Args:
88
+ x (torch.Tensor): The input tensor.
89
+ num (float): The number to replace.
90
+ to (float): The number to replace with.
91
+
92
+ Returns:
93
+ torch.Tensor: The tensor with the number replaced.
94
+ """
95
+ return torch.where(x == num, to, x)
96
+
97
+
98
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
99
+ """Guard the power operation to a valid domain."""
100
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
101
+
102
+
103
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
104
+ val = torch.amin(x, dim=1)
105
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
106
+
107
+
108
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
109
+ val = torch.amin(x, dim=1)
110
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
111
+
112
+
113
+ def init_space_search(
114
+ x: torch.Tensor,
115
+ **kwargs: Dict[str, Any],
116
+ ) -> torch.Tensor:
117
+
118
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
119
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
120
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
121
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
122
+
123
+ def _search_param(tensors: List[torch.tensor], n_params):
124
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
125
+ torch_tensors = torch.stack(tensors)
126
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
127
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
128
+ mean = torch.mean(torch_tensors, dim=0)
129
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
130
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
131
+
132
+ def _calc(x, qtz_func, deqtz_func, **params):
133
+ x_ = x.transpose(0, 1)
134
+ x_ = qtz_func(x=x_, **params)
135
+ x_ = deqtz_func(x=x_, **params)
136
+ x_ = x_.transpose(0, 1)
137
+ return x_
138
+
139
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
140
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
141
+ assert "params_list" in kwargs, "params list must be provided."
142
+ assert "param" in kwargs, "param must be provided."
143
+
144
+ qtz_func = kwargs.get('qtz_func')
145
+ deqtz_func = kwargs.get('deqtz_func')
146
+ params_list = kwargs.get('params_list')
147
+ param = kwargs.get('param')
148
+
149
+ n_runs = 50 # Number of runs to try to find the best parameters
150
+ n_random_params = 50 # Number of random parameters to generate
151
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
152
+ max_initial = 10000 # Maximum value to initialize the parameters
153
+
154
+ # Initializes the parameters
155
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
156
+ params = _build_initial_param(x, max_initial, n_random_params)
157
+
158
+ # Performs the search
159
+ for _ in range(n_runs):
160
+
161
+ best_params = []
162
+ for param_ in params:
163
+ try:
164
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
165
+ loss_ones = nn.MSELoss()(x, x_)
166
+
167
+ if len(best_params) < n_best_to_pick:
168
+ best_params.append((param_, loss_ones.item()))
169
+ best_params = sorted(best_params, key=lambda x: x[1])
170
+ elif loss_ones < best_params[-1][1]:
171
+ best_params[-1] = (param_, loss_ones.item())
172
+ best_params = sorted(best_params, key=lambda x: x[1])
173
+
174
+ except Exception: # The parameters might not be valid for the function's domain
175
+ continue
176
+
177
+ # Generates new parameters around the mean
178
+ params = _search_param([p for p, _ in best_params], n_random_params)
179
+
180
+ # Checks if the best parameter is better than the init_ones
181
+ p_ones = init_ones(x, **kwargs)
182
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
183
+ loss_ones = nn.MSELoss()(x, x_)
184
+
185
+ # Checks if the best parameter is better than the init_rand
186
+ p_rand = init_rand(x, **kwargs)
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
188
+ loss_rand = nn.MSELoss()(x, x_)
189
+
190
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
191
+ return p_rand
192
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
193
+ return p_ones
194
+ else:
195
+ return best_params[0][0]
196
+
197
+
198
+ def init_linear_scale( # Symmetric scale. From the study folder
199
+ x: torch.Tensor,
200
+ **kwargs: Dict[str, Any],
201
+ ) -> torch.Tensor:
202
+ assert "bits" in kwargs, "bits must be provided."
203
+ assert "params" in kwargs, "params must be provided."
204
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
205
+
206
+ bits = kwargs.get('bits')
207
+ params = kwargs.get('params')
208
+ qtz_func = kwargs.get('qtz_func')
209
+
210
+ x_ = x.transpose(0, 1)
211
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
212
+ x_ = x_.transpose(0, 1)
213
+
214
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
215
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
216
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
217
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
218
+
219
+ eps = torch.finfo(torch.float32).eps
220
+
221
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
222
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
223
+
224
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
225
+
226
+ # Introduces some noise in scale
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
228
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
229
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
230
+ # left it here for future reference. Will be removed later.
231
+ # scale = scale + 0.01 * torch.randn_like(scale)
232
+
233
+ return scale
234
+
235
+
236
+ def init_non_linear_regression_fit(
237
+ x: torch.Tensor,
238
+ **kwargs: Dict[str, Any],
239
+ ) -> torch.Tensor:
240
+
241
+ assert "params_list" in kwargs, "params list must be provided."
242
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
243
+ assert "p0" in kwargs, "p0 must be provided."
244
+ np_fit_func = kwargs.get('np_fit_func')
245
+ params_list = kwargs.get('params_list')
246
+ p0 = kwargs.get('p0')
247
+
248
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
249
+ popt, _ = curve_fit(
250
+ func,
251
+ xdata,
252
+ ydata,
253
+ maxfev=1000,
254
+ p0=p0,
255
+ method='lm'
256
+ )
257
+ return popt
258
+
259
+ # 1. Needs to convert the torch tensor to numpy tensor
260
+ xdata = x.cpu().numpy()
261
+
262
+ # 2. Sorts the data so that it makes it easier to fit to it
263
+ sorted_xdata = np.sort(xdata, axis=-1)
264
+
265
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
266
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
267
+
268
+ # 3. Finds the best parameters for each channel
269
+ try:
270
+ params = []
271
+ for i in range(sorted_xdata.shape[0]):
272
+ xdata_ = sorted_xdata[i]
273
+ p0_ = [p0[p][i] for p in params_list]
274
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
275
+ params.append(ch_params)
276
+
277
+ # 4. Builds the parameters
278
+ result = {}
279
+ for i, p in enumerate(params_list):
280
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
281
+
282
+ return result
283
+
284
+ except ValueError as e:
285
+ print(f"Could not fit the function with error: {e}")
286
+ print(f"Using fallback result...")
287
+ return {
288
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
289
+ }
290
+
291
+
292
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
293
+ val = torch.amin(x, dim=1)
294
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
295
+
296
+
297
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
298
+ # Calculate the original minimum and maximum values
299
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
300
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
301
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
302
+
303
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
304
+ return torch.ones_like(x_min)
305
+
306
+ # Calculate the scale factor
307
+ scale = (_max - _min) / (x_max - x_min)
308
+ return scale
309
+
310
+
311
+
312
+ ############## Quant ###############
313
+
314
+ @torch.enable_grad()
315
+ def learn_parameters(
316
+ x: torch.Tensor,
317
+ params: Dict[str, nn.Parameter],
318
+ qtz_func: nn.Module,
319
+ deqtz_func: nn.Module,
320
+ bits: int,
321
+ target_dtype: torch.dtype,
322
+ epochs: int = 1000,
323
+ early_stop: bool = True,
324
+ do_report: bool = False
325
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
326
+ loss_fn = nn.MSELoss()
327
+
328
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
329
+ # the order of magnitude of the loss divided by 2
330
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
331
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
332
+ loss = loss_fn(x, dequant)
333
+
334
+ base_lr = 0.1
335
+ exponent = int(np.floor(np.log10(loss.item())))
336
+ lr = base_lr * (10 ** (exponent // 2))
337
+
338
+ # Requires gradients in the parameters
339
+ for p in params.values():
340
+ p.requires_grad = True
341
+ p.grad = None
342
+
343
+ param_keys = list(params.keys())
344
+ param_values = list(params.values())
345
+
346
+ # Defines optimizer and loss function
347
+ optimizer = torch.optim.Adam(param_values, lr=lr)
348
+
349
+ # Contains the best loss and the best parameters
350
+ best_loss = float("inf")
351
+ best_params = None
352
+
353
+ # Used to stop the search early
354
+ min_delta = 1e-7
355
+ acc_loss = []
356
+ percent_epochs_before_stop = 0.1
357
+
358
+ for i in range(epochs):
359
+ optimizer.zero_grad()
360
+
361
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
362
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
363
+ loss = loss_fn(x, dequant)
364
+
365
+ if loss.isnan() or loss.isinf():
366
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
367
+
368
+ loss.backward()
369
+ optimizer.step()
370
+
371
+ acc_loss.append(loss.item())
372
+
373
+ # Reports loss every 10 steps
374
+ if i % 10 == 0 and do_report:
375
+ print(f"Epoch {i}: Loss {loss.item()}")
376
+
377
+ # Optimizes the parameter search by storing the best loss and the parameters
378
+ if loss.item() < best_loss:
379
+ best_loss = loss.item()
380
+ best_params = copy.deepcopy({
381
+ k: v for k, v in params.items() if k in param_keys
382
+ })
383
+
384
+ # We also stop the search if the loss has not considerably during the last 10% epochs
385
+ if early_stop:
386
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
387
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
388
+ break
389
+
390
+ # No longer requires gradients in the parameters
391
+ for p in best_params.values():
392
+ p.requires_grad = False
393
+ p.grad = None
394
+
395
+ if do_report:
396
+ print(f"Best loss: {best_loss}")
397
+ return best_params, acc_loss
398
+ else:
399
+ return best_params
400
+
401
+
402
+ def quantize(
403
+ x: torch.Tensor,
404
+ params: Dict[str, nn.Parameter],
405
+ func: nn.Module,
406
+ bits: int,
407
+ target_dtype: torch.dtype = torch.int8
408
+ ) -> torch.Tensor:
409
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
410
+ x = x.transpose(0, 1) # Aligns shapes
411
+ x = func(x=x, **params)
412
+ x = x.transpose(0, 1)
413
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
414
+ return x
415
+
416
+
417
+ def dequantize(
418
+ x: torch.Tensor,
419
+ params: Dict[str, nn.Parameter],
420
+ func: nn.Module,
421
+ bits: int,
422
+ out_dtype: torch.dtype
423
+ ) -> torch.Tensor:
424
+ x = x.to(dtype=out_dtype)
425
+ x = x.transpose(0, 1)
426
+ x = func(x=x, **params)
427
+ x = x.transpose(0, 1)
428
+ return x
429
+
430
+
431
+ def round_func_BPDA(input):
432
+ # This is equivalent to replacing round function (non-differentiable) with
433
+ # an identity function (differentiable) only when backward.
434
+ forward_value = torch.round(input)
435
+ out = input.clone()
436
+ out.data = forward_value.data
437
+ return out
438
+
439
+
440
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
441
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
442
+
443
+
444
+
445
+ ############## Numpy ###############
446
+
447
+ def np_domain_guard(
448
+ x: np.ndarray,
449
+ min: float = None,
450
+ max: float = None,
451
+ posinf: float = None,
452
+ neginf: float = None,
453
+ nan: float = None
454
+ ) -> np.ndarray:
455
+ """Guard a tensor to a valid domain."""
456
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
457
+ if min is not None or max is not None:
458
+ x = np.clip(x, min, max)
459
+ return x
460
+
461
+
462
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
463
+ """Replace a number in a tensor with another number.
464
+
465
+ Args:
466
+ x (np.ndarray): The input tensor.
467
+ num (float): The number to replace.
468
+ to (float): The number to replace with.
469
+
470
+ Returns:
471
+ np.ndarray: The tensor with the number replaced.
472
+ """
473
+ return np.where(x == num, to, x)
474
+
475
+
476
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
477
+ """Guard the power operation to a valid domain."""
478
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
479
+
fn_gen/nlr_t_no_sched/18/loss.png ADDED
fn_gen/nlr_t_no_sched/18/quantization.png ADDED
fn_gen/nlr_t_no_sched/2/distortion.png ADDED
fn_gen/nlr_t_no_sched/2/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asin(_0*x)/_s
2
+ sin(_s*x)/_0
fn_gen/nlr_t_no_sched/2/fn.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asin(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sin((params['_s'] * x)))
19
+
20
+
21
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
22
+ base_p0 = {
23
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], param='_0', **kwargs),
24
+ }
25
+
26
+ base_p0['_s'] = init_linear_scale(x, qtz_func=quantization, params=base_p0, **kwargs)
27
+ if 'post_init_hook' in kwargs:
28
+ kwargs['post_init_hook'](parameters=base_p0)
29
+
30
+ params = init_non_linear_regression_fit(x, p0=base_p0, np_fit_func=fit_func, qtz_func=quantization, deqtz_func=dequantization, params_list=['_0', '_s'], **kwargs)
31
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
32
+ if 'post_method_hook' in kwargs:
33
+ kwargs['post_method_hook'](parameters=params)
34
+
35
+ params = learn_parameters(x, params,
36
+ qtz_func=quantization,
37
+ deqtz_func=dequantization,
38
+ bits=kwargs['bits'],
39
+ target_dtype=torch.int8,
40
+ epochs=500,
41
+ early_stop=False,
42
+ )
43
+ if 'post_train_hook' in kwargs:
44
+ kwargs['post_train_hook'](parameters=params)
45
+
46
+ return params
47
+
48
+
49
+ ############### Numpy Qtz ###############
50
+
51
+
52
+ def np_quantization(x, _0, _s):
53
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsin(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
54
+
55
+
56
+ def np_dequantization(x, _0, _s):
57
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sin((_s * x)))
58
+
59
+
60
+ def fit_func(x, _0, _s):
61
+ x_ = np_quantization(x, _0, _s)
62
+ x_ = np_dequantization(x_, _0, _s)
63
+ return x_
64
+
65
+
66
+
67
+ ############### HELPERS ###############
68
+
69
+ def domain_guard(
70
+ x: torch.Tensor,
71
+ min: float = None,
72
+ max: float = None,
73
+ posinf: float = None,
74
+ neginf: float = None,
75
+ nan: float = None
76
+ ) -> torch.Tensor:
77
+ """Guard a tensor to a valid domain."""
78
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
79
+ if min is not None or max is not None:
80
+ x = torch.clamp(x, min=min, max=max)
81
+ return x
82
+
83
+
84
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
85
+ """Replace a number in a tensor with another number.
86
+
87
+ Args:
88
+ x (torch.Tensor): The input tensor.
89
+ num (float): The number to replace.
90
+ to (float): The number to replace with.
91
+
92
+ Returns:
93
+ torch.Tensor: The tensor with the number replaced.
94
+ """
95
+ return torch.where(x == num, to, x)
96
+
97
+
98
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
99
+ """Guard the power operation to a valid domain."""
100
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
101
+
102
+
103
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
104
+ val = torch.amin(x, dim=1)
105
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
106
+
107
+
108
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
109
+ val = torch.amin(x, dim=1)
110
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
111
+
112
+
113
+ def init_space_search(
114
+ x: torch.Tensor,
115
+ **kwargs: Dict[str, Any],
116
+ ) -> torch.Tensor:
117
+
118
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
119
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
120
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
121
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
122
+
123
+ def _search_param(tensors: List[torch.tensor], n_params):
124
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
125
+ torch_tensors = torch.stack(tensors)
126
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
127
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
128
+ mean = torch.mean(torch_tensors, dim=0)
129
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
130
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
131
+
132
+ def _calc(x, qtz_func, deqtz_func, **params):
133
+ x_ = x.transpose(0, 1)
134
+ x_ = qtz_func(x=x_, **params)
135
+ x_ = deqtz_func(x=x_, **params)
136
+ x_ = x_.transpose(0, 1)
137
+ return x_
138
+
139
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
140
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
141
+ assert "params_list" in kwargs, "params list must be provided."
142
+ assert "param" in kwargs, "param must be provided."
143
+
144
+ qtz_func = kwargs.get('qtz_func')
145
+ deqtz_func = kwargs.get('deqtz_func')
146
+ params_list = kwargs.get('params_list')
147
+ param = kwargs.get('param')
148
+
149
+ n_runs = 50 # Number of runs to try to find the best parameters
150
+ n_random_params = 50 # Number of random parameters to generate
151
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
152
+ max_initial = 10000 # Maximum value to initialize the parameters
153
+
154
+ # Initializes the parameters
155
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
156
+ params = _build_initial_param(x, max_initial, n_random_params)
157
+
158
+ # Performs the search
159
+ for _ in range(n_runs):
160
+
161
+ best_params = []
162
+ for param_ in params:
163
+ try:
164
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
165
+ loss_ones = nn.MSELoss()(x, x_)
166
+
167
+ if len(best_params) < n_best_to_pick:
168
+ best_params.append((param_, loss_ones.item()))
169
+ best_params = sorted(best_params, key=lambda x: x[1])
170
+ elif loss_ones < best_params[-1][1]:
171
+ best_params[-1] = (param_, loss_ones.item())
172
+ best_params = sorted(best_params, key=lambda x: x[1])
173
+
174
+ except Exception: # The parameters might not be valid for the function's domain
175
+ continue
176
+
177
+ # Generates new parameters around the mean
178
+ params = _search_param([p for p, _ in best_params], n_random_params)
179
+
180
+ # Checks if the best parameter is better than the init_ones
181
+ p_ones = init_ones(x, **kwargs)
182
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
183
+ loss_ones = nn.MSELoss()(x, x_)
184
+
185
+ # Checks if the best parameter is better than the init_rand
186
+ p_rand = init_rand(x, **kwargs)
187
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
188
+ loss_rand = nn.MSELoss()(x, x_)
189
+
190
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
191
+ return p_rand
192
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
193
+ return p_ones
194
+ else:
195
+ return best_params[0][0]
196
+
197
+
198
+ def init_linear_scale( # Symmetric scale. From the study folder
199
+ x: torch.Tensor,
200
+ **kwargs: Dict[str, Any],
201
+ ) -> torch.Tensor:
202
+ assert "bits" in kwargs, "bits must be provided."
203
+ assert "params" in kwargs, "params must be provided."
204
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
205
+
206
+ bits = kwargs.get('bits')
207
+ params = kwargs.get('params')
208
+ qtz_func = kwargs.get('qtz_func')
209
+
210
+ x_ = x.transpose(0, 1)
211
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
212
+ x_ = x_.transpose(0, 1)
213
+
214
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
215
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
216
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
217
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
218
+
219
+ eps = torch.finfo(torch.float32).eps
220
+
221
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
222
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
223
+
224
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
225
+
226
+ # Introduces some noise in scale
227
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
228
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
229
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
230
+ # left it here for future reference. Will be removed later.
231
+ # scale = scale + 0.01 * torch.randn_like(scale)
232
+
233
+ return scale
234
+
235
+
236
+ def init_non_linear_regression_fit(
237
+ x: torch.Tensor,
238
+ **kwargs: Dict[str, Any],
239
+ ) -> torch.Tensor:
240
+
241
+ assert "params_list" in kwargs, "params list must be provided."
242
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
243
+ assert "p0" in kwargs, "p0 must be provided."
244
+ np_fit_func = kwargs.get('np_fit_func')
245
+ params_list = kwargs.get('params_list')
246
+ p0 = kwargs.get('p0')
247
+
248
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
249
+ popt, _ = curve_fit(
250
+ func,
251
+ xdata,
252
+ ydata,
253
+ maxfev=1000,
254
+ p0=p0,
255
+ method='lm'
256
+ )
257
+ return popt
258
+
259
+ # 1. Needs to convert the torch tensor to numpy tensor
260
+ xdata = x.cpu().numpy()
261
+
262
+ # 2. Sorts the data so that it makes it easier to fit to it
263
+ sorted_xdata = np.sort(xdata, axis=-1)
264
+
265
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
266
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
267
+
268
+ # 3. Finds the best parameters for each channel
269
+ try:
270
+ params = []
271
+ for i in range(sorted_xdata.shape[0]):
272
+ xdata_ = sorted_xdata[i]
273
+ p0_ = [p0[p][i] for p in params_list]
274
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
275
+ params.append(ch_params)
276
+
277
+ # 4. Builds the parameters
278
+ result = {}
279
+ for i, p in enumerate(params_list):
280
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
281
+
282
+ return result
283
+
284
+ except ValueError as e:
285
+ print(f"Could not fit the function with error: {e}")
286
+ print(f"Using fallback result...")
287
+ return {
288
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
289
+ }
290
+
291
+
292
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
293
+ val = torch.amin(x, dim=1)
294
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
295
+
296
+
297
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
298
+ # Calculate the original minimum and maximum values
299
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
300
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
301
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
302
+
303
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
304
+ return torch.ones_like(x_min)
305
+
306
+ # Calculate the scale factor
307
+ scale = (_max - _min) / (x_max - x_min)
308
+ return scale
309
+
310
+
311
+
312
+ ############## Quant ###############
313
+
314
+ @torch.enable_grad()
315
+ def learn_parameters(
316
+ x: torch.Tensor,
317
+ params: Dict[str, nn.Parameter],
318
+ qtz_func: nn.Module,
319
+ deqtz_func: nn.Module,
320
+ bits: int,
321
+ target_dtype: torch.dtype,
322
+ epochs: int = 1000,
323
+ early_stop: bool = True,
324
+ do_report: bool = False
325
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
326
+ loss_fn = nn.MSELoss()
327
+
328
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
329
+ # the order of magnitude of the loss divided by 2
330
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
331
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
332
+ loss = loss_fn(x, dequant)
333
+
334
+ base_lr = 0.1
335
+ exponent = int(np.floor(np.log10(loss.item())))
336
+ lr = base_lr * (10 ** (exponent // 2))
337
+
338
+ # Requires gradients in the parameters
339
+ for p in params.values():
340
+ p.requires_grad = True
341
+ p.grad = None
342
+
343
+ param_keys = list(params.keys())
344
+ param_values = list(params.values())
345
+
346
+ # Defines optimizer and loss function
347
+ optimizer = torch.optim.Adam(param_values, lr=lr)
348
+
349
+ # Contains the best loss and the best parameters
350
+ best_loss = float("inf")
351
+ best_params = None
352
+
353
+ # Used to stop the search early
354
+ min_delta = 1e-7
355
+ acc_loss = []
356
+ percent_epochs_before_stop = 0.1
357
+
358
+ for i in range(epochs):
359
+ optimizer.zero_grad()
360
+
361
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
362
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
363
+ loss = loss_fn(x, dequant)
364
+
365
+ if loss.isnan() or loss.isinf():
366
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
367
+
368
+ loss.backward()
369
+ optimizer.step()
370
+
371
+ acc_loss.append(loss.item())
372
+
373
+ # Reports loss every 10 steps
374
+ if i % 10 == 0 and do_report:
375
+ print(f"Epoch {i}: Loss {loss.item()}")
376
+
377
+ # Optimizes the parameter search by storing the best loss and the parameters
378
+ if loss.item() < best_loss:
379
+ best_loss = loss.item()
380
+ best_params = copy.deepcopy({
381
+ k: v for k, v in params.items() if k in param_keys
382
+ })
383
+
384
+ # We also stop the search if the loss has not considerably during the last 10% epochs
385
+ if early_stop:
386
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
387
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
388
+ break
389
+
390
+ # No longer requires gradients in the parameters
391
+ for p in best_params.values():
392
+ p.requires_grad = False
393
+ p.grad = None
394
+
395
+ if do_report:
396
+ print(f"Best loss: {best_loss}")
397
+ return best_params, acc_loss
398
+ else:
399
+ return best_params
400
+
401
+
402
+ def quantize(
403
+ x: torch.Tensor,
404
+ params: Dict[str, nn.Parameter],
405
+ func: nn.Module,
406
+ bits: int,
407
+ target_dtype: torch.dtype = torch.int8
408
+ ) -> torch.Tensor:
409
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
410
+ x = x.transpose(0, 1) # Aligns shapes
411
+ x = func(x=x, **params)
412
+ x = x.transpose(0, 1)
413
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
414
+ return x
415
+
416
+
417
+ def dequantize(
418
+ x: torch.Tensor,
419
+ params: Dict[str, nn.Parameter],
420
+ func: nn.Module,
421
+ bits: int,
422
+ out_dtype: torch.dtype
423
+ ) -> torch.Tensor:
424
+ x = x.to(dtype=out_dtype)
425
+ x = x.transpose(0, 1)
426
+ x = func(x=x, **params)
427
+ x = x.transpose(0, 1)
428
+ return x
429
+
430
+
431
+ def round_func_BPDA(input):
432
+ # This is equivalent to replacing round function (non-differentiable) with
433
+ # an identity function (differentiable) only when backward.
434
+ forward_value = torch.round(input)
435
+ out = input.clone()
436
+ out.data = forward_value.data
437
+ return out
438
+
439
+
440
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
441
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
442
+
443
+
444
+
445
+ ############## Numpy ###############
446
+
447
+ def np_domain_guard(
448
+ x: np.ndarray,
449
+ min: float = None,
450
+ max: float = None,
451
+ posinf: float = None,
452
+ neginf: float = None,
453
+ nan: float = None
454
+ ) -> np.ndarray:
455
+ """Guard a tensor to a valid domain."""
456
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
457
+ if min is not None or max is not None:
458
+ x = np.clip(x, min, max)
459
+ return x
460
+
461
+
462
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
463
+ """Replace a number in a tensor with another number.
464
+
465
+ Args:
466
+ x (np.ndarray): The input tensor.
467
+ num (float): The number to replace.
468
+ to (float): The number to replace with.
469
+
470
+ Returns:
471
+ np.ndarray: The tensor with the number replaced.
472
+ """
473
+ return np.where(x == num, to, x)
474
+
475
+
476
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
477
+ """Guard the power operation to a valid domain."""
478
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
479
+
fn_gen/nlr_t_no_sched/2/loss.png ADDED
fn_gen/nlr_t_no_sched/2/quantization.png ADDED