Diogo-V commited on
Commit
725ad36
1 Parent(s): d91308b

Upload learned functions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fn_gen/ones_affine_scale/0/distortion.png +0 -0
  2. fn_gen/ones_affine_scale/0/expressions.txt +2 -0
  3. fn_gen/ones_affine_scale/0/fn.py +493 -0
  4. fn_gen/ones_affine_scale/0/loss.png +0 -0
  5. fn_gen/ones_affine_scale/0/quantization.png +0 -0
  6. fn_gen/ones_affine_scale/1/distortion.png +0 -0
  7. fn_gen/ones_affine_scale/1/expressions.txt +2 -0
  8. fn_gen/ones_affine_scale/1/fn.py +493 -0
  9. fn_gen/ones_affine_scale/1/loss.png +0 -0
  10. fn_gen/ones_affine_scale/1/quantization.png +0 -0
  11. fn_gen/ones_affine_scale/10/distortion.png +0 -0
  12. fn_gen/ones_affine_scale/10/expressions.txt +2 -0
  13. fn_gen/ones_affine_scale/10/fn.py +492 -0
  14. fn_gen/ones_affine_scale/10/loss.png +0 -0
  15. fn_gen/ones_affine_scale/10/quantization.png +0 -0
  16. fn_gen/ones_affine_scale/11/distortion.png +0 -0
  17. fn_gen/ones_affine_scale/11/expressions.txt +2 -0
  18. fn_gen/ones_affine_scale/11/fn.py +493 -0
  19. fn_gen/ones_affine_scale/11/loss.png +0 -0
  20. fn_gen/ones_affine_scale/11/quantization.png +0 -0
  21. fn_gen/ones_affine_scale/12/distortion.png +0 -0
  22. fn_gen/ones_affine_scale/12/expressions.txt +2 -0
  23. fn_gen/ones_affine_scale/12/fn.py +493 -0
  24. fn_gen/ones_affine_scale/12/loss.png +0 -0
  25. fn_gen/ones_affine_scale/12/quantization.png +0 -0
  26. fn_gen/ones_affine_scale/13/distortion.png +0 -0
  27. fn_gen/ones_affine_scale/13/expressions.txt +2 -0
  28. fn_gen/ones_affine_scale/13/fn.py +493 -0
  29. fn_gen/ones_affine_scale/13/loss.png +0 -0
  30. fn_gen/ones_affine_scale/13/quantization.png +0 -0
  31. fn_gen/ones_affine_scale/14/distortion.png +0 -0
  32. fn_gen/ones_affine_scale/14/expressions.txt +2 -0
  33. fn_gen/ones_affine_scale/14/fn.py +493 -0
  34. fn_gen/ones_affine_scale/14/loss.png +0 -0
  35. fn_gen/ones_affine_scale/14/quantization.png +0 -0
  36. fn_gen/ones_affine_scale/15/distortion.png +0 -0
  37. fn_gen/ones_affine_scale/15/expressions.txt +2 -0
  38. fn_gen/ones_affine_scale/15/fn.py +493 -0
  39. fn_gen/ones_affine_scale/15/loss.png +0 -0
  40. fn_gen/ones_affine_scale/15/quantization.png +0 -0
  41. fn_gen/ones_affine_scale/16/distortion.png +0 -0
  42. fn_gen/ones_affine_scale/16/expressions.txt +2 -0
  43. fn_gen/ones_affine_scale/16/fn.py +493 -0
  44. fn_gen/ones_affine_scale/16/loss.png +0 -0
  45. fn_gen/ones_affine_scale/16/quantization.png +0 -0
  46. fn_gen/ones_affine_scale/17/distortion.png +0 -0
  47. fn_gen/ones_affine_scale/17/expressions.txt +2 -0
  48. fn_gen/ones_affine_scale/17/fn.py +493 -0
  49. fn_gen/ones_affine_scale/17/loss.png +0 -0
  50. fn_gen/ones_affine_scale/17/quantization.png +0 -0
fn_gen/ones_affine_scale/0/distortion.png ADDED
fn_gen/ones_affine_scale/0/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atanh(_0*x)/_s
2
+ tanh(_s*x)/_0
fn_gen/ones_affine_scale/0/fn.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
60
+ }
61
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
62
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
63
+
64
+ if 'post_init_hook' in kwargs:
65
+ kwargs['post_init_hook'](parameters=params)
66
+
67
+
68
+ if 'post_train_hook' in kwargs:
69
+ kwargs['post_train_hook'](parameters=params)
70
+
71
+ return params
72
+
73
+
74
+ ############### Numpy Qtz ###############
75
+
76
+
77
+ def np_quantization(x, _0, _s):
78
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0)))
79
+
80
+
81
+ def np_dequantization(x, _0, _s):
82
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x)))
83
+
84
+
85
+ def fit_func(x, _0, _s):
86
+ x_ = np_quantization(x, _0, _s)
87
+ x_ = np_dequantization(x_, _0, _s)
88
+ return x_
89
+
90
+
91
+
92
+ ############### HELPERS ###############
93
+
94
+ def domain_guard(
95
+ x: torch.Tensor,
96
+ min: float = None,
97
+ max: float = None,
98
+ posinf: float = None,
99
+ neginf: float = None,
100
+ nan: float = None
101
+ ) -> torch.Tensor:
102
+ """Guard a tensor to a valid domain."""
103
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
104
+ if min is not None or max is not None:
105
+ x = torch.clamp(x, min=min, max=max)
106
+ return x
107
+
108
+
109
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
110
+ """Replace a number in a tensor with another number.
111
+
112
+ Args:
113
+ x (torch.Tensor): The input tensor.
114
+ num (float): The number to replace.
115
+ to (float): The number to replace with.
116
+
117
+ Returns:
118
+ torch.Tensor: The tensor with the number replaced.
119
+ """
120
+ return torch.where(x == num, to, x)
121
+
122
+
123
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
124
+ """Guard the power operation to a valid domain."""
125
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
126
+
127
+
128
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
129
+ val = torch.amin(x, dim=1)
130
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
131
+
132
+
133
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_space_search(
139
+ x: torch.Tensor,
140
+ **kwargs: Dict[str, Any],
141
+ ) -> torch.Tensor:
142
+
143
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
144
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
145
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
146
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
147
+
148
+ def _search_param(tensors: List[torch.tensor], n_params):
149
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
150
+ torch_tensors = torch.stack(tensors)
151
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
152
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
153
+ mean = torch.mean(torch_tensors, dim=0)
154
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
155
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
156
+
157
+ def _calc(x, qtz_func, deqtz_func, **params):
158
+ x_ = x.transpose(0, 1)
159
+ x_ = qtz_func(x=x_, **params)
160
+ x_ = deqtz_func(x=x_, **params)
161
+ x_ = x_.transpose(0, 1)
162
+ return x_
163
+
164
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
165
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
166
+ assert "params_list" in kwargs, "params list must be provided."
167
+ assert "param" in kwargs, "param must be provided."
168
+
169
+ qtz_func = kwargs.get('qtz_func')
170
+ deqtz_func = kwargs.get('deqtz_func')
171
+ params_list = kwargs.get('params_list')
172
+ param = kwargs.get('param')
173
+
174
+ n_runs = 50 # Number of runs to try to find the best parameters
175
+ n_random_params = 50 # Number of random parameters to generate
176
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
177
+ max_initial = 10000 # Maximum value to initialize the parameters
178
+
179
+ # Initializes the parameters
180
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
181
+ params = _build_initial_param(x, max_initial, n_random_params)
182
+
183
+ # Performs the search
184
+ for _ in range(n_runs):
185
+
186
+ best_params = []
187
+ for param_ in params:
188
+ try:
189
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
190
+ loss_ones = nn.MSELoss()(x, x_)
191
+
192
+ if len(best_params) < n_best_to_pick:
193
+ best_params.append((param_, loss_ones.item()))
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+ elif loss_ones < best_params[-1][1]:
196
+ best_params[-1] = (param_, loss_ones.item())
197
+ best_params = sorted(best_params, key=lambda x: x[1])
198
+
199
+ except Exception: # The parameters might not be valid for the function's domain
200
+ continue
201
+
202
+ # Generates new parameters around the mean
203
+ params = _search_param([p for p, _ in best_params], n_random_params)
204
+
205
+ # Checks if the best parameter is better than the init_ones
206
+ p_ones = init_ones(x, **kwargs)
207
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
208
+ loss_ones = nn.MSELoss()(x, x_)
209
+
210
+ # Checks if the best parameter is better than the init_rand
211
+ p_rand = init_rand(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
213
+ loss_rand = nn.MSELoss()(x, x_)
214
+
215
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
216
+ return p_rand
217
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
218
+ return p_ones
219
+ else:
220
+ return best_params[0][0]
221
+
222
+
223
+ def init_linear_scale( # Symmetric scale. From the study folder
224
+ x: torch.Tensor,
225
+ **kwargs: Dict[str, Any],
226
+ ) -> torch.Tensor:
227
+ assert "bits" in kwargs, "bits must be provided."
228
+ assert "params" in kwargs, "params must be provided."
229
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
230
+
231
+ bits = kwargs.get('bits')
232
+ params = kwargs.get('params')
233
+ qtz_func = kwargs.get('qtz_func')
234
+
235
+ x_ = x.transpose(0, 1)
236
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
237
+ x_ = x_.transpose(0, 1)
238
+
239
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
240
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
241
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
242
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
243
+
244
+ eps = torch.finfo(torch.float32).eps
245
+
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
248
+
249
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
250
+
251
+ return scale
252
+
253
+ # Introduces some noise in scale
254
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
255
+ scale = scale + 0.01 * torch.randn_like(scale)
256
+ return scale
257
+
258
+
259
+ def init_non_linear_regression_fit(
260
+ x: torch.Tensor,
261
+ **kwargs: Dict[str, Any],
262
+ ) -> torch.Tensor:
263
+
264
+ assert "params_list" in kwargs, "params list must be provided."
265
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
266
+ assert "p0" in kwargs, "p0 must be provided."
267
+ np_fit_func = kwargs.get('np_fit_func')
268
+ params_list = kwargs.get('params_list')
269
+ p0 = kwargs.get('p0')
270
+
271
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
272
+ popt, _ = curve_fit(
273
+ func,
274
+ xdata,
275
+ ydata,
276
+ maxfev=1000,
277
+ p0=p0,
278
+ method='lm'
279
+ )
280
+ return popt
281
+
282
+ # 1. Needs to convert the torch tensor to numpy tensor
283
+ xdata = x.cpu().numpy()
284
+
285
+ # 2. Sorts the data so that it makes it easier to fit to it
286
+ sorted_xdata = np.sort(xdata, axis=-1)
287
+
288
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
289
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
290
+
291
+ # 3. Finds the best parameters for each channel
292
+ try:
293
+ params = []
294
+ for i in range(sorted_xdata.shape[0]):
295
+ xdata_ = sorted_xdata[i]
296
+ p0_ = [p0[p][i] for p in params_list]
297
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
298
+ params.append(ch_params)
299
+
300
+ # 4. Builds the parameters
301
+ result = {}
302
+ for i, p in enumerate(params_list):
303
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
304
+
305
+ return result
306
+
307
+ except ValueError as e:
308
+ print(f"Could not fit the function with error: {e}")
309
+ print(f"Using fallback result...")
310
+ return {
311
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
312
+ }
313
+
314
+
315
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
316
+ val = torch.amin(x, dim=1)
317
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
318
+
319
+
320
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
321
+ # Calculate the original minimum and maximum values
322
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
323
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
324
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
325
+
326
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
327
+ return torch.ones_like(x_min)
328
+
329
+ # Calculate the scale factor
330
+ scale = (_max - _min) / (x_max - x_min)
331
+ return scale
332
+
333
+
334
+
335
+ ############## Quant ###############
336
+
337
+ @torch.enable_grad()
338
+ def learn_parameters(
339
+ x: torch.Tensor,
340
+ params: Dict[str, nn.Parameter],
341
+ qtz_func: nn.Module,
342
+ deqtz_func: nn.Module,
343
+ bits: int,
344
+ target_dtype: torch.dtype,
345
+ epochs: int = 1000,
346
+ early_stop: bool = True,
347
+ do_report: bool = False
348
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
349
+
350
+ # Requires gradients in the parameters
351
+ for p in params.values():
352
+ p.requires_grad = True
353
+ p.grad = None
354
+
355
+ param_keys = list(params.keys())
356
+ param_values = list(params.values())
357
+
358
+ # Defines optimizer and loss function
359
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
360
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
361
+ loss_fn = nn.MSELoss()
362
+
363
+ # Contains the best loss and the best parameters
364
+ best_loss = float("inf")
365
+ best_params = None
366
+
367
+ # Used to stop the search early
368
+ min_delta = 1e-7
369
+ acc_loss = []
370
+ percent_epochs_before_stop = 0.1
371
+
372
+ for i in range(epochs):
373
+ optimizer.zero_grad()
374
+
375
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
376
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
377
+ loss = loss_fn(x, dequant)
378
+
379
+ if loss.isnan() or loss.isinf():
380
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
381
+
382
+ loss.backward()
383
+ optimizer.step()
384
+ scheduler.step()
385
+
386
+ acc_loss.append(loss.item())
387
+
388
+ # Reports loss every 10 steps
389
+ if i % 10 == 0 and do_report:
390
+ print(f"Epoch {i}: Loss {loss.item()}")
391
+
392
+ # Optimizes the parameter search by storing the best loss and the parameters
393
+ if loss.item() < best_loss:
394
+ best_loss = loss.item()
395
+ best_params = copy.deepcopy({
396
+ k: v for k, v in params.items() if k in param_keys
397
+ })
398
+
399
+ # We also stop the search if the loss has not considerably during the last 10% epochs
400
+ if early_stop:
401
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
402
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
403
+ break
404
+
405
+ # No longer requires gradients in the parameters
406
+ for p in best_params.values():
407
+ p.requires_grad = False
408
+ p.grad = None
409
+
410
+ if do_report:
411
+ return best_params, acc_loss
412
+ else:
413
+ return best_params
414
+
415
+
416
+ def quantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ target_dtype: torch.dtype = torch.int8
422
+ ) -> torch.Tensor:
423
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
424
+ x = x.transpose(0, 1) # Aligns shapes
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
428
+ return x
429
+
430
+
431
+ def dequantize(
432
+ x: torch.Tensor,
433
+ params: Dict[str, nn.Parameter],
434
+ func: nn.Module,
435
+ bits: int,
436
+ out_dtype: torch.dtype
437
+ ) -> torch.Tensor:
438
+ x = x.to(dtype=out_dtype)
439
+ x = x.transpose(0, 1)
440
+ x = func(x=x, **params)
441
+ x = x.transpose(0, 1)
442
+ return x
443
+
444
+
445
+ def round_func_BPDA(input):
446
+ # This is equivalent to replacing round function (non-differentiable) with
447
+ # an identity function (differentiable) only when backward.
448
+ forward_value = torch.round(input)
449
+ out = input.clone()
450
+ out.data = forward_value.data
451
+ return out
452
+
453
+
454
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
455
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
456
+
457
+
458
+
459
+ ############## Numpy ###############
460
+
461
+ def np_domain_guard(
462
+ x: np.ndarray,
463
+ min: float = None,
464
+ max: float = None,
465
+ posinf: float = None,
466
+ neginf: float = None,
467
+ nan: float = None
468
+ ) -> np.ndarray:
469
+ """Guard a tensor to a valid domain."""
470
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
471
+ if min is not None or max is not None:
472
+ x = np.clip(x, min, max)
473
+ return x
474
+
475
+
476
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
477
+ """Replace a number in a tensor with another number.
478
+
479
+ Args:
480
+ x (np.ndarray): The input tensor.
481
+ num (float): The number to replace.
482
+ to (float): The number to replace with.
483
+
484
+ Returns:
485
+ np.ndarray: The tensor with the number replaced.
486
+ """
487
+ return np.where(x == num, to, x)
488
+
489
+
490
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
491
+ """Guard the power operation to a valid domain."""
492
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
493
+
fn_gen/ones_affine_scale/0/loss.png ADDED
fn_gen/ones_affine_scale/0/quantization.png ADDED
fn_gen/ones_affine_scale/1/distortion.png ADDED
fn_gen/ones_affine_scale/1/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ (_0*x)**(1/3)/_s
2
+ _s**3*x**3/_0
fn_gen/ones_affine_scale/1/fn.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power((params['_0'] * x), 1 / 3))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(3)) * guarded_torch_power(x, torch.tensor(3)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
60
+ }
61
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
62
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
63
+
64
+ if 'post_init_hook' in kwargs:
65
+ kwargs['post_init_hook'](parameters=params)
66
+
67
+
68
+ if 'post_train_hook' in kwargs:
69
+ kwargs['post_train_hook'](parameters=params)
70
+
71
+ return params
72
+
73
+
74
+ ############### Numpy Qtz ###############
75
+
76
+
77
+ def np_quantization(x, _0, _s):
78
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power((_0 * x), 1 / 3))
79
+
80
+
81
+ def np_dequantization(x, _0, _s):
82
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(3)) * np_guarded_power(x, np.array(3)))
83
+
84
+
85
+ def fit_func(x, _0, _s):
86
+ x_ = np_quantization(x, _0, _s)
87
+ x_ = np_dequantization(x_, _0, _s)
88
+ return x_
89
+
90
+
91
+
92
+ ############### HELPERS ###############
93
+
94
+ def domain_guard(
95
+ x: torch.Tensor,
96
+ min: float = None,
97
+ max: float = None,
98
+ posinf: float = None,
99
+ neginf: float = None,
100
+ nan: float = None
101
+ ) -> torch.Tensor:
102
+ """Guard a tensor to a valid domain."""
103
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
104
+ if min is not None or max is not None:
105
+ x = torch.clamp(x, min=min, max=max)
106
+ return x
107
+
108
+
109
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
110
+ """Replace a number in a tensor with another number.
111
+
112
+ Args:
113
+ x (torch.Tensor): The input tensor.
114
+ num (float): The number to replace.
115
+ to (float): The number to replace with.
116
+
117
+ Returns:
118
+ torch.Tensor: The tensor with the number replaced.
119
+ """
120
+ return torch.where(x == num, to, x)
121
+
122
+
123
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
124
+ """Guard the power operation to a valid domain."""
125
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
126
+
127
+
128
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
129
+ val = torch.amin(x, dim=1)
130
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
131
+
132
+
133
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_space_search(
139
+ x: torch.Tensor,
140
+ **kwargs: Dict[str, Any],
141
+ ) -> torch.Tensor:
142
+
143
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
144
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
145
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
146
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
147
+
148
+ def _search_param(tensors: List[torch.tensor], n_params):
149
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
150
+ torch_tensors = torch.stack(tensors)
151
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
152
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
153
+ mean = torch.mean(torch_tensors, dim=0)
154
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
155
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
156
+
157
+ def _calc(x, qtz_func, deqtz_func, **params):
158
+ x_ = x.transpose(0, 1)
159
+ x_ = qtz_func(x=x_, **params)
160
+ x_ = deqtz_func(x=x_, **params)
161
+ x_ = x_.transpose(0, 1)
162
+ return x_
163
+
164
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
165
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
166
+ assert "params_list" in kwargs, "params list must be provided."
167
+ assert "param" in kwargs, "param must be provided."
168
+
169
+ qtz_func = kwargs.get('qtz_func')
170
+ deqtz_func = kwargs.get('deqtz_func')
171
+ params_list = kwargs.get('params_list')
172
+ param = kwargs.get('param')
173
+
174
+ n_runs = 50 # Number of runs to try to find the best parameters
175
+ n_random_params = 50 # Number of random parameters to generate
176
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
177
+ max_initial = 10000 # Maximum value to initialize the parameters
178
+
179
+ # Initializes the parameters
180
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
181
+ params = _build_initial_param(x, max_initial, n_random_params)
182
+
183
+ # Performs the search
184
+ for _ in range(n_runs):
185
+
186
+ best_params = []
187
+ for param_ in params:
188
+ try:
189
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
190
+ loss_ones = nn.MSELoss()(x, x_)
191
+
192
+ if len(best_params) < n_best_to_pick:
193
+ best_params.append((param_, loss_ones.item()))
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+ elif loss_ones < best_params[-1][1]:
196
+ best_params[-1] = (param_, loss_ones.item())
197
+ best_params = sorted(best_params, key=lambda x: x[1])
198
+
199
+ except Exception: # The parameters might not be valid for the function's domain
200
+ continue
201
+
202
+ # Generates new parameters around the mean
203
+ params = _search_param([p for p, _ in best_params], n_random_params)
204
+
205
+ # Checks if the best parameter is better than the init_ones
206
+ p_ones = init_ones(x, **kwargs)
207
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
208
+ loss_ones = nn.MSELoss()(x, x_)
209
+
210
+ # Checks if the best parameter is better than the init_rand
211
+ p_rand = init_rand(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
213
+ loss_rand = nn.MSELoss()(x, x_)
214
+
215
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
216
+ return p_rand
217
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
218
+ return p_ones
219
+ else:
220
+ return best_params[0][0]
221
+
222
+
223
+ def init_linear_scale( # Symmetric scale. From the study folder
224
+ x: torch.Tensor,
225
+ **kwargs: Dict[str, Any],
226
+ ) -> torch.Tensor:
227
+ assert "bits" in kwargs, "bits must be provided."
228
+ assert "params" in kwargs, "params must be provided."
229
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
230
+
231
+ bits = kwargs.get('bits')
232
+ params = kwargs.get('params')
233
+ qtz_func = kwargs.get('qtz_func')
234
+
235
+ x_ = x.transpose(0, 1)
236
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
237
+ x_ = x_.transpose(0, 1)
238
+
239
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
240
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
241
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
242
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
243
+
244
+ eps = torch.finfo(torch.float32).eps
245
+
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
248
+
249
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
250
+
251
+ return scale
252
+
253
+ # Introduces some noise in scale
254
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
255
+ scale = scale + 0.01 * torch.randn_like(scale)
256
+ return scale
257
+
258
+
259
+ def init_non_linear_regression_fit(
260
+ x: torch.Tensor,
261
+ **kwargs: Dict[str, Any],
262
+ ) -> torch.Tensor:
263
+
264
+ assert "params_list" in kwargs, "params list must be provided."
265
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
266
+ assert "p0" in kwargs, "p0 must be provided."
267
+ np_fit_func = kwargs.get('np_fit_func')
268
+ params_list = kwargs.get('params_list')
269
+ p0 = kwargs.get('p0')
270
+
271
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
272
+ popt, _ = curve_fit(
273
+ func,
274
+ xdata,
275
+ ydata,
276
+ maxfev=1000,
277
+ p0=p0,
278
+ method='lm'
279
+ )
280
+ return popt
281
+
282
+ # 1. Needs to convert the torch tensor to numpy tensor
283
+ xdata = x.cpu().numpy()
284
+
285
+ # 2. Sorts the data so that it makes it easier to fit to it
286
+ sorted_xdata = np.sort(xdata, axis=-1)
287
+
288
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
289
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
290
+
291
+ # 3. Finds the best parameters for each channel
292
+ try:
293
+ params = []
294
+ for i in range(sorted_xdata.shape[0]):
295
+ xdata_ = sorted_xdata[i]
296
+ p0_ = [p0[p][i] for p in params_list]
297
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
298
+ params.append(ch_params)
299
+
300
+ # 4. Builds the parameters
301
+ result = {}
302
+ for i, p in enumerate(params_list):
303
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
304
+
305
+ return result
306
+
307
+ except ValueError as e:
308
+ print(f"Could not fit the function with error: {e}")
309
+ print(f"Using fallback result...")
310
+ return {
311
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
312
+ }
313
+
314
+
315
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
316
+ val = torch.amin(x, dim=1)
317
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
318
+
319
+
320
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
321
+ # Calculate the original minimum and maximum values
322
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
323
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
324
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
325
+
326
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
327
+ return torch.ones_like(x_min)
328
+
329
+ # Calculate the scale factor
330
+ scale = (_max - _min) / (x_max - x_min)
331
+ return scale
332
+
333
+
334
+
335
+ ############## Quant ###############
336
+
337
+ @torch.enable_grad()
338
+ def learn_parameters(
339
+ x: torch.Tensor,
340
+ params: Dict[str, nn.Parameter],
341
+ qtz_func: nn.Module,
342
+ deqtz_func: nn.Module,
343
+ bits: int,
344
+ target_dtype: torch.dtype,
345
+ epochs: int = 1000,
346
+ early_stop: bool = True,
347
+ do_report: bool = False
348
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
349
+
350
+ # Requires gradients in the parameters
351
+ for p in params.values():
352
+ p.requires_grad = True
353
+ p.grad = None
354
+
355
+ param_keys = list(params.keys())
356
+ param_values = list(params.values())
357
+
358
+ # Defines optimizer and loss function
359
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
360
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
361
+ loss_fn = nn.MSELoss()
362
+
363
+ # Contains the best loss and the best parameters
364
+ best_loss = float("inf")
365
+ best_params = None
366
+
367
+ # Used to stop the search early
368
+ min_delta = 1e-7
369
+ acc_loss = []
370
+ percent_epochs_before_stop = 0.1
371
+
372
+ for i in range(epochs):
373
+ optimizer.zero_grad()
374
+
375
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
376
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
377
+ loss = loss_fn(x, dequant)
378
+
379
+ if loss.isnan() or loss.isinf():
380
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
381
+
382
+ loss.backward()
383
+ optimizer.step()
384
+ scheduler.step()
385
+
386
+ acc_loss.append(loss.item())
387
+
388
+ # Reports loss every 10 steps
389
+ if i % 10 == 0 and do_report:
390
+ print(f"Epoch {i}: Loss {loss.item()}")
391
+
392
+ # Optimizes the parameter search by storing the best loss and the parameters
393
+ if loss.item() < best_loss:
394
+ best_loss = loss.item()
395
+ best_params = copy.deepcopy({
396
+ k: v for k, v in params.items() if k in param_keys
397
+ })
398
+
399
+ # We also stop the search if the loss has not considerably during the last 10% epochs
400
+ if early_stop:
401
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
402
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
403
+ break
404
+
405
+ # No longer requires gradients in the parameters
406
+ for p in best_params.values():
407
+ p.requires_grad = False
408
+ p.grad = None
409
+
410
+ if do_report:
411
+ return best_params, acc_loss
412
+ else:
413
+ return best_params
414
+
415
+
416
+ def quantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ target_dtype: torch.dtype = torch.int8
422
+ ) -> torch.Tensor:
423
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
424
+ x = x.transpose(0, 1) # Aligns shapes
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
428
+ return x
429
+
430
+
431
+ def dequantize(
432
+ x: torch.Tensor,
433
+ params: Dict[str, nn.Parameter],
434
+ func: nn.Module,
435
+ bits: int,
436
+ out_dtype: torch.dtype
437
+ ) -> torch.Tensor:
438
+ x = x.to(dtype=out_dtype)
439
+ x = x.transpose(0, 1)
440
+ x = func(x=x, **params)
441
+ x = x.transpose(0, 1)
442
+ return x
443
+
444
+
445
+ def round_func_BPDA(input):
446
+ # This is equivalent to replacing round function (non-differentiable) with
447
+ # an identity function (differentiable) only when backward.
448
+ forward_value = torch.round(input)
449
+ out = input.clone()
450
+ out.data = forward_value.data
451
+ return out
452
+
453
+
454
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
455
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
456
+
457
+
458
+
459
+ ############## Numpy ###############
460
+
461
+ def np_domain_guard(
462
+ x: np.ndarray,
463
+ min: float = None,
464
+ max: float = None,
465
+ posinf: float = None,
466
+ neginf: float = None,
467
+ nan: float = None
468
+ ) -> np.ndarray:
469
+ """Guard a tensor to a valid domain."""
470
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
471
+ if min is not None or max is not None:
472
+ x = np.clip(x, min, max)
473
+ return x
474
+
475
+
476
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
477
+ """Replace a number in a tensor with another number.
478
+
479
+ Args:
480
+ x (np.ndarray): The input tensor.
481
+ num (float): The number to replace.
482
+ to (float): The number to replace with.
483
+
484
+ Returns:
485
+ np.ndarray: The tensor with the number replaced.
486
+ """
487
+ return np.where(x == num, to, x)
488
+
489
+
490
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
491
+ """Guard the power operation to a valid domain."""
492
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
493
+
fn_gen/ones_affine_scale/1/loss.png ADDED
fn_gen/ones_affine_scale/1/quantization.png ADDED
fn_gen/ones_affine_scale/10/distortion.png ADDED
fn_gen/ones_affine_scale/10/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**2/_s
2
+ sqrt(_s*x)
fn_gen/ones_affine_scale/10/fn.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(2)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return torch.sqrt(domain_guard((params['_s'] * x), min=0.1, nan=0.1))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ }
60
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
61
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
62
+
63
+ if 'post_init_hook' in kwargs:
64
+ kwargs['post_init_hook'](parameters=params)
65
+
66
+
67
+ if 'post_train_hook' in kwargs:
68
+ kwargs['post_train_hook'](parameters=params)
69
+
70
+ return params
71
+
72
+
73
+ ############### Numpy Qtz ###############
74
+
75
+
76
+ def np_quantization(x, _s):
77
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(2)))
78
+
79
+
80
+ def np_dequantization(x, _s):
81
+ return np.sqrt(np_domain_guard((_s * x), min=0.1, nan=0.1))
82
+
83
+
84
+ def fit_func(x, _s):
85
+ x_ = np_quantization(x, _s)
86
+ x_ = np_dequantization(x_, _s)
87
+ return x_
88
+
89
+
90
+
91
+ ############### HELPERS ###############
92
+
93
+ def domain_guard(
94
+ x: torch.Tensor,
95
+ min: float = None,
96
+ max: float = None,
97
+ posinf: float = None,
98
+ neginf: float = None,
99
+ nan: float = None
100
+ ) -> torch.Tensor:
101
+ """Guard a tensor to a valid domain."""
102
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
103
+ if min is not None or max is not None:
104
+ x = torch.clamp(x, min=min, max=max)
105
+ return x
106
+
107
+
108
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
109
+ """Replace a number in a tensor with another number.
110
+
111
+ Args:
112
+ x (torch.Tensor): The input tensor.
113
+ num (float): The number to replace.
114
+ to (float): The number to replace with.
115
+
116
+ Returns:
117
+ torch.Tensor: The tensor with the number replaced.
118
+ """
119
+ return torch.where(x == num, to, x)
120
+
121
+
122
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
123
+ """Guard the power operation to a valid domain."""
124
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
125
+
126
+
127
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
128
+ val = torch.amin(x, dim=1)
129
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
130
+
131
+
132
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
133
+ val = torch.amin(x, dim=1)
134
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
135
+
136
+
137
+ def init_space_search(
138
+ x: torch.Tensor,
139
+ **kwargs: Dict[str, Any],
140
+ ) -> torch.Tensor:
141
+
142
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
143
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
144
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
145
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
146
+
147
+ def _search_param(tensors: List[torch.tensor], n_params):
148
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
149
+ torch_tensors = torch.stack(tensors)
150
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
151
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
152
+ mean = torch.mean(torch_tensors, dim=0)
153
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
154
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
155
+
156
+ def _calc(x, qtz_func, deqtz_func, **params):
157
+ x_ = x.transpose(0, 1)
158
+ x_ = qtz_func(x=x_, **params)
159
+ x_ = deqtz_func(x=x_, **params)
160
+ x_ = x_.transpose(0, 1)
161
+ return x_
162
+
163
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
164
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
165
+ assert "params_list" in kwargs, "params list must be provided."
166
+ assert "param" in kwargs, "param must be provided."
167
+
168
+ qtz_func = kwargs.get('qtz_func')
169
+ deqtz_func = kwargs.get('deqtz_func')
170
+ params_list = kwargs.get('params_list')
171
+ param = kwargs.get('param')
172
+
173
+ n_runs = 50 # Number of runs to try to find the best parameters
174
+ n_random_params = 50 # Number of random parameters to generate
175
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
176
+ max_initial = 10000 # Maximum value to initialize the parameters
177
+
178
+ # Initializes the parameters
179
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
180
+ params = _build_initial_param(x, max_initial, n_random_params)
181
+
182
+ # Performs the search
183
+ for _ in range(n_runs):
184
+
185
+ best_params = []
186
+ for param_ in params:
187
+ try:
188
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
189
+ loss_ones = nn.MSELoss()(x, x_)
190
+
191
+ if len(best_params) < n_best_to_pick:
192
+ best_params.append((param_, loss_ones.item()))
193
+ best_params = sorted(best_params, key=lambda x: x[1])
194
+ elif loss_ones < best_params[-1][1]:
195
+ best_params[-1] = (param_, loss_ones.item())
196
+ best_params = sorted(best_params, key=lambda x: x[1])
197
+
198
+ except Exception: # The parameters might not be valid for the function's domain
199
+ continue
200
+
201
+ # Generates new parameters around the mean
202
+ params = _search_param([p for p, _ in best_params], n_random_params)
203
+
204
+ # Checks if the best parameter is better than the init_ones
205
+ p_ones = init_ones(x, **kwargs)
206
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
207
+ loss_ones = nn.MSELoss()(x, x_)
208
+
209
+ # Checks if the best parameter is better than the init_rand
210
+ p_rand = init_rand(x, **kwargs)
211
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
212
+ loss_rand = nn.MSELoss()(x, x_)
213
+
214
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
215
+ return p_rand
216
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
217
+ return p_ones
218
+ else:
219
+ return best_params[0][0]
220
+
221
+
222
+ def init_linear_scale( # Symmetric scale. From the study folder
223
+ x: torch.Tensor,
224
+ **kwargs: Dict[str, Any],
225
+ ) -> torch.Tensor:
226
+ assert "bits" in kwargs, "bits must be provided."
227
+ assert "params" in kwargs, "params must be provided."
228
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
229
+
230
+ bits = kwargs.get('bits')
231
+ params = kwargs.get('params')
232
+ qtz_func = kwargs.get('qtz_func')
233
+
234
+ x_ = x.transpose(0, 1)
235
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
236
+ x_ = x_.transpose(0, 1)
237
+
238
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
239
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
240
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
241
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
242
+
243
+ eps = torch.finfo(torch.float32).eps
244
+
245
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
246
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
247
+
248
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
249
+
250
+ return scale
251
+
252
+ # Introduces some noise in scale
253
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
254
+ scale = scale + 0.01 * torch.randn_like(scale)
255
+ return scale
256
+
257
+
258
+ def init_non_linear_regression_fit(
259
+ x: torch.Tensor,
260
+ **kwargs: Dict[str, Any],
261
+ ) -> torch.Tensor:
262
+
263
+ assert "params_list" in kwargs, "params list must be provided."
264
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
265
+ assert "p0" in kwargs, "p0 must be provided."
266
+ np_fit_func = kwargs.get('np_fit_func')
267
+ params_list = kwargs.get('params_list')
268
+ p0 = kwargs.get('p0')
269
+
270
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
271
+ popt, _ = curve_fit(
272
+ func,
273
+ xdata,
274
+ ydata,
275
+ maxfev=1000,
276
+ p0=p0,
277
+ method='lm'
278
+ )
279
+ return popt
280
+
281
+ # 1. Needs to convert the torch tensor to numpy tensor
282
+ xdata = x.cpu().numpy()
283
+
284
+ # 2. Sorts the data so that it makes it easier to fit to it
285
+ sorted_xdata = np.sort(xdata, axis=-1)
286
+
287
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
288
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
289
+
290
+ # 3. Finds the best parameters for each channel
291
+ try:
292
+ params = []
293
+ for i in range(sorted_xdata.shape[0]):
294
+ xdata_ = sorted_xdata[i]
295
+ p0_ = [p0[p][i] for p in params_list]
296
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
297
+ params.append(ch_params)
298
+
299
+ # 4. Builds the parameters
300
+ result = {}
301
+ for i, p in enumerate(params_list):
302
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
303
+
304
+ return result
305
+
306
+ except ValueError as e:
307
+ print(f"Could not fit the function with error: {e}")
308
+ print(f"Using fallback result...")
309
+ return {
310
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
311
+ }
312
+
313
+
314
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
315
+ val = torch.amin(x, dim=1)
316
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
317
+
318
+
319
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
320
+ # Calculate the original minimum and maximum values
321
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
322
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
323
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
324
+
325
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
326
+ return torch.ones_like(x_min)
327
+
328
+ # Calculate the scale factor
329
+ scale = (_max - _min) / (x_max - x_min)
330
+ return scale
331
+
332
+
333
+
334
+ ############## Quant ###############
335
+
336
+ @torch.enable_grad()
337
+ def learn_parameters(
338
+ x: torch.Tensor,
339
+ params: Dict[str, nn.Parameter],
340
+ qtz_func: nn.Module,
341
+ deqtz_func: nn.Module,
342
+ bits: int,
343
+ target_dtype: torch.dtype,
344
+ epochs: int = 1000,
345
+ early_stop: bool = True,
346
+ do_report: bool = False
347
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
348
+
349
+ # Requires gradients in the parameters
350
+ for p in params.values():
351
+ p.requires_grad = True
352
+ p.grad = None
353
+
354
+ param_keys = list(params.keys())
355
+ param_values = list(params.values())
356
+
357
+ # Defines optimizer and loss function
358
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
359
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
360
+ loss_fn = nn.MSELoss()
361
+
362
+ # Contains the best loss and the best parameters
363
+ best_loss = float("inf")
364
+ best_params = None
365
+
366
+ # Used to stop the search early
367
+ min_delta = 1e-7
368
+ acc_loss = []
369
+ percent_epochs_before_stop = 0.1
370
+
371
+ for i in range(epochs):
372
+ optimizer.zero_grad()
373
+
374
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
375
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
376
+ loss = loss_fn(x, dequant)
377
+
378
+ if loss.isnan() or loss.isinf():
379
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
380
+
381
+ loss.backward()
382
+ optimizer.step()
383
+ scheduler.step()
384
+
385
+ acc_loss.append(loss.item())
386
+
387
+ # Reports loss every 10 steps
388
+ if i % 10 == 0 and do_report:
389
+ print(f"Epoch {i}: Loss {loss.item()}")
390
+
391
+ # Optimizes the parameter search by storing the best loss and the parameters
392
+ if loss.item() < best_loss:
393
+ best_loss = loss.item()
394
+ best_params = copy.deepcopy({
395
+ k: v for k, v in params.items() if k in param_keys
396
+ })
397
+
398
+ # We also stop the search if the loss has not considerably during the last 10% epochs
399
+ if early_stop:
400
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
401
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
402
+ break
403
+
404
+ # No longer requires gradients in the parameters
405
+ for p in best_params.values():
406
+ p.requires_grad = False
407
+ p.grad = None
408
+
409
+ if do_report:
410
+ return best_params, acc_loss
411
+ else:
412
+ return best_params
413
+
414
+
415
+ def quantize(
416
+ x: torch.Tensor,
417
+ params: Dict[str, nn.Parameter],
418
+ func: nn.Module,
419
+ bits: int,
420
+ target_dtype: torch.dtype = torch.int8
421
+ ) -> torch.Tensor:
422
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
423
+ x = x.transpose(0, 1) # Aligns shapes
424
+ x = func(x=x, **params)
425
+ x = x.transpose(0, 1)
426
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
427
+ return x
428
+
429
+
430
+ def dequantize(
431
+ x: torch.Tensor,
432
+ params: Dict[str, nn.Parameter],
433
+ func: nn.Module,
434
+ bits: int,
435
+ out_dtype: torch.dtype
436
+ ) -> torch.Tensor:
437
+ x = x.to(dtype=out_dtype)
438
+ x = x.transpose(0, 1)
439
+ x = func(x=x, **params)
440
+ x = x.transpose(0, 1)
441
+ return x
442
+
443
+
444
+ def round_func_BPDA(input):
445
+ # This is equivalent to replacing round function (non-differentiable) with
446
+ # an identity function (differentiable) only when backward.
447
+ forward_value = torch.round(input)
448
+ out = input.clone()
449
+ out.data = forward_value.data
450
+ return out
451
+
452
+
453
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
454
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
455
+
456
+
457
+
458
+ ############## Numpy ###############
459
+
460
+ def np_domain_guard(
461
+ x: np.ndarray,
462
+ min: float = None,
463
+ max: float = None,
464
+ posinf: float = None,
465
+ neginf: float = None,
466
+ nan: float = None
467
+ ) -> np.ndarray:
468
+ """Guard a tensor to a valid domain."""
469
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
470
+ if min is not None or max is not None:
471
+ x = np.clip(x, min, max)
472
+ return x
473
+
474
+
475
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
476
+ """Replace a number in a tensor with another number.
477
+
478
+ Args:
479
+ x (np.ndarray): The input tensor.
480
+ num (float): The number to replace.
481
+ to (float): The number to replace with.
482
+
483
+ Returns:
484
+ np.ndarray: The tensor with the number replaced.
485
+ """
486
+ return np.where(x == num, to, x)
487
+
488
+
489
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
490
+ """Guard the power operation to a valid domain."""
491
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
492
+
fn_gen/ones_affine_scale/10/loss.png ADDED
fn_gen/ones_affine_scale/10/quantization.png ADDED
fn_gen/ones_affine_scale/11/distortion.png ADDED
fn_gen/ones_affine_scale/11/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asinh(_0*x)/_s
2
+ sinh(_s*x)/_0
fn_gen/ones_affine_scale/11/fn.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
60
+ }
61
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
62
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
63
+
64
+ if 'post_init_hook' in kwargs:
65
+ kwargs['post_init_hook'](parameters=params)
66
+
67
+
68
+ if 'post_train_hook' in kwargs:
69
+ kwargs['post_train_hook'](parameters=params)
70
+
71
+ return params
72
+
73
+
74
+ ############### Numpy Qtz ###############
75
+
76
+
77
+ def np_quantization(x, _0, _s):
78
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x)))
79
+
80
+
81
+ def np_dequantization(x, _0, _s):
82
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x)))
83
+
84
+
85
+ def fit_func(x, _0, _s):
86
+ x_ = np_quantization(x, _0, _s)
87
+ x_ = np_dequantization(x_, _0, _s)
88
+ return x_
89
+
90
+
91
+
92
+ ############### HELPERS ###############
93
+
94
+ def domain_guard(
95
+ x: torch.Tensor,
96
+ min: float = None,
97
+ max: float = None,
98
+ posinf: float = None,
99
+ neginf: float = None,
100
+ nan: float = None
101
+ ) -> torch.Tensor:
102
+ """Guard a tensor to a valid domain."""
103
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
104
+ if min is not None or max is not None:
105
+ x = torch.clamp(x, min=min, max=max)
106
+ return x
107
+
108
+
109
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
110
+ """Replace a number in a tensor with another number.
111
+
112
+ Args:
113
+ x (torch.Tensor): The input tensor.
114
+ num (float): The number to replace.
115
+ to (float): The number to replace with.
116
+
117
+ Returns:
118
+ torch.Tensor: The tensor with the number replaced.
119
+ """
120
+ return torch.where(x == num, to, x)
121
+
122
+
123
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
124
+ """Guard the power operation to a valid domain."""
125
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
126
+
127
+
128
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
129
+ val = torch.amin(x, dim=1)
130
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
131
+
132
+
133
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_space_search(
139
+ x: torch.Tensor,
140
+ **kwargs: Dict[str, Any],
141
+ ) -> torch.Tensor:
142
+
143
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
144
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
145
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
146
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
147
+
148
+ def _search_param(tensors: List[torch.tensor], n_params):
149
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
150
+ torch_tensors = torch.stack(tensors)
151
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
152
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
153
+ mean = torch.mean(torch_tensors, dim=0)
154
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
155
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
156
+
157
+ def _calc(x, qtz_func, deqtz_func, **params):
158
+ x_ = x.transpose(0, 1)
159
+ x_ = qtz_func(x=x_, **params)
160
+ x_ = deqtz_func(x=x_, **params)
161
+ x_ = x_.transpose(0, 1)
162
+ return x_
163
+
164
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
165
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
166
+ assert "params_list" in kwargs, "params list must be provided."
167
+ assert "param" in kwargs, "param must be provided."
168
+
169
+ qtz_func = kwargs.get('qtz_func')
170
+ deqtz_func = kwargs.get('deqtz_func')
171
+ params_list = kwargs.get('params_list')
172
+ param = kwargs.get('param')
173
+
174
+ n_runs = 50 # Number of runs to try to find the best parameters
175
+ n_random_params = 50 # Number of random parameters to generate
176
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
177
+ max_initial = 10000 # Maximum value to initialize the parameters
178
+
179
+ # Initializes the parameters
180
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
181
+ params = _build_initial_param(x, max_initial, n_random_params)
182
+
183
+ # Performs the search
184
+ for _ in range(n_runs):
185
+
186
+ best_params = []
187
+ for param_ in params:
188
+ try:
189
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
190
+ loss_ones = nn.MSELoss()(x, x_)
191
+
192
+ if len(best_params) < n_best_to_pick:
193
+ best_params.append((param_, loss_ones.item()))
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+ elif loss_ones < best_params[-1][1]:
196
+ best_params[-1] = (param_, loss_ones.item())
197
+ best_params = sorted(best_params, key=lambda x: x[1])
198
+
199
+ except Exception: # The parameters might not be valid for the function's domain
200
+ continue
201
+
202
+ # Generates new parameters around the mean
203
+ params = _search_param([p for p, _ in best_params], n_random_params)
204
+
205
+ # Checks if the best parameter is better than the init_ones
206
+ p_ones = init_ones(x, **kwargs)
207
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
208
+ loss_ones = nn.MSELoss()(x, x_)
209
+
210
+ # Checks if the best parameter is better than the init_rand
211
+ p_rand = init_rand(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
213
+ loss_rand = nn.MSELoss()(x, x_)
214
+
215
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
216
+ return p_rand
217
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
218
+ return p_ones
219
+ else:
220
+ return best_params[0][0]
221
+
222
+
223
+ def init_linear_scale( # Symmetric scale. From the study folder
224
+ x: torch.Tensor,
225
+ **kwargs: Dict[str, Any],
226
+ ) -> torch.Tensor:
227
+ assert "bits" in kwargs, "bits must be provided."
228
+ assert "params" in kwargs, "params must be provided."
229
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
230
+
231
+ bits = kwargs.get('bits')
232
+ params = kwargs.get('params')
233
+ qtz_func = kwargs.get('qtz_func')
234
+
235
+ x_ = x.transpose(0, 1)
236
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
237
+ x_ = x_.transpose(0, 1)
238
+
239
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
240
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
241
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
242
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
243
+
244
+ eps = torch.finfo(torch.float32).eps
245
+
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
248
+
249
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
250
+
251
+ return scale
252
+
253
+ # Introduces some noise in scale
254
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
255
+ scale = scale + 0.01 * torch.randn_like(scale)
256
+ return scale
257
+
258
+
259
+ def init_non_linear_regression_fit(
260
+ x: torch.Tensor,
261
+ **kwargs: Dict[str, Any],
262
+ ) -> torch.Tensor:
263
+
264
+ assert "params_list" in kwargs, "params list must be provided."
265
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
266
+ assert "p0" in kwargs, "p0 must be provided."
267
+ np_fit_func = kwargs.get('np_fit_func')
268
+ params_list = kwargs.get('params_list')
269
+ p0 = kwargs.get('p0')
270
+
271
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
272
+ popt, _ = curve_fit(
273
+ func,
274
+ xdata,
275
+ ydata,
276
+ maxfev=1000,
277
+ p0=p0,
278
+ method='lm'
279
+ )
280
+ return popt
281
+
282
+ # 1. Needs to convert the torch tensor to numpy tensor
283
+ xdata = x.cpu().numpy()
284
+
285
+ # 2. Sorts the data so that it makes it easier to fit to it
286
+ sorted_xdata = np.sort(xdata, axis=-1)
287
+
288
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
289
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
290
+
291
+ # 3. Finds the best parameters for each channel
292
+ try:
293
+ params = []
294
+ for i in range(sorted_xdata.shape[0]):
295
+ xdata_ = sorted_xdata[i]
296
+ p0_ = [p0[p][i] for p in params_list]
297
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
298
+ params.append(ch_params)
299
+
300
+ # 4. Builds the parameters
301
+ result = {}
302
+ for i, p in enumerate(params_list):
303
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
304
+
305
+ return result
306
+
307
+ except ValueError as e:
308
+ print(f"Could not fit the function with error: {e}")
309
+ print(f"Using fallback result...")
310
+ return {
311
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
312
+ }
313
+
314
+
315
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
316
+ val = torch.amin(x, dim=1)
317
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
318
+
319
+
320
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
321
+ # Calculate the original minimum and maximum values
322
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
323
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
324
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
325
+
326
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
327
+ return torch.ones_like(x_min)
328
+
329
+ # Calculate the scale factor
330
+ scale = (_max - _min) / (x_max - x_min)
331
+ return scale
332
+
333
+
334
+
335
+ ############## Quant ###############
336
+
337
+ @torch.enable_grad()
338
+ def learn_parameters(
339
+ x: torch.Tensor,
340
+ params: Dict[str, nn.Parameter],
341
+ qtz_func: nn.Module,
342
+ deqtz_func: nn.Module,
343
+ bits: int,
344
+ target_dtype: torch.dtype,
345
+ epochs: int = 1000,
346
+ early_stop: bool = True,
347
+ do_report: bool = False
348
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
349
+
350
+ # Requires gradients in the parameters
351
+ for p in params.values():
352
+ p.requires_grad = True
353
+ p.grad = None
354
+
355
+ param_keys = list(params.keys())
356
+ param_values = list(params.values())
357
+
358
+ # Defines optimizer and loss function
359
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
360
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
361
+ loss_fn = nn.MSELoss()
362
+
363
+ # Contains the best loss and the best parameters
364
+ best_loss = float("inf")
365
+ best_params = None
366
+
367
+ # Used to stop the search early
368
+ min_delta = 1e-7
369
+ acc_loss = []
370
+ percent_epochs_before_stop = 0.1
371
+
372
+ for i in range(epochs):
373
+ optimizer.zero_grad()
374
+
375
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
376
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
377
+ loss = loss_fn(x, dequant)
378
+
379
+ if loss.isnan() or loss.isinf():
380
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
381
+
382
+ loss.backward()
383
+ optimizer.step()
384
+ scheduler.step()
385
+
386
+ acc_loss.append(loss.item())
387
+
388
+ # Reports loss every 10 steps
389
+ if i % 10 == 0 and do_report:
390
+ print(f"Epoch {i}: Loss {loss.item()}")
391
+
392
+ # Optimizes the parameter search by storing the best loss and the parameters
393
+ if loss.item() < best_loss:
394
+ best_loss = loss.item()
395
+ best_params = copy.deepcopy({
396
+ k: v for k, v in params.items() if k in param_keys
397
+ })
398
+
399
+ # We also stop the search if the loss has not considerably during the last 10% epochs
400
+ if early_stop:
401
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
402
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
403
+ break
404
+
405
+ # No longer requires gradients in the parameters
406
+ for p in best_params.values():
407
+ p.requires_grad = False
408
+ p.grad = None
409
+
410
+ if do_report:
411
+ return best_params, acc_loss
412
+ else:
413
+ return best_params
414
+
415
+
416
+ def quantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ target_dtype: torch.dtype = torch.int8
422
+ ) -> torch.Tensor:
423
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
424
+ x = x.transpose(0, 1) # Aligns shapes
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
428
+ return x
429
+
430
+
431
+ def dequantize(
432
+ x: torch.Tensor,
433
+ params: Dict[str, nn.Parameter],
434
+ func: nn.Module,
435
+ bits: int,
436
+ out_dtype: torch.dtype
437
+ ) -> torch.Tensor:
438
+ x = x.to(dtype=out_dtype)
439
+ x = x.transpose(0, 1)
440
+ x = func(x=x, **params)
441
+ x = x.transpose(0, 1)
442
+ return x
443
+
444
+
445
+ def round_func_BPDA(input):
446
+ # This is equivalent to replacing round function (non-differentiable) with
447
+ # an identity function (differentiable) only when backward.
448
+ forward_value = torch.round(input)
449
+ out = input.clone()
450
+ out.data = forward_value.data
451
+ return out
452
+
453
+
454
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
455
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
456
+
457
+
458
+
459
+ ############## Numpy ###############
460
+
461
+ def np_domain_guard(
462
+ x: np.ndarray,
463
+ min: float = None,
464
+ max: float = None,
465
+ posinf: float = None,
466
+ neginf: float = None,
467
+ nan: float = None
468
+ ) -> np.ndarray:
469
+ """Guard a tensor to a valid domain."""
470
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
471
+ if min is not None or max is not None:
472
+ x = np.clip(x, min, max)
473
+ return x
474
+
475
+
476
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
477
+ """Replace a number in a tensor with another number.
478
+
479
+ Args:
480
+ x (np.ndarray): The input tensor.
481
+ num (float): The number to replace.
482
+ to (float): The number to replace with.
483
+
484
+ Returns:
485
+ np.ndarray: The tensor with the number replaced.
486
+ """
487
+ return np.where(x == num, to, x)
488
+
489
+
490
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
491
+ """Guard the power operation to a valid domain."""
492
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
493
+
fn_gen/ones_affine_scale/11/loss.png ADDED
fn_gen/ones_affine_scale/11/quantization.png ADDED
fn_gen/ones_affine_scale/12/distortion.png ADDED
fn_gen/ones_affine_scale/12/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acos(_0*x)/_s
2
+ cos(_s*x)/_0
fn_gen/ones_affine_scale/12/fn.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
60
+ }
61
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
62
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
63
+
64
+ if 'post_init_hook' in kwargs:
65
+ kwargs['post_init_hook'](parameters=params)
66
+
67
+
68
+ if 'post_train_hook' in kwargs:
69
+ kwargs['post_train_hook'](parameters=params)
70
+
71
+ return params
72
+
73
+
74
+ ############### Numpy Qtz ###############
75
+
76
+
77
+ def np_quantization(x, _0, _s):
78
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
79
+
80
+
81
+ def np_dequantization(x, _0, _s):
82
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x)))
83
+
84
+
85
+ def fit_func(x, _0, _s):
86
+ x_ = np_quantization(x, _0, _s)
87
+ x_ = np_dequantization(x_, _0, _s)
88
+ return x_
89
+
90
+
91
+
92
+ ############### HELPERS ###############
93
+
94
+ def domain_guard(
95
+ x: torch.Tensor,
96
+ min: float = None,
97
+ max: float = None,
98
+ posinf: float = None,
99
+ neginf: float = None,
100
+ nan: float = None
101
+ ) -> torch.Tensor:
102
+ """Guard a tensor to a valid domain."""
103
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
104
+ if min is not None or max is not None:
105
+ x = torch.clamp(x, min=min, max=max)
106
+ return x
107
+
108
+
109
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
110
+ """Replace a number in a tensor with another number.
111
+
112
+ Args:
113
+ x (torch.Tensor): The input tensor.
114
+ num (float): The number to replace.
115
+ to (float): The number to replace with.
116
+
117
+ Returns:
118
+ torch.Tensor: The tensor with the number replaced.
119
+ """
120
+ return torch.where(x == num, to, x)
121
+
122
+
123
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
124
+ """Guard the power operation to a valid domain."""
125
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
126
+
127
+
128
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
129
+ val = torch.amin(x, dim=1)
130
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
131
+
132
+
133
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_space_search(
139
+ x: torch.Tensor,
140
+ **kwargs: Dict[str, Any],
141
+ ) -> torch.Tensor:
142
+
143
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
144
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
145
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
146
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
147
+
148
+ def _search_param(tensors: List[torch.tensor], n_params):
149
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
150
+ torch_tensors = torch.stack(tensors)
151
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
152
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
153
+ mean = torch.mean(torch_tensors, dim=0)
154
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
155
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
156
+
157
+ def _calc(x, qtz_func, deqtz_func, **params):
158
+ x_ = x.transpose(0, 1)
159
+ x_ = qtz_func(x=x_, **params)
160
+ x_ = deqtz_func(x=x_, **params)
161
+ x_ = x_.transpose(0, 1)
162
+ return x_
163
+
164
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
165
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
166
+ assert "params_list" in kwargs, "params list must be provided."
167
+ assert "param" in kwargs, "param must be provided."
168
+
169
+ qtz_func = kwargs.get('qtz_func')
170
+ deqtz_func = kwargs.get('deqtz_func')
171
+ params_list = kwargs.get('params_list')
172
+ param = kwargs.get('param')
173
+
174
+ n_runs = 50 # Number of runs to try to find the best parameters
175
+ n_random_params = 50 # Number of random parameters to generate
176
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
177
+ max_initial = 10000 # Maximum value to initialize the parameters
178
+
179
+ # Initializes the parameters
180
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
181
+ params = _build_initial_param(x, max_initial, n_random_params)
182
+
183
+ # Performs the search
184
+ for _ in range(n_runs):
185
+
186
+ best_params = []
187
+ for param_ in params:
188
+ try:
189
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
190
+ loss_ones = nn.MSELoss()(x, x_)
191
+
192
+ if len(best_params) < n_best_to_pick:
193
+ best_params.append((param_, loss_ones.item()))
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+ elif loss_ones < best_params[-1][1]:
196
+ best_params[-1] = (param_, loss_ones.item())
197
+ best_params = sorted(best_params, key=lambda x: x[1])
198
+
199
+ except Exception: # The parameters might not be valid for the function's domain
200
+ continue
201
+
202
+ # Generates new parameters around the mean
203
+ params = _search_param([p for p, _ in best_params], n_random_params)
204
+
205
+ # Checks if the best parameter is better than the init_ones
206
+ p_ones = init_ones(x, **kwargs)
207
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
208
+ loss_ones = nn.MSELoss()(x, x_)
209
+
210
+ # Checks if the best parameter is better than the init_rand
211
+ p_rand = init_rand(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
213
+ loss_rand = nn.MSELoss()(x, x_)
214
+
215
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
216
+ return p_rand
217
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
218
+ return p_ones
219
+ else:
220
+ return best_params[0][0]
221
+
222
+
223
+ def init_linear_scale( # Symmetric scale. From the study folder
224
+ x: torch.Tensor,
225
+ **kwargs: Dict[str, Any],
226
+ ) -> torch.Tensor:
227
+ assert "bits" in kwargs, "bits must be provided."
228
+ assert "params" in kwargs, "params must be provided."
229
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
230
+
231
+ bits = kwargs.get('bits')
232
+ params = kwargs.get('params')
233
+ qtz_func = kwargs.get('qtz_func')
234
+
235
+ x_ = x.transpose(0, 1)
236
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
237
+ x_ = x_.transpose(0, 1)
238
+
239
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
240
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
241
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
242
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
243
+
244
+ eps = torch.finfo(torch.float32).eps
245
+
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
248
+
249
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
250
+
251
+ return scale
252
+
253
+ # Introduces some noise in scale
254
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
255
+ scale = scale + 0.01 * torch.randn_like(scale)
256
+ return scale
257
+
258
+
259
+ def init_non_linear_regression_fit(
260
+ x: torch.Tensor,
261
+ **kwargs: Dict[str, Any],
262
+ ) -> torch.Tensor:
263
+
264
+ assert "params_list" in kwargs, "params list must be provided."
265
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
266
+ assert "p0" in kwargs, "p0 must be provided."
267
+ np_fit_func = kwargs.get('np_fit_func')
268
+ params_list = kwargs.get('params_list')
269
+ p0 = kwargs.get('p0')
270
+
271
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
272
+ popt, _ = curve_fit(
273
+ func,
274
+ xdata,
275
+ ydata,
276
+ maxfev=1000,
277
+ p0=p0,
278
+ method='lm'
279
+ )
280
+ return popt
281
+
282
+ # 1. Needs to convert the torch tensor to numpy tensor
283
+ xdata = x.cpu().numpy()
284
+
285
+ # 2. Sorts the data so that it makes it easier to fit to it
286
+ sorted_xdata = np.sort(xdata, axis=-1)
287
+
288
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
289
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
290
+
291
+ # 3. Finds the best parameters for each channel
292
+ try:
293
+ params = []
294
+ for i in range(sorted_xdata.shape[0]):
295
+ xdata_ = sorted_xdata[i]
296
+ p0_ = [p0[p][i] for p in params_list]
297
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
298
+ params.append(ch_params)
299
+
300
+ # 4. Builds the parameters
301
+ result = {}
302
+ for i, p in enumerate(params_list):
303
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
304
+
305
+ return result
306
+
307
+ except ValueError as e:
308
+ print(f"Could not fit the function with error: {e}")
309
+ print(f"Using fallback result...")
310
+ return {
311
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
312
+ }
313
+
314
+
315
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
316
+ val = torch.amin(x, dim=1)
317
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
318
+
319
+
320
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
321
+ # Calculate the original minimum and maximum values
322
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
323
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
324
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
325
+
326
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
327
+ return torch.ones_like(x_min)
328
+
329
+ # Calculate the scale factor
330
+ scale = (_max - _min) / (x_max - x_min)
331
+ return scale
332
+
333
+
334
+
335
+ ############## Quant ###############
336
+
337
+ @torch.enable_grad()
338
+ def learn_parameters(
339
+ x: torch.Tensor,
340
+ params: Dict[str, nn.Parameter],
341
+ qtz_func: nn.Module,
342
+ deqtz_func: nn.Module,
343
+ bits: int,
344
+ target_dtype: torch.dtype,
345
+ epochs: int = 1000,
346
+ early_stop: bool = True,
347
+ do_report: bool = False
348
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
349
+
350
+ # Requires gradients in the parameters
351
+ for p in params.values():
352
+ p.requires_grad = True
353
+ p.grad = None
354
+
355
+ param_keys = list(params.keys())
356
+ param_values = list(params.values())
357
+
358
+ # Defines optimizer and loss function
359
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
360
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
361
+ loss_fn = nn.MSELoss()
362
+
363
+ # Contains the best loss and the best parameters
364
+ best_loss = float("inf")
365
+ best_params = None
366
+
367
+ # Used to stop the search early
368
+ min_delta = 1e-7
369
+ acc_loss = []
370
+ percent_epochs_before_stop = 0.1
371
+
372
+ for i in range(epochs):
373
+ optimizer.zero_grad()
374
+
375
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
376
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
377
+ loss = loss_fn(x, dequant)
378
+
379
+ if loss.isnan() or loss.isinf():
380
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
381
+
382
+ loss.backward()
383
+ optimizer.step()
384
+ scheduler.step()
385
+
386
+ acc_loss.append(loss.item())
387
+
388
+ # Reports loss every 10 steps
389
+ if i % 10 == 0 and do_report:
390
+ print(f"Epoch {i}: Loss {loss.item()}")
391
+
392
+ # Optimizes the parameter search by storing the best loss and the parameters
393
+ if loss.item() < best_loss:
394
+ best_loss = loss.item()
395
+ best_params = copy.deepcopy({
396
+ k: v for k, v in params.items() if k in param_keys
397
+ })
398
+
399
+ # We also stop the search if the loss has not considerably during the last 10% epochs
400
+ if early_stop:
401
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
402
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
403
+ break
404
+
405
+ # No longer requires gradients in the parameters
406
+ for p in best_params.values():
407
+ p.requires_grad = False
408
+ p.grad = None
409
+
410
+ if do_report:
411
+ return best_params, acc_loss
412
+ else:
413
+ return best_params
414
+
415
+
416
+ def quantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ target_dtype: torch.dtype = torch.int8
422
+ ) -> torch.Tensor:
423
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
424
+ x = x.transpose(0, 1) # Aligns shapes
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
428
+ return x
429
+
430
+
431
+ def dequantize(
432
+ x: torch.Tensor,
433
+ params: Dict[str, nn.Parameter],
434
+ func: nn.Module,
435
+ bits: int,
436
+ out_dtype: torch.dtype
437
+ ) -> torch.Tensor:
438
+ x = x.to(dtype=out_dtype)
439
+ x = x.transpose(0, 1)
440
+ x = func(x=x, **params)
441
+ x = x.transpose(0, 1)
442
+ return x
443
+
444
+
445
+ def round_func_BPDA(input):
446
+ # This is equivalent to replacing round function (non-differentiable) with
447
+ # an identity function (differentiable) only when backward.
448
+ forward_value = torch.round(input)
449
+ out = input.clone()
450
+ out.data = forward_value.data
451
+ return out
452
+
453
+
454
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
455
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
456
+
457
+
458
+
459
+ ############## Numpy ###############
460
+
461
+ def np_domain_guard(
462
+ x: np.ndarray,
463
+ min: float = None,
464
+ max: float = None,
465
+ posinf: float = None,
466
+ neginf: float = None,
467
+ nan: float = None
468
+ ) -> np.ndarray:
469
+ """Guard a tensor to a valid domain."""
470
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
471
+ if min is not None or max is not None:
472
+ x = np.clip(x, min, max)
473
+ return x
474
+
475
+
476
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
477
+ """Replace a number in a tensor with another number.
478
+
479
+ Args:
480
+ x (np.ndarray): The input tensor.
481
+ num (float): The number to replace.
482
+ to (float): The number to replace with.
483
+
484
+ Returns:
485
+ np.ndarray: The tensor with the number replaced.
486
+ """
487
+ return np.where(x == num, to, x)
488
+
489
+
490
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
491
+ """Guard the power operation to a valid domain."""
492
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
493
+
fn_gen/ones_affine_scale/12/loss.png ADDED
fn_gen/ones_affine_scale/12/quantization.png ADDED
fn_gen/ones_affine_scale/13/distortion.png ADDED
fn_gen/ones_affine_scale/13/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acosh(_0*x)/_s
2
+ cosh(_s*x)/_0
fn_gen/ones_affine_scale/13/fn.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acosh(domain_guard((params['_0'] * x), min=1, nan=1)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cosh((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
60
+ }
61
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
62
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
63
+
64
+ if 'post_init_hook' in kwargs:
65
+ kwargs['post_init_hook'](parameters=params)
66
+
67
+
68
+ if 'post_train_hook' in kwargs:
69
+ kwargs['post_train_hook'](parameters=params)
70
+
71
+ return params
72
+
73
+
74
+ ############### Numpy Qtz ###############
75
+
76
+
77
+ def np_quantization(x, _0, _s):
78
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccosh(np_domain_guard((_0 * x), min=1, nan=1)))
79
+
80
+
81
+ def np_dequantization(x, _0, _s):
82
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cosh((_s * x)))
83
+
84
+
85
+ def fit_func(x, _0, _s):
86
+ x_ = np_quantization(x, _0, _s)
87
+ x_ = np_dequantization(x_, _0, _s)
88
+ return x_
89
+
90
+
91
+
92
+ ############### HELPERS ###############
93
+
94
+ def domain_guard(
95
+ x: torch.Tensor,
96
+ min: float = None,
97
+ max: float = None,
98
+ posinf: float = None,
99
+ neginf: float = None,
100
+ nan: float = None
101
+ ) -> torch.Tensor:
102
+ """Guard a tensor to a valid domain."""
103
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
104
+ if min is not None or max is not None:
105
+ x = torch.clamp(x, min=min, max=max)
106
+ return x
107
+
108
+
109
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
110
+ """Replace a number in a tensor with another number.
111
+
112
+ Args:
113
+ x (torch.Tensor): The input tensor.
114
+ num (float): The number to replace.
115
+ to (float): The number to replace with.
116
+
117
+ Returns:
118
+ torch.Tensor: The tensor with the number replaced.
119
+ """
120
+ return torch.where(x == num, to, x)
121
+
122
+
123
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
124
+ """Guard the power operation to a valid domain."""
125
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
126
+
127
+
128
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
129
+ val = torch.amin(x, dim=1)
130
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
131
+
132
+
133
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_space_search(
139
+ x: torch.Tensor,
140
+ **kwargs: Dict[str, Any],
141
+ ) -> torch.Tensor:
142
+
143
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
144
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
145
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
146
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
147
+
148
+ def _search_param(tensors: List[torch.tensor], n_params):
149
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
150
+ torch_tensors = torch.stack(tensors)
151
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
152
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
153
+ mean = torch.mean(torch_tensors, dim=0)
154
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
155
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
156
+
157
+ def _calc(x, qtz_func, deqtz_func, **params):
158
+ x_ = x.transpose(0, 1)
159
+ x_ = qtz_func(x=x_, **params)
160
+ x_ = deqtz_func(x=x_, **params)
161
+ x_ = x_.transpose(0, 1)
162
+ return x_
163
+
164
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
165
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
166
+ assert "params_list" in kwargs, "params list must be provided."
167
+ assert "param" in kwargs, "param must be provided."
168
+
169
+ qtz_func = kwargs.get('qtz_func')
170
+ deqtz_func = kwargs.get('deqtz_func')
171
+ params_list = kwargs.get('params_list')
172
+ param = kwargs.get('param')
173
+
174
+ n_runs = 50 # Number of runs to try to find the best parameters
175
+ n_random_params = 50 # Number of random parameters to generate
176
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
177
+ max_initial = 10000 # Maximum value to initialize the parameters
178
+
179
+ # Initializes the parameters
180
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
181
+ params = _build_initial_param(x, max_initial, n_random_params)
182
+
183
+ # Performs the search
184
+ for _ in range(n_runs):
185
+
186
+ best_params = []
187
+ for param_ in params:
188
+ try:
189
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
190
+ loss_ones = nn.MSELoss()(x, x_)
191
+
192
+ if len(best_params) < n_best_to_pick:
193
+ best_params.append((param_, loss_ones.item()))
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+ elif loss_ones < best_params[-1][1]:
196
+ best_params[-1] = (param_, loss_ones.item())
197
+ best_params = sorted(best_params, key=lambda x: x[1])
198
+
199
+ except Exception: # The parameters might not be valid for the function's domain
200
+ continue
201
+
202
+ # Generates new parameters around the mean
203
+ params = _search_param([p for p, _ in best_params], n_random_params)
204
+
205
+ # Checks if the best parameter is better than the init_ones
206
+ p_ones = init_ones(x, **kwargs)
207
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
208
+ loss_ones = nn.MSELoss()(x, x_)
209
+
210
+ # Checks if the best parameter is better than the init_rand
211
+ p_rand = init_rand(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
213
+ loss_rand = nn.MSELoss()(x, x_)
214
+
215
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
216
+ return p_rand
217
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
218
+ return p_ones
219
+ else:
220
+ return best_params[0][0]
221
+
222
+
223
+ def init_linear_scale( # Symmetric scale. From the study folder
224
+ x: torch.Tensor,
225
+ **kwargs: Dict[str, Any],
226
+ ) -> torch.Tensor:
227
+ assert "bits" in kwargs, "bits must be provided."
228
+ assert "params" in kwargs, "params must be provided."
229
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
230
+
231
+ bits = kwargs.get('bits')
232
+ params = kwargs.get('params')
233
+ qtz_func = kwargs.get('qtz_func')
234
+
235
+ x_ = x.transpose(0, 1)
236
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
237
+ x_ = x_.transpose(0, 1)
238
+
239
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
240
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
241
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
242
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
243
+
244
+ eps = torch.finfo(torch.float32).eps
245
+
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
248
+
249
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
250
+
251
+ return scale
252
+
253
+ # Introduces some noise in scale
254
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
255
+ scale = scale + 0.01 * torch.randn_like(scale)
256
+ return scale
257
+
258
+
259
+ def init_non_linear_regression_fit(
260
+ x: torch.Tensor,
261
+ **kwargs: Dict[str, Any],
262
+ ) -> torch.Tensor:
263
+
264
+ assert "params_list" in kwargs, "params list must be provided."
265
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
266
+ assert "p0" in kwargs, "p0 must be provided."
267
+ np_fit_func = kwargs.get('np_fit_func')
268
+ params_list = kwargs.get('params_list')
269
+ p0 = kwargs.get('p0')
270
+
271
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
272
+ popt, _ = curve_fit(
273
+ func,
274
+ xdata,
275
+ ydata,
276
+ maxfev=1000,
277
+ p0=p0,
278
+ method='lm'
279
+ )
280
+ return popt
281
+
282
+ # 1. Needs to convert the torch tensor to numpy tensor
283
+ xdata = x.cpu().numpy()
284
+
285
+ # 2. Sorts the data so that it makes it easier to fit to it
286
+ sorted_xdata = np.sort(xdata, axis=-1)
287
+
288
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
289
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
290
+
291
+ # 3. Finds the best parameters for each channel
292
+ try:
293
+ params = []
294
+ for i in range(sorted_xdata.shape[0]):
295
+ xdata_ = sorted_xdata[i]
296
+ p0_ = [p0[p][i] for p in params_list]
297
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
298
+ params.append(ch_params)
299
+
300
+ # 4. Builds the parameters
301
+ result = {}
302
+ for i, p in enumerate(params_list):
303
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
304
+
305
+ return result
306
+
307
+ except ValueError as e:
308
+ print(f"Could not fit the function with error: {e}")
309
+ print(f"Using fallback result...")
310
+ return {
311
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
312
+ }
313
+
314
+
315
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
316
+ val = torch.amin(x, dim=1)
317
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
318
+
319
+
320
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
321
+ # Calculate the original minimum and maximum values
322
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
323
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
324
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
325
+
326
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
327
+ return torch.ones_like(x_min)
328
+
329
+ # Calculate the scale factor
330
+ scale = (_max - _min) / (x_max - x_min)
331
+ return scale
332
+
333
+
334
+
335
+ ############## Quant ###############
336
+
337
+ @torch.enable_grad()
338
+ def learn_parameters(
339
+ x: torch.Tensor,
340
+ params: Dict[str, nn.Parameter],
341
+ qtz_func: nn.Module,
342
+ deqtz_func: nn.Module,
343
+ bits: int,
344
+ target_dtype: torch.dtype,
345
+ epochs: int = 1000,
346
+ early_stop: bool = True,
347
+ do_report: bool = False
348
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
349
+
350
+ # Requires gradients in the parameters
351
+ for p in params.values():
352
+ p.requires_grad = True
353
+ p.grad = None
354
+
355
+ param_keys = list(params.keys())
356
+ param_values = list(params.values())
357
+
358
+ # Defines optimizer and loss function
359
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
360
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
361
+ loss_fn = nn.MSELoss()
362
+
363
+ # Contains the best loss and the best parameters
364
+ best_loss = float("inf")
365
+ best_params = None
366
+
367
+ # Used to stop the search early
368
+ min_delta = 1e-7
369
+ acc_loss = []
370
+ percent_epochs_before_stop = 0.1
371
+
372
+ for i in range(epochs):
373
+ optimizer.zero_grad()
374
+
375
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
376
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
377
+ loss = loss_fn(x, dequant)
378
+
379
+ if loss.isnan() or loss.isinf():
380
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
381
+
382
+ loss.backward()
383
+ optimizer.step()
384
+ scheduler.step()
385
+
386
+ acc_loss.append(loss.item())
387
+
388
+ # Reports loss every 10 steps
389
+ if i % 10 == 0 and do_report:
390
+ print(f"Epoch {i}: Loss {loss.item()}")
391
+
392
+ # Optimizes the parameter search by storing the best loss and the parameters
393
+ if loss.item() < best_loss:
394
+ best_loss = loss.item()
395
+ best_params = copy.deepcopy({
396
+ k: v for k, v in params.items() if k in param_keys
397
+ })
398
+
399
+ # We also stop the search if the loss has not considerably during the last 10% epochs
400
+ if early_stop:
401
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
402
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
403
+ break
404
+
405
+ # No longer requires gradients in the parameters
406
+ for p in best_params.values():
407
+ p.requires_grad = False
408
+ p.grad = None
409
+
410
+ if do_report:
411
+ return best_params, acc_loss
412
+ else:
413
+ return best_params
414
+
415
+
416
+ def quantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ target_dtype: torch.dtype = torch.int8
422
+ ) -> torch.Tensor:
423
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
424
+ x = x.transpose(0, 1) # Aligns shapes
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
428
+ return x
429
+
430
+
431
+ def dequantize(
432
+ x: torch.Tensor,
433
+ params: Dict[str, nn.Parameter],
434
+ func: nn.Module,
435
+ bits: int,
436
+ out_dtype: torch.dtype
437
+ ) -> torch.Tensor:
438
+ x = x.to(dtype=out_dtype)
439
+ x = x.transpose(0, 1)
440
+ x = func(x=x, **params)
441
+ x = x.transpose(0, 1)
442
+ return x
443
+
444
+
445
+ def round_func_BPDA(input):
446
+ # This is equivalent to replacing round function (non-differentiable) with
447
+ # an identity function (differentiable) only when backward.
448
+ forward_value = torch.round(input)
449
+ out = input.clone()
450
+ out.data = forward_value.data
451
+ return out
452
+
453
+
454
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
455
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
456
+
457
+
458
+
459
+ ############## Numpy ###############
460
+
461
+ def np_domain_guard(
462
+ x: np.ndarray,
463
+ min: float = None,
464
+ max: float = None,
465
+ posinf: float = None,
466
+ neginf: float = None,
467
+ nan: float = None
468
+ ) -> np.ndarray:
469
+ """Guard a tensor to a valid domain."""
470
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
471
+ if min is not None or max is not None:
472
+ x = np.clip(x, min, max)
473
+ return x
474
+
475
+
476
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
477
+ """Replace a number in a tensor with another number.
478
+
479
+ Args:
480
+ x (np.ndarray): The input tensor.
481
+ num (float): The number to replace.
482
+ to (float): The number to replace with.
483
+
484
+ Returns:
485
+ np.ndarray: The tensor with the number replaced.
486
+ """
487
+ return np.where(x == num, to, x)
488
+
489
+
490
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
491
+ """Guard the power operation to a valid domain."""
492
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
493
+
fn_gen/ones_affine_scale/13/loss.png ADDED
fn_gen/ones_affine_scale/13/quantization.png ADDED
fn_gen/ones_affine_scale/14/distortion.png ADDED
fn_gen/ones_affine_scale/14/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tan(_0*x)/_s
2
+ atan(_s*x)/_0
fn_gen/ones_affine_scale/14/fn.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
60
+ }
61
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
62
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
63
+
64
+ if 'post_init_hook' in kwargs:
65
+ kwargs['post_init_hook'](parameters=params)
66
+
67
+
68
+ if 'post_train_hook' in kwargs:
69
+ kwargs['post_train_hook'](parameters=params)
70
+
71
+ return params
72
+
73
+
74
+ ############### Numpy Qtz ###############
75
+
76
+
77
+ def np_quantization(x, _0, _s):
78
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0)))
79
+
80
+
81
+ def np_dequantization(x, _0, _s):
82
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x)))
83
+
84
+
85
+ def fit_func(x, _0, _s):
86
+ x_ = np_quantization(x, _0, _s)
87
+ x_ = np_dequantization(x_, _0, _s)
88
+ return x_
89
+
90
+
91
+
92
+ ############### HELPERS ###############
93
+
94
+ def domain_guard(
95
+ x: torch.Tensor,
96
+ min: float = None,
97
+ max: float = None,
98
+ posinf: float = None,
99
+ neginf: float = None,
100
+ nan: float = None
101
+ ) -> torch.Tensor:
102
+ """Guard a tensor to a valid domain."""
103
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
104
+ if min is not None or max is not None:
105
+ x = torch.clamp(x, min=min, max=max)
106
+ return x
107
+
108
+
109
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
110
+ """Replace a number in a tensor with another number.
111
+
112
+ Args:
113
+ x (torch.Tensor): The input tensor.
114
+ num (float): The number to replace.
115
+ to (float): The number to replace with.
116
+
117
+ Returns:
118
+ torch.Tensor: The tensor with the number replaced.
119
+ """
120
+ return torch.where(x == num, to, x)
121
+
122
+
123
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
124
+ """Guard the power operation to a valid domain."""
125
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
126
+
127
+
128
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
129
+ val = torch.amin(x, dim=1)
130
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
131
+
132
+
133
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_space_search(
139
+ x: torch.Tensor,
140
+ **kwargs: Dict[str, Any],
141
+ ) -> torch.Tensor:
142
+
143
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
144
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
145
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
146
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
147
+
148
+ def _search_param(tensors: List[torch.tensor], n_params):
149
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
150
+ torch_tensors = torch.stack(tensors)
151
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
152
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
153
+ mean = torch.mean(torch_tensors, dim=0)
154
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
155
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
156
+
157
+ def _calc(x, qtz_func, deqtz_func, **params):
158
+ x_ = x.transpose(0, 1)
159
+ x_ = qtz_func(x=x_, **params)
160
+ x_ = deqtz_func(x=x_, **params)
161
+ x_ = x_.transpose(0, 1)
162
+ return x_
163
+
164
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
165
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
166
+ assert "params_list" in kwargs, "params list must be provided."
167
+ assert "param" in kwargs, "param must be provided."
168
+
169
+ qtz_func = kwargs.get('qtz_func')
170
+ deqtz_func = kwargs.get('deqtz_func')
171
+ params_list = kwargs.get('params_list')
172
+ param = kwargs.get('param')
173
+
174
+ n_runs = 50 # Number of runs to try to find the best parameters
175
+ n_random_params = 50 # Number of random parameters to generate
176
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
177
+ max_initial = 10000 # Maximum value to initialize the parameters
178
+
179
+ # Initializes the parameters
180
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
181
+ params = _build_initial_param(x, max_initial, n_random_params)
182
+
183
+ # Performs the search
184
+ for _ in range(n_runs):
185
+
186
+ best_params = []
187
+ for param_ in params:
188
+ try:
189
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
190
+ loss_ones = nn.MSELoss()(x, x_)
191
+
192
+ if len(best_params) < n_best_to_pick:
193
+ best_params.append((param_, loss_ones.item()))
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+ elif loss_ones < best_params[-1][1]:
196
+ best_params[-1] = (param_, loss_ones.item())
197
+ best_params = sorted(best_params, key=lambda x: x[1])
198
+
199
+ except Exception: # The parameters might not be valid for the function's domain
200
+ continue
201
+
202
+ # Generates new parameters around the mean
203
+ params = _search_param([p for p, _ in best_params], n_random_params)
204
+
205
+ # Checks if the best parameter is better than the init_ones
206
+ p_ones = init_ones(x, **kwargs)
207
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
208
+ loss_ones = nn.MSELoss()(x, x_)
209
+
210
+ # Checks if the best parameter is better than the init_rand
211
+ p_rand = init_rand(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
213
+ loss_rand = nn.MSELoss()(x, x_)
214
+
215
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
216
+ return p_rand
217
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
218
+ return p_ones
219
+ else:
220
+ return best_params[0][0]
221
+
222
+
223
+ def init_linear_scale( # Symmetric scale. From the study folder
224
+ x: torch.Tensor,
225
+ **kwargs: Dict[str, Any],
226
+ ) -> torch.Tensor:
227
+ assert "bits" in kwargs, "bits must be provided."
228
+ assert "params" in kwargs, "params must be provided."
229
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
230
+
231
+ bits = kwargs.get('bits')
232
+ params = kwargs.get('params')
233
+ qtz_func = kwargs.get('qtz_func')
234
+
235
+ x_ = x.transpose(0, 1)
236
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
237
+ x_ = x_.transpose(0, 1)
238
+
239
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
240
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
241
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
242
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
243
+
244
+ eps = torch.finfo(torch.float32).eps
245
+
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
248
+
249
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
250
+
251
+ return scale
252
+
253
+ # Introduces some noise in scale
254
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
255
+ scale = scale + 0.01 * torch.randn_like(scale)
256
+ return scale
257
+
258
+
259
+ def init_non_linear_regression_fit(
260
+ x: torch.Tensor,
261
+ **kwargs: Dict[str, Any],
262
+ ) -> torch.Tensor:
263
+
264
+ assert "params_list" in kwargs, "params list must be provided."
265
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
266
+ assert "p0" in kwargs, "p0 must be provided."
267
+ np_fit_func = kwargs.get('np_fit_func')
268
+ params_list = kwargs.get('params_list')
269
+ p0 = kwargs.get('p0')
270
+
271
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
272
+ popt, _ = curve_fit(
273
+ func,
274
+ xdata,
275
+ ydata,
276
+ maxfev=1000,
277
+ p0=p0,
278
+ method='lm'
279
+ )
280
+ return popt
281
+
282
+ # 1. Needs to convert the torch tensor to numpy tensor
283
+ xdata = x.cpu().numpy()
284
+
285
+ # 2. Sorts the data so that it makes it easier to fit to it
286
+ sorted_xdata = np.sort(xdata, axis=-1)
287
+
288
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
289
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
290
+
291
+ # 3. Finds the best parameters for each channel
292
+ try:
293
+ params = []
294
+ for i in range(sorted_xdata.shape[0]):
295
+ xdata_ = sorted_xdata[i]
296
+ p0_ = [p0[p][i] for p in params_list]
297
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
298
+ params.append(ch_params)
299
+
300
+ # 4. Builds the parameters
301
+ result = {}
302
+ for i, p in enumerate(params_list):
303
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
304
+
305
+ return result
306
+
307
+ except ValueError as e:
308
+ print(f"Could not fit the function with error: {e}")
309
+ print(f"Using fallback result...")
310
+ return {
311
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
312
+ }
313
+
314
+
315
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
316
+ val = torch.amin(x, dim=1)
317
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
318
+
319
+
320
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
321
+ # Calculate the original minimum and maximum values
322
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
323
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
324
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
325
+
326
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
327
+ return torch.ones_like(x_min)
328
+
329
+ # Calculate the scale factor
330
+ scale = (_max - _min) / (x_max - x_min)
331
+ return scale
332
+
333
+
334
+
335
+ ############## Quant ###############
336
+
337
+ @torch.enable_grad()
338
+ def learn_parameters(
339
+ x: torch.Tensor,
340
+ params: Dict[str, nn.Parameter],
341
+ qtz_func: nn.Module,
342
+ deqtz_func: nn.Module,
343
+ bits: int,
344
+ target_dtype: torch.dtype,
345
+ epochs: int = 1000,
346
+ early_stop: bool = True,
347
+ do_report: bool = False
348
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
349
+
350
+ # Requires gradients in the parameters
351
+ for p in params.values():
352
+ p.requires_grad = True
353
+ p.grad = None
354
+
355
+ param_keys = list(params.keys())
356
+ param_values = list(params.values())
357
+
358
+ # Defines optimizer and loss function
359
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
360
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
361
+ loss_fn = nn.MSELoss()
362
+
363
+ # Contains the best loss and the best parameters
364
+ best_loss = float("inf")
365
+ best_params = None
366
+
367
+ # Used to stop the search early
368
+ min_delta = 1e-7
369
+ acc_loss = []
370
+ percent_epochs_before_stop = 0.1
371
+
372
+ for i in range(epochs):
373
+ optimizer.zero_grad()
374
+
375
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
376
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
377
+ loss = loss_fn(x, dequant)
378
+
379
+ if loss.isnan() or loss.isinf():
380
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
381
+
382
+ loss.backward()
383
+ optimizer.step()
384
+ scheduler.step()
385
+
386
+ acc_loss.append(loss.item())
387
+
388
+ # Reports loss every 10 steps
389
+ if i % 10 == 0 and do_report:
390
+ print(f"Epoch {i}: Loss {loss.item()}")
391
+
392
+ # Optimizes the parameter search by storing the best loss and the parameters
393
+ if loss.item() < best_loss:
394
+ best_loss = loss.item()
395
+ best_params = copy.deepcopy({
396
+ k: v for k, v in params.items() if k in param_keys
397
+ })
398
+
399
+ # We also stop the search if the loss has not considerably during the last 10% epochs
400
+ if early_stop:
401
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
402
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
403
+ break
404
+
405
+ # No longer requires gradients in the parameters
406
+ for p in best_params.values():
407
+ p.requires_grad = False
408
+ p.grad = None
409
+
410
+ if do_report:
411
+ return best_params, acc_loss
412
+ else:
413
+ return best_params
414
+
415
+
416
+ def quantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ target_dtype: torch.dtype = torch.int8
422
+ ) -> torch.Tensor:
423
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
424
+ x = x.transpose(0, 1) # Aligns shapes
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
428
+ return x
429
+
430
+
431
+ def dequantize(
432
+ x: torch.Tensor,
433
+ params: Dict[str, nn.Parameter],
434
+ func: nn.Module,
435
+ bits: int,
436
+ out_dtype: torch.dtype
437
+ ) -> torch.Tensor:
438
+ x = x.to(dtype=out_dtype)
439
+ x = x.transpose(0, 1)
440
+ x = func(x=x, **params)
441
+ x = x.transpose(0, 1)
442
+ return x
443
+
444
+
445
+ def round_func_BPDA(input):
446
+ # This is equivalent to replacing round function (non-differentiable) with
447
+ # an identity function (differentiable) only when backward.
448
+ forward_value = torch.round(input)
449
+ out = input.clone()
450
+ out.data = forward_value.data
451
+ return out
452
+
453
+
454
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
455
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
456
+
457
+
458
+
459
+ ############## Numpy ###############
460
+
461
+ def np_domain_guard(
462
+ x: np.ndarray,
463
+ min: float = None,
464
+ max: float = None,
465
+ posinf: float = None,
466
+ neginf: float = None,
467
+ nan: float = None
468
+ ) -> np.ndarray:
469
+ """Guard a tensor to a valid domain."""
470
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
471
+ if min is not None or max is not None:
472
+ x = np.clip(x, min, max)
473
+ return x
474
+
475
+
476
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
477
+ """Replace a number in a tensor with another number.
478
+
479
+ Args:
480
+ x (np.ndarray): The input tensor.
481
+ num (float): The number to replace.
482
+ to (float): The number to replace with.
483
+
484
+ Returns:
485
+ np.ndarray: The tensor with the number replaced.
486
+ """
487
+ return np.where(x == num, to, x)
488
+
489
+
490
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
491
+ """Guard the power operation to a valid domain."""
492
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
493
+
fn_gen/ones_affine_scale/14/loss.png ADDED
fn_gen/ones_affine_scale/14/quantization.png ADDED
fn_gen/ones_affine_scale/15/distortion.png ADDED
fn_gen/ones_affine_scale/15/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sinh(_0*x)/_s
2
+ log(_s*x - sqrt(_s**2*x**2 + 1))/_0
fn_gen/ones_affine_scale/15/fn.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
60
+ }
61
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
62
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
63
+
64
+ if 'post_init_hook' in kwargs:
65
+ kwargs['post_init_hook'](parameters=params)
66
+
67
+
68
+ if 'post_train_hook' in kwargs:
69
+ kwargs['post_train_hook'](parameters=params)
70
+
71
+ return params
72
+
73
+
74
+ ############### Numpy Qtz ###############
75
+
76
+
77
+ def np_quantization(x, _0, _s):
78
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sinh((_0 * x)))
79
+
80
+
81
+ def np_dequantization(x, _0, _s):
82
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5)))
83
+
84
+
85
+ def fit_func(x, _0, _s):
86
+ x_ = np_quantization(x, _0, _s)
87
+ x_ = np_dequantization(x_, _0, _s)
88
+ return x_
89
+
90
+
91
+
92
+ ############### HELPERS ###############
93
+
94
+ def domain_guard(
95
+ x: torch.Tensor,
96
+ min: float = None,
97
+ max: float = None,
98
+ posinf: float = None,
99
+ neginf: float = None,
100
+ nan: float = None
101
+ ) -> torch.Tensor:
102
+ """Guard a tensor to a valid domain."""
103
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
104
+ if min is not None or max is not None:
105
+ x = torch.clamp(x, min=min, max=max)
106
+ return x
107
+
108
+
109
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
110
+ """Replace a number in a tensor with another number.
111
+
112
+ Args:
113
+ x (torch.Tensor): The input tensor.
114
+ num (float): The number to replace.
115
+ to (float): The number to replace with.
116
+
117
+ Returns:
118
+ torch.Tensor: The tensor with the number replaced.
119
+ """
120
+ return torch.where(x == num, to, x)
121
+
122
+
123
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
124
+ """Guard the power operation to a valid domain."""
125
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
126
+
127
+
128
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
129
+ val = torch.amin(x, dim=1)
130
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
131
+
132
+
133
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_space_search(
139
+ x: torch.Tensor,
140
+ **kwargs: Dict[str, Any],
141
+ ) -> torch.Tensor:
142
+
143
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
144
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
145
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
146
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
147
+
148
+ def _search_param(tensors: List[torch.tensor], n_params):
149
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
150
+ torch_tensors = torch.stack(tensors)
151
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
152
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
153
+ mean = torch.mean(torch_tensors, dim=0)
154
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
155
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
156
+
157
+ def _calc(x, qtz_func, deqtz_func, **params):
158
+ x_ = x.transpose(0, 1)
159
+ x_ = qtz_func(x=x_, **params)
160
+ x_ = deqtz_func(x=x_, **params)
161
+ x_ = x_.transpose(0, 1)
162
+ return x_
163
+
164
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
165
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
166
+ assert "params_list" in kwargs, "params list must be provided."
167
+ assert "param" in kwargs, "param must be provided."
168
+
169
+ qtz_func = kwargs.get('qtz_func')
170
+ deqtz_func = kwargs.get('deqtz_func')
171
+ params_list = kwargs.get('params_list')
172
+ param = kwargs.get('param')
173
+
174
+ n_runs = 50 # Number of runs to try to find the best parameters
175
+ n_random_params = 50 # Number of random parameters to generate
176
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
177
+ max_initial = 10000 # Maximum value to initialize the parameters
178
+
179
+ # Initializes the parameters
180
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
181
+ params = _build_initial_param(x, max_initial, n_random_params)
182
+
183
+ # Performs the search
184
+ for _ in range(n_runs):
185
+
186
+ best_params = []
187
+ for param_ in params:
188
+ try:
189
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
190
+ loss_ones = nn.MSELoss()(x, x_)
191
+
192
+ if len(best_params) < n_best_to_pick:
193
+ best_params.append((param_, loss_ones.item()))
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+ elif loss_ones < best_params[-1][1]:
196
+ best_params[-1] = (param_, loss_ones.item())
197
+ best_params = sorted(best_params, key=lambda x: x[1])
198
+
199
+ except Exception: # The parameters might not be valid for the function's domain
200
+ continue
201
+
202
+ # Generates new parameters around the mean
203
+ params = _search_param([p for p, _ in best_params], n_random_params)
204
+
205
+ # Checks if the best parameter is better than the init_ones
206
+ p_ones = init_ones(x, **kwargs)
207
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
208
+ loss_ones = nn.MSELoss()(x, x_)
209
+
210
+ # Checks if the best parameter is better than the init_rand
211
+ p_rand = init_rand(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
213
+ loss_rand = nn.MSELoss()(x, x_)
214
+
215
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
216
+ return p_rand
217
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
218
+ return p_ones
219
+ else:
220
+ return best_params[0][0]
221
+
222
+
223
+ def init_linear_scale( # Symmetric scale. From the study folder
224
+ x: torch.Tensor,
225
+ **kwargs: Dict[str, Any],
226
+ ) -> torch.Tensor:
227
+ assert "bits" in kwargs, "bits must be provided."
228
+ assert "params" in kwargs, "params must be provided."
229
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
230
+
231
+ bits = kwargs.get('bits')
232
+ params = kwargs.get('params')
233
+ qtz_func = kwargs.get('qtz_func')
234
+
235
+ x_ = x.transpose(0, 1)
236
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
237
+ x_ = x_.transpose(0, 1)
238
+
239
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
240
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
241
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
242
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
243
+
244
+ eps = torch.finfo(torch.float32).eps
245
+
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
248
+
249
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
250
+
251
+ return scale
252
+
253
+ # Introduces some noise in scale
254
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
255
+ scale = scale + 0.01 * torch.randn_like(scale)
256
+ return scale
257
+
258
+
259
+ def init_non_linear_regression_fit(
260
+ x: torch.Tensor,
261
+ **kwargs: Dict[str, Any],
262
+ ) -> torch.Tensor:
263
+
264
+ assert "params_list" in kwargs, "params list must be provided."
265
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
266
+ assert "p0" in kwargs, "p0 must be provided."
267
+ np_fit_func = kwargs.get('np_fit_func')
268
+ params_list = kwargs.get('params_list')
269
+ p0 = kwargs.get('p0')
270
+
271
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
272
+ popt, _ = curve_fit(
273
+ func,
274
+ xdata,
275
+ ydata,
276
+ maxfev=1000,
277
+ p0=p0,
278
+ method='lm'
279
+ )
280
+ return popt
281
+
282
+ # 1. Needs to convert the torch tensor to numpy tensor
283
+ xdata = x.cpu().numpy()
284
+
285
+ # 2. Sorts the data so that it makes it easier to fit to it
286
+ sorted_xdata = np.sort(xdata, axis=-1)
287
+
288
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
289
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
290
+
291
+ # 3. Finds the best parameters for each channel
292
+ try:
293
+ params = []
294
+ for i in range(sorted_xdata.shape[0]):
295
+ xdata_ = sorted_xdata[i]
296
+ p0_ = [p0[p][i] for p in params_list]
297
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
298
+ params.append(ch_params)
299
+
300
+ # 4. Builds the parameters
301
+ result = {}
302
+ for i, p in enumerate(params_list):
303
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
304
+
305
+ return result
306
+
307
+ except ValueError as e:
308
+ print(f"Could not fit the function with error: {e}")
309
+ print(f"Using fallback result...")
310
+ return {
311
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
312
+ }
313
+
314
+
315
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
316
+ val = torch.amin(x, dim=1)
317
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
318
+
319
+
320
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
321
+ # Calculate the original minimum and maximum values
322
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
323
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
324
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
325
+
326
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
327
+ return torch.ones_like(x_min)
328
+
329
+ # Calculate the scale factor
330
+ scale = (_max - _min) / (x_max - x_min)
331
+ return scale
332
+
333
+
334
+
335
+ ############## Quant ###############
336
+
337
+ @torch.enable_grad()
338
+ def learn_parameters(
339
+ x: torch.Tensor,
340
+ params: Dict[str, nn.Parameter],
341
+ qtz_func: nn.Module,
342
+ deqtz_func: nn.Module,
343
+ bits: int,
344
+ target_dtype: torch.dtype,
345
+ epochs: int = 1000,
346
+ early_stop: bool = True,
347
+ do_report: bool = False
348
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
349
+
350
+ # Requires gradients in the parameters
351
+ for p in params.values():
352
+ p.requires_grad = True
353
+ p.grad = None
354
+
355
+ param_keys = list(params.keys())
356
+ param_values = list(params.values())
357
+
358
+ # Defines optimizer and loss function
359
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
360
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
361
+ loss_fn = nn.MSELoss()
362
+
363
+ # Contains the best loss and the best parameters
364
+ best_loss = float("inf")
365
+ best_params = None
366
+
367
+ # Used to stop the search early
368
+ min_delta = 1e-7
369
+ acc_loss = []
370
+ percent_epochs_before_stop = 0.1
371
+
372
+ for i in range(epochs):
373
+ optimizer.zero_grad()
374
+
375
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
376
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
377
+ loss = loss_fn(x, dequant)
378
+
379
+ if loss.isnan() or loss.isinf():
380
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
381
+
382
+ loss.backward()
383
+ optimizer.step()
384
+ scheduler.step()
385
+
386
+ acc_loss.append(loss.item())
387
+
388
+ # Reports loss every 10 steps
389
+ if i % 10 == 0 and do_report:
390
+ print(f"Epoch {i}: Loss {loss.item()}")
391
+
392
+ # Optimizes the parameter search by storing the best loss and the parameters
393
+ if loss.item() < best_loss:
394
+ best_loss = loss.item()
395
+ best_params = copy.deepcopy({
396
+ k: v for k, v in params.items() if k in param_keys
397
+ })
398
+
399
+ # We also stop the search if the loss has not considerably during the last 10% epochs
400
+ if early_stop:
401
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
402
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
403
+ break
404
+
405
+ # No longer requires gradients in the parameters
406
+ for p in best_params.values():
407
+ p.requires_grad = False
408
+ p.grad = None
409
+
410
+ if do_report:
411
+ return best_params, acc_loss
412
+ else:
413
+ return best_params
414
+
415
+
416
+ def quantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ target_dtype: torch.dtype = torch.int8
422
+ ) -> torch.Tensor:
423
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
424
+ x = x.transpose(0, 1) # Aligns shapes
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
428
+ return x
429
+
430
+
431
+ def dequantize(
432
+ x: torch.Tensor,
433
+ params: Dict[str, nn.Parameter],
434
+ func: nn.Module,
435
+ bits: int,
436
+ out_dtype: torch.dtype
437
+ ) -> torch.Tensor:
438
+ x = x.to(dtype=out_dtype)
439
+ x = x.transpose(0, 1)
440
+ x = func(x=x, **params)
441
+ x = x.transpose(0, 1)
442
+ return x
443
+
444
+
445
+ def round_func_BPDA(input):
446
+ # This is equivalent to replacing round function (non-differentiable) with
447
+ # an identity function (differentiable) only when backward.
448
+ forward_value = torch.round(input)
449
+ out = input.clone()
450
+ out.data = forward_value.data
451
+ return out
452
+
453
+
454
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
455
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
456
+
457
+
458
+
459
+ ############## Numpy ###############
460
+
461
+ def np_domain_guard(
462
+ x: np.ndarray,
463
+ min: float = None,
464
+ max: float = None,
465
+ posinf: float = None,
466
+ neginf: float = None,
467
+ nan: float = None
468
+ ) -> np.ndarray:
469
+ """Guard a tensor to a valid domain."""
470
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
471
+ if min is not None or max is not None:
472
+ x = np.clip(x, min, max)
473
+ return x
474
+
475
+
476
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
477
+ """Replace a number in a tensor with another number.
478
+
479
+ Args:
480
+ x (np.ndarray): The input tensor.
481
+ num (float): The number to replace.
482
+ to (float): The number to replace with.
483
+
484
+ Returns:
485
+ np.ndarray: The tensor with the number replaced.
486
+ """
487
+ return np.where(x == num, to, x)
488
+
489
+
490
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
491
+ """Guard the power operation to a valid domain."""
492
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
493
+
fn_gen/ones_affine_scale/15/loss.png ADDED
fn_gen/ones_affine_scale/15/quantization.png ADDED
fn_gen/ones_affine_scale/16/distortion.png ADDED
fn_gen/ones_affine_scale/16/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atan(_0*x)/_s
2
+ tan(_s*x)/_0
fn_gen/ones_affine_scale/16/fn.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atan((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tan(domain_guard((params['_s'] * x), posinf=1, neginf=-1, nan=0)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
60
+ }
61
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
62
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
63
+
64
+ if 'post_init_hook' in kwargs:
65
+ kwargs['post_init_hook'](parameters=params)
66
+
67
+
68
+ if 'post_train_hook' in kwargs:
69
+ kwargs['post_train_hook'](parameters=params)
70
+
71
+ return params
72
+
73
+
74
+ ############### Numpy Qtz ###############
75
+
76
+
77
+ def np_quantization(x, _0, _s):
78
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctan((_0 * x)))
79
+
80
+
81
+ def np_dequantization(x, _0, _s):
82
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tan(np_domain_guard((_s * x), posinf=1, neginf=-1, nan=0)))
83
+
84
+
85
+ def fit_func(x, _0, _s):
86
+ x_ = np_quantization(x, _0, _s)
87
+ x_ = np_dequantization(x_, _0, _s)
88
+ return x_
89
+
90
+
91
+
92
+ ############### HELPERS ###############
93
+
94
+ def domain_guard(
95
+ x: torch.Tensor,
96
+ min: float = None,
97
+ max: float = None,
98
+ posinf: float = None,
99
+ neginf: float = None,
100
+ nan: float = None
101
+ ) -> torch.Tensor:
102
+ """Guard a tensor to a valid domain."""
103
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
104
+ if min is not None or max is not None:
105
+ x = torch.clamp(x, min=min, max=max)
106
+ return x
107
+
108
+
109
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
110
+ """Replace a number in a tensor with another number.
111
+
112
+ Args:
113
+ x (torch.Tensor): The input tensor.
114
+ num (float): The number to replace.
115
+ to (float): The number to replace with.
116
+
117
+ Returns:
118
+ torch.Tensor: The tensor with the number replaced.
119
+ """
120
+ return torch.where(x == num, to, x)
121
+
122
+
123
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
124
+ """Guard the power operation to a valid domain."""
125
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
126
+
127
+
128
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
129
+ val = torch.amin(x, dim=1)
130
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
131
+
132
+
133
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_space_search(
139
+ x: torch.Tensor,
140
+ **kwargs: Dict[str, Any],
141
+ ) -> torch.Tensor:
142
+
143
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
144
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
145
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
146
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
147
+
148
+ def _search_param(tensors: List[torch.tensor], n_params):
149
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
150
+ torch_tensors = torch.stack(tensors)
151
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
152
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
153
+ mean = torch.mean(torch_tensors, dim=0)
154
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
155
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
156
+
157
+ def _calc(x, qtz_func, deqtz_func, **params):
158
+ x_ = x.transpose(0, 1)
159
+ x_ = qtz_func(x=x_, **params)
160
+ x_ = deqtz_func(x=x_, **params)
161
+ x_ = x_.transpose(0, 1)
162
+ return x_
163
+
164
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
165
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
166
+ assert "params_list" in kwargs, "params list must be provided."
167
+ assert "param" in kwargs, "param must be provided."
168
+
169
+ qtz_func = kwargs.get('qtz_func')
170
+ deqtz_func = kwargs.get('deqtz_func')
171
+ params_list = kwargs.get('params_list')
172
+ param = kwargs.get('param')
173
+
174
+ n_runs = 50 # Number of runs to try to find the best parameters
175
+ n_random_params = 50 # Number of random parameters to generate
176
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
177
+ max_initial = 10000 # Maximum value to initialize the parameters
178
+
179
+ # Initializes the parameters
180
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
181
+ params = _build_initial_param(x, max_initial, n_random_params)
182
+
183
+ # Performs the search
184
+ for _ in range(n_runs):
185
+
186
+ best_params = []
187
+ for param_ in params:
188
+ try:
189
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
190
+ loss_ones = nn.MSELoss()(x, x_)
191
+
192
+ if len(best_params) < n_best_to_pick:
193
+ best_params.append((param_, loss_ones.item()))
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+ elif loss_ones < best_params[-1][1]:
196
+ best_params[-1] = (param_, loss_ones.item())
197
+ best_params = sorted(best_params, key=lambda x: x[1])
198
+
199
+ except Exception: # The parameters might not be valid for the function's domain
200
+ continue
201
+
202
+ # Generates new parameters around the mean
203
+ params = _search_param([p for p, _ in best_params], n_random_params)
204
+
205
+ # Checks if the best parameter is better than the init_ones
206
+ p_ones = init_ones(x, **kwargs)
207
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
208
+ loss_ones = nn.MSELoss()(x, x_)
209
+
210
+ # Checks if the best parameter is better than the init_rand
211
+ p_rand = init_rand(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
213
+ loss_rand = nn.MSELoss()(x, x_)
214
+
215
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
216
+ return p_rand
217
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
218
+ return p_ones
219
+ else:
220
+ return best_params[0][0]
221
+
222
+
223
+ def init_linear_scale( # Symmetric scale. From the study folder
224
+ x: torch.Tensor,
225
+ **kwargs: Dict[str, Any],
226
+ ) -> torch.Tensor:
227
+ assert "bits" in kwargs, "bits must be provided."
228
+ assert "params" in kwargs, "params must be provided."
229
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
230
+
231
+ bits = kwargs.get('bits')
232
+ params = kwargs.get('params')
233
+ qtz_func = kwargs.get('qtz_func')
234
+
235
+ x_ = x.transpose(0, 1)
236
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
237
+ x_ = x_.transpose(0, 1)
238
+
239
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
240
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
241
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
242
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
243
+
244
+ eps = torch.finfo(torch.float32).eps
245
+
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
248
+
249
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
250
+
251
+ return scale
252
+
253
+ # Introduces some noise in scale
254
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
255
+ scale = scale + 0.01 * torch.randn_like(scale)
256
+ return scale
257
+
258
+
259
+ def init_non_linear_regression_fit(
260
+ x: torch.Tensor,
261
+ **kwargs: Dict[str, Any],
262
+ ) -> torch.Tensor:
263
+
264
+ assert "params_list" in kwargs, "params list must be provided."
265
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
266
+ assert "p0" in kwargs, "p0 must be provided."
267
+ np_fit_func = kwargs.get('np_fit_func')
268
+ params_list = kwargs.get('params_list')
269
+ p0 = kwargs.get('p0')
270
+
271
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
272
+ popt, _ = curve_fit(
273
+ func,
274
+ xdata,
275
+ ydata,
276
+ maxfev=1000,
277
+ p0=p0,
278
+ method='lm'
279
+ )
280
+ return popt
281
+
282
+ # 1. Needs to convert the torch tensor to numpy tensor
283
+ xdata = x.cpu().numpy()
284
+
285
+ # 2. Sorts the data so that it makes it easier to fit to it
286
+ sorted_xdata = np.sort(xdata, axis=-1)
287
+
288
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
289
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
290
+
291
+ # 3. Finds the best parameters for each channel
292
+ try:
293
+ params = []
294
+ for i in range(sorted_xdata.shape[0]):
295
+ xdata_ = sorted_xdata[i]
296
+ p0_ = [p0[p][i] for p in params_list]
297
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
298
+ params.append(ch_params)
299
+
300
+ # 4. Builds the parameters
301
+ result = {}
302
+ for i, p in enumerate(params_list):
303
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
304
+
305
+ return result
306
+
307
+ except ValueError as e:
308
+ print(f"Could not fit the function with error: {e}")
309
+ print(f"Using fallback result...")
310
+ return {
311
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
312
+ }
313
+
314
+
315
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
316
+ val = torch.amin(x, dim=1)
317
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
318
+
319
+
320
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
321
+ # Calculate the original minimum and maximum values
322
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
323
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
324
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
325
+
326
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
327
+ return torch.ones_like(x_min)
328
+
329
+ # Calculate the scale factor
330
+ scale = (_max - _min) / (x_max - x_min)
331
+ return scale
332
+
333
+
334
+
335
+ ############## Quant ###############
336
+
337
+ @torch.enable_grad()
338
+ def learn_parameters(
339
+ x: torch.Tensor,
340
+ params: Dict[str, nn.Parameter],
341
+ qtz_func: nn.Module,
342
+ deqtz_func: nn.Module,
343
+ bits: int,
344
+ target_dtype: torch.dtype,
345
+ epochs: int = 1000,
346
+ early_stop: bool = True,
347
+ do_report: bool = False
348
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
349
+
350
+ # Requires gradients in the parameters
351
+ for p in params.values():
352
+ p.requires_grad = True
353
+ p.grad = None
354
+
355
+ param_keys = list(params.keys())
356
+ param_values = list(params.values())
357
+
358
+ # Defines optimizer and loss function
359
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
360
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
361
+ loss_fn = nn.MSELoss()
362
+
363
+ # Contains the best loss and the best parameters
364
+ best_loss = float("inf")
365
+ best_params = None
366
+
367
+ # Used to stop the search early
368
+ min_delta = 1e-7
369
+ acc_loss = []
370
+ percent_epochs_before_stop = 0.1
371
+
372
+ for i in range(epochs):
373
+ optimizer.zero_grad()
374
+
375
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
376
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
377
+ loss = loss_fn(x, dequant)
378
+
379
+ if loss.isnan() or loss.isinf():
380
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
381
+
382
+ loss.backward()
383
+ optimizer.step()
384
+ scheduler.step()
385
+
386
+ acc_loss.append(loss.item())
387
+
388
+ # Reports loss every 10 steps
389
+ if i % 10 == 0 and do_report:
390
+ print(f"Epoch {i}: Loss {loss.item()}")
391
+
392
+ # Optimizes the parameter search by storing the best loss and the parameters
393
+ if loss.item() < best_loss:
394
+ best_loss = loss.item()
395
+ best_params = copy.deepcopy({
396
+ k: v for k, v in params.items() if k in param_keys
397
+ })
398
+
399
+ # We also stop the search if the loss has not considerably during the last 10% epochs
400
+ if early_stop:
401
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
402
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
403
+ break
404
+
405
+ # No longer requires gradients in the parameters
406
+ for p in best_params.values():
407
+ p.requires_grad = False
408
+ p.grad = None
409
+
410
+ if do_report:
411
+ return best_params, acc_loss
412
+ else:
413
+ return best_params
414
+
415
+
416
+ def quantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ target_dtype: torch.dtype = torch.int8
422
+ ) -> torch.Tensor:
423
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
424
+ x = x.transpose(0, 1) # Aligns shapes
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
428
+ return x
429
+
430
+
431
+ def dequantize(
432
+ x: torch.Tensor,
433
+ params: Dict[str, nn.Parameter],
434
+ func: nn.Module,
435
+ bits: int,
436
+ out_dtype: torch.dtype
437
+ ) -> torch.Tensor:
438
+ x = x.to(dtype=out_dtype)
439
+ x = x.transpose(0, 1)
440
+ x = func(x=x, **params)
441
+ x = x.transpose(0, 1)
442
+ return x
443
+
444
+
445
+ def round_func_BPDA(input):
446
+ # This is equivalent to replacing round function (non-differentiable) with
447
+ # an identity function (differentiable) only when backward.
448
+ forward_value = torch.round(input)
449
+ out = input.clone()
450
+ out.data = forward_value.data
451
+ return out
452
+
453
+
454
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
455
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
456
+
457
+
458
+
459
+ ############## Numpy ###############
460
+
461
+ def np_domain_guard(
462
+ x: np.ndarray,
463
+ min: float = None,
464
+ max: float = None,
465
+ posinf: float = None,
466
+ neginf: float = None,
467
+ nan: float = None
468
+ ) -> np.ndarray:
469
+ """Guard a tensor to a valid domain."""
470
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
471
+ if min is not None or max is not None:
472
+ x = np.clip(x, min, max)
473
+ return x
474
+
475
+
476
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
477
+ """Replace a number in a tensor with another number.
478
+
479
+ Args:
480
+ x (np.ndarray): The input tensor.
481
+ num (float): The number to replace.
482
+ to (float): The number to replace with.
483
+
484
+ Returns:
485
+ np.ndarray: The tensor with the number replaced.
486
+ """
487
+ return np.where(x == num, to, x)
488
+
489
+
490
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
491
+ """Guard the power operation to a valid domain."""
492
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
493
+
fn_gen/ones_affine_scale/16/loss.png ADDED
fn_gen/ones_affine_scale/16/quantization.png ADDED
fn_gen/ones_affine_scale/17/distortion.png ADDED
fn_gen/ones_affine_scale/17/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tanh(_0*x)/_s
2
+ log((-_s*x - 1)/(_s*x - 1))/_0
fn_gen/ones_affine_scale/17/fn.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ return scale
50
+
51
+ # Introduces some noise in scale
52
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
53
+ scale = scale + 0.01 * torch.randn_like(scale)
54
+ return scale
55
+
56
+
57
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
58
+ params = {
59
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
60
+ }
61
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
62
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
63
+
64
+ if 'post_init_hook' in kwargs:
65
+ kwargs['post_init_hook'](parameters=params)
66
+
67
+
68
+ if 'post_train_hook' in kwargs:
69
+ kwargs['post_train_hook'](parameters=params)
70
+
71
+ return params
72
+
73
+
74
+ ############### Numpy Qtz ###############
75
+
76
+
77
+ def np_quantization(x, _0, _s):
78
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x)))
79
+
80
+
81
+ def np_dequantization(x, _0, _s):
82
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5)))
83
+
84
+
85
+ def fit_func(x, _0, _s):
86
+ x_ = np_quantization(x, _0, _s)
87
+ x_ = np_dequantization(x_, _0, _s)
88
+ return x_
89
+
90
+
91
+
92
+ ############### HELPERS ###############
93
+
94
+ def domain_guard(
95
+ x: torch.Tensor,
96
+ min: float = None,
97
+ max: float = None,
98
+ posinf: float = None,
99
+ neginf: float = None,
100
+ nan: float = None
101
+ ) -> torch.Tensor:
102
+ """Guard a tensor to a valid domain."""
103
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
104
+ if min is not None or max is not None:
105
+ x = torch.clamp(x, min=min, max=max)
106
+ return x
107
+
108
+
109
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
110
+ """Replace a number in a tensor with another number.
111
+
112
+ Args:
113
+ x (torch.Tensor): The input tensor.
114
+ num (float): The number to replace.
115
+ to (float): The number to replace with.
116
+
117
+ Returns:
118
+ torch.Tensor: The tensor with the number replaced.
119
+ """
120
+ return torch.where(x == num, to, x)
121
+
122
+
123
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
124
+ """Guard the power operation to a valid domain."""
125
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
126
+
127
+
128
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
129
+ val = torch.amin(x, dim=1)
130
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
131
+
132
+
133
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_space_search(
139
+ x: torch.Tensor,
140
+ **kwargs: Dict[str, Any],
141
+ ) -> torch.Tensor:
142
+
143
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
144
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
145
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
146
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
147
+
148
+ def _search_param(tensors: List[torch.tensor], n_params):
149
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
150
+ torch_tensors = torch.stack(tensors)
151
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
152
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
153
+ mean = torch.mean(torch_tensors, dim=0)
154
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
155
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
156
+
157
+ def _calc(x, qtz_func, deqtz_func, **params):
158
+ x_ = x.transpose(0, 1)
159
+ x_ = qtz_func(x=x_, **params)
160
+ x_ = deqtz_func(x=x_, **params)
161
+ x_ = x_.transpose(0, 1)
162
+ return x_
163
+
164
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
165
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
166
+ assert "params_list" in kwargs, "params list must be provided."
167
+ assert "param" in kwargs, "param must be provided."
168
+
169
+ qtz_func = kwargs.get('qtz_func')
170
+ deqtz_func = kwargs.get('deqtz_func')
171
+ params_list = kwargs.get('params_list')
172
+ param = kwargs.get('param')
173
+
174
+ n_runs = 50 # Number of runs to try to find the best parameters
175
+ n_random_params = 50 # Number of random parameters to generate
176
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
177
+ max_initial = 10000 # Maximum value to initialize the parameters
178
+
179
+ # Initializes the parameters
180
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
181
+ params = _build_initial_param(x, max_initial, n_random_params)
182
+
183
+ # Performs the search
184
+ for _ in range(n_runs):
185
+
186
+ best_params = []
187
+ for param_ in params:
188
+ try:
189
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
190
+ loss_ones = nn.MSELoss()(x, x_)
191
+
192
+ if len(best_params) < n_best_to_pick:
193
+ best_params.append((param_, loss_ones.item()))
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+ elif loss_ones < best_params[-1][1]:
196
+ best_params[-1] = (param_, loss_ones.item())
197
+ best_params = sorted(best_params, key=lambda x: x[1])
198
+
199
+ except Exception: # The parameters might not be valid for the function's domain
200
+ continue
201
+
202
+ # Generates new parameters around the mean
203
+ params = _search_param([p for p, _ in best_params], n_random_params)
204
+
205
+ # Checks if the best parameter is better than the init_ones
206
+ p_ones = init_ones(x, **kwargs)
207
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
208
+ loss_ones = nn.MSELoss()(x, x_)
209
+
210
+ # Checks if the best parameter is better than the init_rand
211
+ p_rand = init_rand(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
213
+ loss_rand = nn.MSELoss()(x, x_)
214
+
215
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
216
+ return p_rand
217
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
218
+ return p_ones
219
+ else:
220
+ return best_params[0][0]
221
+
222
+
223
+ def init_linear_scale( # Symmetric scale. From the study folder
224
+ x: torch.Tensor,
225
+ **kwargs: Dict[str, Any],
226
+ ) -> torch.Tensor:
227
+ assert "bits" in kwargs, "bits must be provided."
228
+ assert "params" in kwargs, "params must be provided."
229
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
230
+
231
+ bits = kwargs.get('bits')
232
+ params = kwargs.get('params')
233
+ qtz_func = kwargs.get('qtz_func')
234
+
235
+ x_ = x.transpose(0, 1)
236
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
237
+ x_ = x_.transpose(0, 1)
238
+
239
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
240
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
241
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
242
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
243
+
244
+ eps = torch.finfo(torch.float32).eps
245
+
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
248
+
249
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
250
+
251
+ return scale
252
+
253
+ # Introduces some noise in scale
254
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
255
+ scale = scale + 0.01 * torch.randn_like(scale)
256
+ return scale
257
+
258
+
259
+ def init_non_linear_regression_fit(
260
+ x: torch.Tensor,
261
+ **kwargs: Dict[str, Any],
262
+ ) -> torch.Tensor:
263
+
264
+ assert "params_list" in kwargs, "params list must be provided."
265
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
266
+ assert "p0" in kwargs, "p0 must be provided."
267
+ np_fit_func = kwargs.get('np_fit_func')
268
+ params_list = kwargs.get('params_list')
269
+ p0 = kwargs.get('p0')
270
+
271
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
272
+ popt, _ = curve_fit(
273
+ func,
274
+ xdata,
275
+ ydata,
276
+ maxfev=1000,
277
+ p0=p0,
278
+ method='lm'
279
+ )
280
+ return popt
281
+
282
+ # 1. Needs to convert the torch tensor to numpy tensor
283
+ xdata = x.cpu().numpy()
284
+
285
+ # 2. Sorts the data so that it makes it easier to fit to it
286
+ sorted_xdata = np.sort(xdata, axis=-1)
287
+
288
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
289
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
290
+
291
+ # 3. Finds the best parameters for each channel
292
+ try:
293
+ params = []
294
+ for i in range(sorted_xdata.shape[0]):
295
+ xdata_ = sorted_xdata[i]
296
+ p0_ = [p0[p][i] for p in params_list]
297
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
298
+ params.append(ch_params)
299
+
300
+ # 4. Builds the parameters
301
+ result = {}
302
+ for i, p in enumerate(params_list):
303
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
304
+
305
+ return result
306
+
307
+ except ValueError as e:
308
+ print(f"Could not fit the function with error: {e}")
309
+ print(f"Using fallback result...")
310
+ return {
311
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
312
+ }
313
+
314
+
315
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
316
+ val = torch.amin(x, dim=1)
317
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
318
+
319
+
320
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
321
+ # Calculate the original minimum and maximum values
322
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
323
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
324
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
325
+
326
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
327
+ return torch.ones_like(x_min)
328
+
329
+ # Calculate the scale factor
330
+ scale = (_max - _min) / (x_max - x_min)
331
+ return scale
332
+
333
+
334
+
335
+ ############## Quant ###############
336
+
337
+ @torch.enable_grad()
338
+ def learn_parameters(
339
+ x: torch.Tensor,
340
+ params: Dict[str, nn.Parameter],
341
+ qtz_func: nn.Module,
342
+ deqtz_func: nn.Module,
343
+ bits: int,
344
+ target_dtype: torch.dtype,
345
+ epochs: int = 1000,
346
+ early_stop: bool = True,
347
+ do_report: bool = False
348
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
349
+
350
+ # Requires gradients in the parameters
351
+ for p in params.values():
352
+ p.requires_grad = True
353
+ p.grad = None
354
+
355
+ param_keys = list(params.keys())
356
+ param_values = list(params.values())
357
+
358
+ # Defines optimizer and loss function
359
+ optimizer = torch.optim.Adam(param_values, lr=0.001)
360
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
361
+ loss_fn = nn.MSELoss()
362
+
363
+ # Contains the best loss and the best parameters
364
+ best_loss = float("inf")
365
+ best_params = None
366
+
367
+ # Used to stop the search early
368
+ min_delta = 1e-7
369
+ acc_loss = []
370
+ percent_epochs_before_stop = 0.1
371
+
372
+ for i in range(epochs):
373
+ optimizer.zero_grad()
374
+
375
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
376
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
377
+ loss = loss_fn(x, dequant)
378
+
379
+ if loss.isnan() or loss.isinf():
380
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
381
+
382
+ loss.backward()
383
+ optimizer.step()
384
+ scheduler.step()
385
+
386
+ acc_loss.append(loss.item())
387
+
388
+ # Reports loss every 10 steps
389
+ if i % 10 == 0 and do_report:
390
+ print(f"Epoch {i}: Loss {loss.item()}")
391
+
392
+ # Optimizes the parameter search by storing the best loss and the parameters
393
+ if loss.item() < best_loss:
394
+ best_loss = loss.item()
395
+ best_params = copy.deepcopy({
396
+ k: v for k, v in params.items() if k in param_keys
397
+ })
398
+
399
+ # We also stop the search if the loss has not considerably during the last 10% epochs
400
+ if early_stop:
401
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
402
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
403
+ break
404
+
405
+ # No longer requires gradients in the parameters
406
+ for p in best_params.values():
407
+ p.requires_grad = False
408
+ p.grad = None
409
+
410
+ if do_report:
411
+ return best_params, acc_loss
412
+ else:
413
+ return best_params
414
+
415
+
416
+ def quantize(
417
+ x: torch.Tensor,
418
+ params: Dict[str, nn.Parameter],
419
+ func: nn.Module,
420
+ bits: int,
421
+ target_dtype: torch.dtype = torch.int8
422
+ ) -> torch.Tensor:
423
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
424
+ x = x.transpose(0, 1) # Aligns shapes
425
+ x = func(x=x, **params)
426
+ x = x.transpose(0, 1)
427
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
428
+ return x
429
+
430
+
431
+ def dequantize(
432
+ x: torch.Tensor,
433
+ params: Dict[str, nn.Parameter],
434
+ func: nn.Module,
435
+ bits: int,
436
+ out_dtype: torch.dtype
437
+ ) -> torch.Tensor:
438
+ x = x.to(dtype=out_dtype)
439
+ x = x.transpose(0, 1)
440
+ x = func(x=x, **params)
441
+ x = x.transpose(0, 1)
442
+ return x
443
+
444
+
445
+ def round_func_BPDA(input):
446
+ # This is equivalent to replacing round function (non-differentiable) with
447
+ # an identity function (differentiable) only when backward.
448
+ forward_value = torch.round(input)
449
+ out = input.clone()
450
+ out.data = forward_value.data
451
+ return out
452
+
453
+
454
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
455
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
456
+
457
+
458
+
459
+ ############## Numpy ###############
460
+
461
+ def np_domain_guard(
462
+ x: np.ndarray,
463
+ min: float = None,
464
+ max: float = None,
465
+ posinf: float = None,
466
+ neginf: float = None,
467
+ nan: float = None
468
+ ) -> np.ndarray:
469
+ """Guard a tensor to a valid domain."""
470
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
471
+ if min is not None or max is not None:
472
+ x = np.clip(x, min, max)
473
+ return x
474
+
475
+
476
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
477
+ """Replace a number in a tensor with another number.
478
+
479
+ Args:
480
+ x (np.ndarray): The input tensor.
481
+ num (float): The number to replace.
482
+ to (float): The number to replace with.
483
+
484
+ Returns:
485
+ np.ndarray: The tensor with the number replaced.
486
+ """
487
+ return np.where(x == num, to, x)
488
+
489
+
490
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
491
+ """Guard the power operation to a valid domain."""
492
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
493
+
fn_gen/ones_affine_scale/17/loss.png ADDED
fn_gen/ones_affine_scale/17/quantization.png ADDED