Diogo-V commited on
Commit
5b1443d
·
verified ·
1 Parent(s): 4b2e36b

Upload learned functions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fn_gen/ones_t/0/distortion.png +0 -0
  2. fn_gen/ones_t/0/expressions.txt +2 -0
  3. fn_gen/ones_t/0/fn.py +506 -0
  4. fn_gen/ones_t/0/loss.png +0 -0
  5. fn_gen/ones_t/0/quantization.png +0 -0
  6. fn_gen/ones_t/1/distortion.png +0 -0
  7. fn_gen/ones_t/1/expressions.txt +2 -0
  8. fn_gen/ones_t/1/fn.py +506 -0
  9. fn_gen/ones_t/1/loss.png +0 -0
  10. fn_gen/ones_t/1/quantization.png +0 -0
  11. fn_gen/ones_t/10/distortion.png +0 -0
  12. fn_gen/ones_t/10/expressions.txt +2 -0
  13. fn_gen/ones_t/10/fn.py +506 -0
  14. fn_gen/ones_t/10/loss.png +0 -0
  15. fn_gen/ones_t/10/quantization.png +0 -0
  16. fn_gen/ones_t/11/distortion.png +0 -0
  17. fn_gen/ones_t/11/expressions.txt +2 -0
  18. fn_gen/ones_t/11/fn.py +506 -0
  19. fn_gen/ones_t/11/loss.png +0 -0
  20. fn_gen/ones_t/11/quantization.png +0 -0
  21. fn_gen/ones_t/12/distortion.png +0 -0
  22. fn_gen/ones_t/12/expressions.txt +2 -0
  23. fn_gen/ones_t/12/fn.py +506 -0
  24. fn_gen/ones_t/12/loss.png +0 -0
  25. fn_gen/ones_t/12/quantization.png +0 -0
  26. fn_gen/ones_t/13/distortion.png +0 -0
  27. fn_gen/ones_t/13/expressions.txt +2 -0
  28. fn_gen/ones_t/13/fn.py +505 -0
  29. fn_gen/ones_t/13/loss.png +0 -0
  30. fn_gen/ones_t/13/quantization.png +0 -0
  31. fn_gen/ones_t/14/distortion.png +0 -0
  32. fn_gen/ones_t/14/expressions.txt +2 -0
  33. fn_gen/ones_t/14/fn.py +506 -0
  34. fn_gen/ones_t/14/loss.png +0 -0
  35. fn_gen/ones_t/14/quantization.png +0 -0
  36. fn_gen/ones_t/15/distortion.png +0 -0
  37. fn_gen/ones_t/15/expressions.txt +2 -0
  38. fn_gen/ones_t/15/fn.py +505 -0
  39. fn_gen/ones_t/15/loss.png +0 -0
  40. fn_gen/ones_t/15/quantization.png +0 -0
  41. fn_gen/ones_t/16/distortion.png +0 -0
  42. fn_gen/ones_t/16/expressions.txt +2 -0
  43. fn_gen/ones_t/16/fn.py +506 -0
  44. fn_gen/ones_t/16/loss.png +0 -0
  45. fn_gen/ones_t/16/quantization.png +0 -0
  46. fn_gen/ones_t/17/distortion.png +0 -0
  47. fn_gen/ones_t/17/expressions.txt +2 -0
  48. fn_gen/ones_t/17/fn.py +506 -0
  49. fn_gen/ones_t/17/loss.png +0 -0
  50. fn_gen/ones_t/17/quantization.png +0 -0
fn_gen/ones_t/0/distortion.png ADDED
fn_gen/ones_t/0/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cosh(_0*x)/_s
2
+ log(_s*x - sqrt(_s**2*x**2 - 1))/_0
fn_gen/ones_t/0/fn.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cosh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(-1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+ params = learn_parameters(x, params,
66
+ qtz_func=quantization,
67
+ deqtz_func=dequantization,
68
+ bits=kwargs['bits'],
69
+ target_dtype=torch.int8,
70
+ epochs=500,
71
+ early_stop=False,
72
+ )
73
+ if 'post_train_hook' in kwargs:
74
+ kwargs['post_train_hook'](parameters=params)
75
+
76
+ return params
77
+
78
+
79
+ ############### Numpy Qtz ###############
80
+
81
+
82
+ def np_quantization(x, _0, _s):
83
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cosh((_0 * x)))
84
+
85
+
86
+ def np_dequantization(x, _0, _s):
87
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(-1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5)))
88
+
89
+
90
+ def fit_func(x, _0, _s):
91
+ x_ = np_quantization(x, _0, _s)
92
+ x_ = np_dequantization(x_, _0, _s)
93
+ return x_
94
+
95
+
96
+
97
+ ############### HELPERS ###############
98
+
99
+ def domain_guard(
100
+ x: torch.Tensor,
101
+ min: float = None,
102
+ max: float = None,
103
+ posinf: float = None,
104
+ neginf: float = None,
105
+ nan: float = None
106
+ ) -> torch.Tensor:
107
+ """Guard a tensor to a valid domain."""
108
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
109
+ if min is not None or max is not None:
110
+ x = torch.clamp(x, min=min, max=max)
111
+ return x
112
+
113
+
114
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
115
+ """Replace a number in a tensor with another number.
116
+
117
+ Args:
118
+ x (torch.Tensor): The input tensor.
119
+ num (float): The number to replace.
120
+ to (float): The number to replace with.
121
+
122
+ Returns:
123
+ torch.Tensor: The tensor with the number replaced.
124
+ """
125
+ return torch.where(x == num, to, x)
126
+
127
+
128
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
129
+ """Guard the power operation to a valid domain."""
130
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
131
+
132
+
133
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
139
+ val = torch.amin(x, dim=1)
140
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
141
+
142
+
143
+ def init_space_search(
144
+ x: torch.Tensor,
145
+ **kwargs: Dict[str, Any],
146
+ ) -> torch.Tensor:
147
+
148
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
149
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
150
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
151
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
152
+
153
+ def _search_param(tensors: List[torch.tensor], n_params):
154
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
155
+ torch_tensors = torch.stack(tensors)
156
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
157
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
158
+ mean = torch.mean(torch_tensors, dim=0)
159
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
160
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
161
+
162
+ def _calc(x, qtz_func, deqtz_func, **params):
163
+ x_ = x.transpose(0, 1)
164
+ x_ = qtz_func(x=x_, **params)
165
+ x_ = deqtz_func(x=x_, **params)
166
+ x_ = x_.transpose(0, 1)
167
+ return x_
168
+
169
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
170
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
171
+ assert "params_list" in kwargs, "params list must be provided."
172
+ assert "param" in kwargs, "param must be provided."
173
+
174
+ qtz_func = kwargs.get('qtz_func')
175
+ deqtz_func = kwargs.get('deqtz_func')
176
+ params_list = kwargs.get('params_list')
177
+ param = kwargs.get('param')
178
+
179
+ n_runs = 50 # Number of runs to try to find the best parameters
180
+ n_random_params = 50 # Number of random parameters to generate
181
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
182
+ max_initial = 10000 # Maximum value to initialize the parameters
183
+
184
+ # Initializes the parameters
185
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
186
+ params = _build_initial_param(x, max_initial, n_random_params)
187
+
188
+ # Performs the search
189
+ for _ in range(n_runs):
190
+
191
+ best_params = []
192
+ for param_ in params:
193
+ try:
194
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
195
+ loss_ones = nn.MSELoss()(x, x_)
196
+
197
+ if len(best_params) < n_best_to_pick:
198
+ best_params.append((param_, loss_ones.item()))
199
+ best_params = sorted(best_params, key=lambda x: x[1])
200
+ elif loss_ones < best_params[-1][1]:
201
+ best_params[-1] = (param_, loss_ones.item())
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+
204
+ except Exception: # The parameters might not be valid for the function's domain
205
+ continue
206
+
207
+ # Generates new parameters around the mean
208
+ params = _search_param([p for p, _ in best_params], n_random_params)
209
+
210
+ # Checks if the best parameter is better than the init_ones
211
+ p_ones = init_ones(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
213
+ loss_ones = nn.MSELoss()(x, x_)
214
+
215
+ # Checks if the best parameter is better than the init_rand
216
+ p_rand = init_rand(x, **kwargs)
217
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
218
+ loss_rand = nn.MSELoss()(x, x_)
219
+
220
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
221
+ return p_rand
222
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
223
+ return p_ones
224
+ else:
225
+ return best_params[0][0]
226
+
227
+
228
+ def init_linear_scale( # Symmetric scale. From the study folder
229
+ x: torch.Tensor,
230
+ **kwargs: Dict[str, Any],
231
+ ) -> torch.Tensor:
232
+ assert "bits" in kwargs, "bits must be provided."
233
+ assert "params" in kwargs, "params must be provided."
234
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
235
+
236
+ bits = kwargs.get('bits')
237
+ params = kwargs.get('params')
238
+ qtz_func = kwargs.get('qtz_func')
239
+
240
+ x_ = x.transpose(0, 1)
241
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
242
+ x_ = x_.transpose(0, 1)
243
+
244
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
245
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
246
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
247
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
248
+
249
+ eps = torch.finfo(torch.float32).eps
250
+
251
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
252
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
253
+
254
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
255
+
256
+ # Introduces some noise in scale
257
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
258
+ # scale = scale + 0.01 * torch.randn_like(scale)
259
+ return scale
260
+
261
+
262
+ def init_non_linear_regression_fit(
263
+ x: torch.Tensor,
264
+ **kwargs: Dict[str, Any],
265
+ ) -> torch.Tensor:
266
+
267
+ assert "params_list" in kwargs, "params list must be provided."
268
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
269
+ assert "p0" in kwargs, "p0 must be provided."
270
+ np_fit_func = kwargs.get('np_fit_func')
271
+ params_list = kwargs.get('params_list')
272
+ p0 = kwargs.get('p0')
273
+
274
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
275
+ popt, _ = curve_fit(
276
+ func,
277
+ xdata,
278
+ ydata,
279
+ maxfev=1000,
280
+ p0=p0,
281
+ method='lm'
282
+ )
283
+ return popt
284
+
285
+ # 1. Needs to convert the torch tensor to numpy tensor
286
+ xdata = x.cpu().numpy()
287
+
288
+ # 2. Sorts the data so that it makes it easier to fit to it
289
+ sorted_xdata = np.sort(xdata, axis=-1)
290
+
291
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
292
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
293
+
294
+ # 3. Finds the best parameters for each channel
295
+ try:
296
+ params = []
297
+ for i in range(sorted_xdata.shape[0]):
298
+ xdata_ = sorted_xdata[i]
299
+ p0_ = [p0[p][i] for p in params_list]
300
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
301
+ params.append(ch_params)
302
+
303
+ # 4. Builds the parameters
304
+ result = {}
305
+ for i, p in enumerate(params_list):
306
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
307
+
308
+ return result
309
+
310
+ except ValueError as e:
311
+ print(f"Could not fit the function with error: {e}")
312
+ print(f"Using fallback result...")
313
+ return {
314
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
315
+ }
316
+
317
+
318
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
319
+ val = torch.amin(x, dim=1)
320
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
321
+
322
+
323
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
324
+ # Calculate the original minimum and maximum values
325
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
326
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
327
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
328
+
329
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
330
+ return torch.ones_like(x_min)
331
+
332
+ # Calculate the scale factor
333
+ scale = (_max - _min) / (x_max - x_min)
334
+ return scale
335
+
336
+
337
+
338
+ ############## Quant ###############
339
+
340
+ @torch.enable_grad()
341
+ def learn_parameters(
342
+ x: torch.Tensor,
343
+ params: Dict[str, nn.Parameter],
344
+ qtz_func: nn.Module,
345
+ deqtz_func: nn.Module,
346
+ bits: int,
347
+ target_dtype: torch.dtype,
348
+ epochs: int = 1000,
349
+ early_stop: bool = True,
350
+ do_report: bool = False
351
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
352
+ loss_fn = nn.MSELoss()
353
+
354
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
355
+ # the order of magnitude of the loss divided by 2
356
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
357
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
358
+ loss = loss_fn(x, dequant)
359
+
360
+ base_lr = 0.1
361
+ exponent = int(np.floor(np.log10(loss.item())))
362
+ lr = base_lr * (10 ** (exponent // 2))
363
+
364
+ # Requires gradients in the parameters
365
+ for p in params.values():
366
+ p.requires_grad = True
367
+ p.grad = None
368
+
369
+ param_keys = list(params.keys())
370
+ param_values = list(params.values())
371
+
372
+ # Defines optimizer and loss function
373
+ optimizer = torch.optim.Adam(param_values, lr=lr)
374
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
375
+
376
+ # Contains the best loss and the best parameters
377
+ best_loss = float("inf")
378
+ best_params = None
379
+
380
+ # Used to stop the search early
381
+ min_delta = 1e-7
382
+ acc_loss = []
383
+ percent_epochs_before_stop = 0.1
384
+
385
+ for i in range(epochs):
386
+ optimizer.zero_grad()
387
+
388
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
389
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
390
+ loss = loss_fn(x, dequant)
391
+
392
+ if loss.isnan() or loss.isinf():
393
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
394
+
395
+ loss.backward()
396
+ optimizer.step()
397
+ scheduler.step()
398
+
399
+ acc_loss.append(loss.item())
400
+
401
+ # Reports loss every 10 steps
402
+ if i % 10 == 0 and do_report:
403
+ print(f"Epoch {i}: Loss {loss.item()}")
404
+
405
+ # Optimizes the parameter search by storing the best loss and the parameters
406
+ if loss.item() < best_loss:
407
+ best_loss = loss.item()
408
+ best_params = copy.deepcopy({
409
+ k: v for k, v in params.items() if k in param_keys
410
+ })
411
+
412
+ # We also stop the search if the loss has not considerably during the last 10% epochs
413
+ if early_stop:
414
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
415
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
416
+ break
417
+
418
+ # No longer requires gradients in the parameters
419
+ for p in best_params.values():
420
+ p.requires_grad = False
421
+ p.grad = None
422
+
423
+ if do_report:
424
+ return best_params, acc_loss
425
+ else:
426
+ return best_params
427
+
428
+
429
+ def quantize(
430
+ x: torch.Tensor,
431
+ params: Dict[str, nn.Parameter],
432
+ func: nn.Module,
433
+ bits: int,
434
+ target_dtype: torch.dtype = torch.int8
435
+ ) -> torch.Tensor:
436
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
437
+ x = x.transpose(0, 1) # Aligns shapes
438
+ x = func(x=x, **params)
439
+ x = x.transpose(0, 1)
440
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
441
+ return x
442
+
443
+
444
+ def dequantize(
445
+ x: torch.Tensor,
446
+ params: Dict[str, nn.Parameter],
447
+ func: nn.Module,
448
+ bits: int,
449
+ out_dtype: torch.dtype
450
+ ) -> torch.Tensor:
451
+ x = x.to(dtype=out_dtype)
452
+ x = x.transpose(0, 1)
453
+ x = func(x=x, **params)
454
+ x = x.transpose(0, 1)
455
+ return x
456
+
457
+
458
+ def round_func_BPDA(input):
459
+ # This is equivalent to replacing round function (non-differentiable) with
460
+ # an identity function (differentiable) only when backward.
461
+ forward_value = torch.round(input)
462
+ out = input.clone()
463
+ out.data = forward_value.data
464
+ return out
465
+
466
+
467
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
468
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
469
+
470
+
471
+
472
+ ############## Numpy ###############
473
+
474
+ def np_domain_guard(
475
+ x: np.ndarray,
476
+ min: float = None,
477
+ max: float = None,
478
+ posinf: float = None,
479
+ neginf: float = None,
480
+ nan: float = None
481
+ ) -> np.ndarray:
482
+ """Guard a tensor to a valid domain."""
483
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
484
+ if min is not None or max is not None:
485
+ x = np.clip(x, min, max)
486
+ return x
487
+
488
+
489
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
490
+ """Replace a number in a tensor with another number.
491
+
492
+ Args:
493
+ x (np.ndarray): The input tensor.
494
+ num (float): The number to replace.
495
+ to (float): The number to replace with.
496
+
497
+ Returns:
498
+ np.ndarray: The tensor with the number replaced.
499
+ """
500
+ return np.where(x == num, to, x)
501
+
502
+
503
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
504
+ """Guard the power operation to a valid domain."""
505
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
506
+
fn_gen/ones_t/0/loss.png ADDED
fn_gen/ones_t/0/quantization.png ADDED
fn_gen/ones_t/1/distortion.png ADDED
fn_gen/ones_t/1/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ (_0*x)**(1/3)/_s
2
+ _s**3*x**3/_0
fn_gen/ones_t/1/fn.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power((params['_0'] * x), 1 / 3))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(3)) * guarded_torch_power(x, torch.tensor(3)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+ params = learn_parameters(x, params,
66
+ qtz_func=quantization,
67
+ deqtz_func=dequantization,
68
+ bits=kwargs['bits'],
69
+ target_dtype=torch.int8,
70
+ epochs=500,
71
+ early_stop=False,
72
+ )
73
+ if 'post_train_hook' in kwargs:
74
+ kwargs['post_train_hook'](parameters=params)
75
+
76
+ return params
77
+
78
+
79
+ ############### Numpy Qtz ###############
80
+
81
+
82
+ def np_quantization(x, _0, _s):
83
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power((_0 * x), 1 / 3))
84
+
85
+
86
+ def np_dequantization(x, _0, _s):
87
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(3)) * np_guarded_power(x, np.array(3)))
88
+
89
+
90
+ def fit_func(x, _0, _s):
91
+ x_ = np_quantization(x, _0, _s)
92
+ x_ = np_dequantization(x_, _0, _s)
93
+ return x_
94
+
95
+
96
+
97
+ ############### HELPERS ###############
98
+
99
+ def domain_guard(
100
+ x: torch.Tensor,
101
+ min: float = None,
102
+ max: float = None,
103
+ posinf: float = None,
104
+ neginf: float = None,
105
+ nan: float = None
106
+ ) -> torch.Tensor:
107
+ """Guard a tensor to a valid domain."""
108
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
109
+ if min is not None or max is not None:
110
+ x = torch.clamp(x, min=min, max=max)
111
+ return x
112
+
113
+
114
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
115
+ """Replace a number in a tensor with another number.
116
+
117
+ Args:
118
+ x (torch.Tensor): The input tensor.
119
+ num (float): The number to replace.
120
+ to (float): The number to replace with.
121
+
122
+ Returns:
123
+ torch.Tensor: The tensor with the number replaced.
124
+ """
125
+ return torch.where(x == num, to, x)
126
+
127
+
128
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
129
+ """Guard the power operation to a valid domain."""
130
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
131
+
132
+
133
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
139
+ val = torch.amin(x, dim=1)
140
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
141
+
142
+
143
+ def init_space_search(
144
+ x: torch.Tensor,
145
+ **kwargs: Dict[str, Any],
146
+ ) -> torch.Tensor:
147
+
148
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
149
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
150
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
151
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
152
+
153
+ def _search_param(tensors: List[torch.tensor], n_params):
154
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
155
+ torch_tensors = torch.stack(tensors)
156
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
157
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
158
+ mean = torch.mean(torch_tensors, dim=0)
159
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
160
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
161
+
162
+ def _calc(x, qtz_func, deqtz_func, **params):
163
+ x_ = x.transpose(0, 1)
164
+ x_ = qtz_func(x=x_, **params)
165
+ x_ = deqtz_func(x=x_, **params)
166
+ x_ = x_.transpose(0, 1)
167
+ return x_
168
+
169
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
170
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
171
+ assert "params_list" in kwargs, "params list must be provided."
172
+ assert "param" in kwargs, "param must be provided."
173
+
174
+ qtz_func = kwargs.get('qtz_func')
175
+ deqtz_func = kwargs.get('deqtz_func')
176
+ params_list = kwargs.get('params_list')
177
+ param = kwargs.get('param')
178
+
179
+ n_runs = 50 # Number of runs to try to find the best parameters
180
+ n_random_params = 50 # Number of random parameters to generate
181
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
182
+ max_initial = 10000 # Maximum value to initialize the parameters
183
+
184
+ # Initializes the parameters
185
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
186
+ params = _build_initial_param(x, max_initial, n_random_params)
187
+
188
+ # Performs the search
189
+ for _ in range(n_runs):
190
+
191
+ best_params = []
192
+ for param_ in params:
193
+ try:
194
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
195
+ loss_ones = nn.MSELoss()(x, x_)
196
+
197
+ if len(best_params) < n_best_to_pick:
198
+ best_params.append((param_, loss_ones.item()))
199
+ best_params = sorted(best_params, key=lambda x: x[1])
200
+ elif loss_ones < best_params[-1][1]:
201
+ best_params[-1] = (param_, loss_ones.item())
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+
204
+ except Exception: # The parameters might not be valid for the function's domain
205
+ continue
206
+
207
+ # Generates new parameters around the mean
208
+ params = _search_param([p for p, _ in best_params], n_random_params)
209
+
210
+ # Checks if the best parameter is better than the init_ones
211
+ p_ones = init_ones(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
213
+ loss_ones = nn.MSELoss()(x, x_)
214
+
215
+ # Checks if the best parameter is better than the init_rand
216
+ p_rand = init_rand(x, **kwargs)
217
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
218
+ loss_rand = nn.MSELoss()(x, x_)
219
+
220
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
221
+ return p_rand
222
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
223
+ return p_ones
224
+ else:
225
+ return best_params[0][0]
226
+
227
+
228
+ def init_linear_scale( # Symmetric scale. From the study folder
229
+ x: torch.Tensor,
230
+ **kwargs: Dict[str, Any],
231
+ ) -> torch.Tensor:
232
+ assert "bits" in kwargs, "bits must be provided."
233
+ assert "params" in kwargs, "params must be provided."
234
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
235
+
236
+ bits = kwargs.get('bits')
237
+ params = kwargs.get('params')
238
+ qtz_func = kwargs.get('qtz_func')
239
+
240
+ x_ = x.transpose(0, 1)
241
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
242
+ x_ = x_.transpose(0, 1)
243
+
244
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
245
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
246
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
247
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
248
+
249
+ eps = torch.finfo(torch.float32).eps
250
+
251
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
252
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
253
+
254
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
255
+
256
+ # Introduces some noise in scale
257
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
258
+ # scale = scale + 0.01 * torch.randn_like(scale)
259
+ return scale
260
+
261
+
262
+ def init_non_linear_regression_fit(
263
+ x: torch.Tensor,
264
+ **kwargs: Dict[str, Any],
265
+ ) -> torch.Tensor:
266
+
267
+ assert "params_list" in kwargs, "params list must be provided."
268
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
269
+ assert "p0" in kwargs, "p0 must be provided."
270
+ np_fit_func = kwargs.get('np_fit_func')
271
+ params_list = kwargs.get('params_list')
272
+ p0 = kwargs.get('p0')
273
+
274
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
275
+ popt, _ = curve_fit(
276
+ func,
277
+ xdata,
278
+ ydata,
279
+ maxfev=1000,
280
+ p0=p0,
281
+ method='lm'
282
+ )
283
+ return popt
284
+
285
+ # 1. Needs to convert the torch tensor to numpy tensor
286
+ xdata = x.cpu().numpy()
287
+
288
+ # 2. Sorts the data so that it makes it easier to fit to it
289
+ sorted_xdata = np.sort(xdata, axis=-1)
290
+
291
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
292
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
293
+
294
+ # 3. Finds the best parameters for each channel
295
+ try:
296
+ params = []
297
+ for i in range(sorted_xdata.shape[0]):
298
+ xdata_ = sorted_xdata[i]
299
+ p0_ = [p0[p][i] for p in params_list]
300
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
301
+ params.append(ch_params)
302
+
303
+ # 4. Builds the parameters
304
+ result = {}
305
+ for i, p in enumerate(params_list):
306
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
307
+
308
+ return result
309
+
310
+ except ValueError as e:
311
+ print(f"Could not fit the function with error: {e}")
312
+ print(f"Using fallback result...")
313
+ return {
314
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
315
+ }
316
+
317
+
318
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
319
+ val = torch.amin(x, dim=1)
320
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
321
+
322
+
323
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
324
+ # Calculate the original minimum and maximum values
325
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
326
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
327
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
328
+
329
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
330
+ return torch.ones_like(x_min)
331
+
332
+ # Calculate the scale factor
333
+ scale = (_max - _min) / (x_max - x_min)
334
+ return scale
335
+
336
+
337
+
338
+ ############## Quant ###############
339
+
340
+ @torch.enable_grad()
341
+ def learn_parameters(
342
+ x: torch.Tensor,
343
+ params: Dict[str, nn.Parameter],
344
+ qtz_func: nn.Module,
345
+ deqtz_func: nn.Module,
346
+ bits: int,
347
+ target_dtype: torch.dtype,
348
+ epochs: int = 1000,
349
+ early_stop: bool = True,
350
+ do_report: bool = False
351
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
352
+ loss_fn = nn.MSELoss()
353
+
354
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
355
+ # the order of magnitude of the loss divided by 2
356
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
357
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
358
+ loss = loss_fn(x, dequant)
359
+
360
+ base_lr = 0.1
361
+ exponent = int(np.floor(np.log10(loss.item())))
362
+ lr = base_lr * (10 ** (exponent // 2))
363
+
364
+ # Requires gradients in the parameters
365
+ for p in params.values():
366
+ p.requires_grad = True
367
+ p.grad = None
368
+
369
+ param_keys = list(params.keys())
370
+ param_values = list(params.values())
371
+
372
+ # Defines optimizer and loss function
373
+ optimizer = torch.optim.Adam(param_values, lr=lr)
374
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
375
+
376
+ # Contains the best loss and the best parameters
377
+ best_loss = float("inf")
378
+ best_params = None
379
+
380
+ # Used to stop the search early
381
+ min_delta = 1e-7
382
+ acc_loss = []
383
+ percent_epochs_before_stop = 0.1
384
+
385
+ for i in range(epochs):
386
+ optimizer.zero_grad()
387
+
388
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
389
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
390
+ loss = loss_fn(x, dequant)
391
+
392
+ if loss.isnan() or loss.isinf():
393
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
394
+
395
+ loss.backward()
396
+ optimizer.step()
397
+ scheduler.step()
398
+
399
+ acc_loss.append(loss.item())
400
+
401
+ # Reports loss every 10 steps
402
+ if i % 10 == 0 and do_report:
403
+ print(f"Epoch {i}: Loss {loss.item()}")
404
+
405
+ # Optimizes the parameter search by storing the best loss and the parameters
406
+ if loss.item() < best_loss:
407
+ best_loss = loss.item()
408
+ best_params = copy.deepcopy({
409
+ k: v for k, v in params.items() if k in param_keys
410
+ })
411
+
412
+ # We also stop the search if the loss has not considerably during the last 10% epochs
413
+ if early_stop:
414
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
415
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
416
+ break
417
+
418
+ # No longer requires gradients in the parameters
419
+ for p in best_params.values():
420
+ p.requires_grad = False
421
+ p.grad = None
422
+
423
+ if do_report:
424
+ return best_params, acc_loss
425
+ else:
426
+ return best_params
427
+
428
+
429
+ def quantize(
430
+ x: torch.Tensor,
431
+ params: Dict[str, nn.Parameter],
432
+ func: nn.Module,
433
+ bits: int,
434
+ target_dtype: torch.dtype = torch.int8
435
+ ) -> torch.Tensor:
436
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
437
+ x = x.transpose(0, 1) # Aligns shapes
438
+ x = func(x=x, **params)
439
+ x = x.transpose(0, 1)
440
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
441
+ return x
442
+
443
+
444
+ def dequantize(
445
+ x: torch.Tensor,
446
+ params: Dict[str, nn.Parameter],
447
+ func: nn.Module,
448
+ bits: int,
449
+ out_dtype: torch.dtype
450
+ ) -> torch.Tensor:
451
+ x = x.to(dtype=out_dtype)
452
+ x = x.transpose(0, 1)
453
+ x = func(x=x, **params)
454
+ x = x.transpose(0, 1)
455
+ return x
456
+
457
+
458
+ def round_func_BPDA(input):
459
+ # This is equivalent to replacing round function (non-differentiable) with
460
+ # an identity function (differentiable) only when backward.
461
+ forward_value = torch.round(input)
462
+ out = input.clone()
463
+ out.data = forward_value.data
464
+ return out
465
+
466
+
467
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
468
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
469
+
470
+
471
+
472
+ ############## Numpy ###############
473
+
474
+ def np_domain_guard(
475
+ x: np.ndarray,
476
+ min: float = None,
477
+ max: float = None,
478
+ posinf: float = None,
479
+ neginf: float = None,
480
+ nan: float = None
481
+ ) -> np.ndarray:
482
+ """Guard a tensor to a valid domain."""
483
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
484
+ if min is not None or max is not None:
485
+ x = np.clip(x, min, max)
486
+ return x
487
+
488
+
489
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
490
+ """Replace a number in a tensor with another number.
491
+
492
+ Args:
493
+ x (np.ndarray): The input tensor.
494
+ num (float): The number to replace.
495
+ to (float): The number to replace with.
496
+
497
+ Returns:
498
+ np.ndarray: The tensor with the number replaced.
499
+ """
500
+ return np.where(x == num, to, x)
501
+
502
+
503
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
504
+ """Guard the power operation to a valid domain."""
505
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
506
+
fn_gen/ones_t/1/loss.png ADDED
fn_gen/ones_t/1/quantization.png ADDED
fn_gen/ones_t/10/distortion.png ADDED
fn_gen/ones_t/10/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acos(_0*x)/_s
2
+ cos(_s*x)/_0
fn_gen/ones_t/10/fn.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+ params = learn_parameters(x, params,
66
+ qtz_func=quantization,
67
+ deqtz_func=dequantization,
68
+ bits=kwargs['bits'],
69
+ target_dtype=torch.int8,
70
+ epochs=500,
71
+ early_stop=False,
72
+ )
73
+ if 'post_train_hook' in kwargs:
74
+ kwargs['post_train_hook'](parameters=params)
75
+
76
+ return params
77
+
78
+
79
+ ############### Numpy Qtz ###############
80
+
81
+
82
+ def np_quantization(x, _0, _s):
83
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
84
+
85
+
86
+ def np_dequantization(x, _0, _s):
87
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x)))
88
+
89
+
90
+ def fit_func(x, _0, _s):
91
+ x_ = np_quantization(x, _0, _s)
92
+ x_ = np_dequantization(x_, _0, _s)
93
+ return x_
94
+
95
+
96
+
97
+ ############### HELPERS ###############
98
+
99
+ def domain_guard(
100
+ x: torch.Tensor,
101
+ min: float = None,
102
+ max: float = None,
103
+ posinf: float = None,
104
+ neginf: float = None,
105
+ nan: float = None
106
+ ) -> torch.Tensor:
107
+ """Guard a tensor to a valid domain."""
108
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
109
+ if min is not None or max is not None:
110
+ x = torch.clamp(x, min=min, max=max)
111
+ return x
112
+
113
+
114
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
115
+ """Replace a number in a tensor with another number.
116
+
117
+ Args:
118
+ x (torch.Tensor): The input tensor.
119
+ num (float): The number to replace.
120
+ to (float): The number to replace with.
121
+
122
+ Returns:
123
+ torch.Tensor: The tensor with the number replaced.
124
+ """
125
+ return torch.where(x == num, to, x)
126
+
127
+
128
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
129
+ """Guard the power operation to a valid domain."""
130
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
131
+
132
+
133
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
139
+ val = torch.amin(x, dim=1)
140
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
141
+
142
+
143
+ def init_space_search(
144
+ x: torch.Tensor,
145
+ **kwargs: Dict[str, Any],
146
+ ) -> torch.Tensor:
147
+
148
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
149
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
150
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
151
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
152
+
153
+ def _search_param(tensors: List[torch.tensor], n_params):
154
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
155
+ torch_tensors = torch.stack(tensors)
156
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
157
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
158
+ mean = torch.mean(torch_tensors, dim=0)
159
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
160
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
161
+
162
+ def _calc(x, qtz_func, deqtz_func, **params):
163
+ x_ = x.transpose(0, 1)
164
+ x_ = qtz_func(x=x_, **params)
165
+ x_ = deqtz_func(x=x_, **params)
166
+ x_ = x_.transpose(0, 1)
167
+ return x_
168
+
169
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
170
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
171
+ assert "params_list" in kwargs, "params list must be provided."
172
+ assert "param" in kwargs, "param must be provided."
173
+
174
+ qtz_func = kwargs.get('qtz_func')
175
+ deqtz_func = kwargs.get('deqtz_func')
176
+ params_list = kwargs.get('params_list')
177
+ param = kwargs.get('param')
178
+
179
+ n_runs = 50 # Number of runs to try to find the best parameters
180
+ n_random_params = 50 # Number of random parameters to generate
181
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
182
+ max_initial = 10000 # Maximum value to initialize the parameters
183
+
184
+ # Initializes the parameters
185
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
186
+ params = _build_initial_param(x, max_initial, n_random_params)
187
+
188
+ # Performs the search
189
+ for _ in range(n_runs):
190
+
191
+ best_params = []
192
+ for param_ in params:
193
+ try:
194
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
195
+ loss_ones = nn.MSELoss()(x, x_)
196
+
197
+ if len(best_params) < n_best_to_pick:
198
+ best_params.append((param_, loss_ones.item()))
199
+ best_params = sorted(best_params, key=lambda x: x[1])
200
+ elif loss_ones < best_params[-1][1]:
201
+ best_params[-1] = (param_, loss_ones.item())
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+
204
+ except Exception: # The parameters might not be valid for the function's domain
205
+ continue
206
+
207
+ # Generates new parameters around the mean
208
+ params = _search_param([p for p, _ in best_params], n_random_params)
209
+
210
+ # Checks if the best parameter is better than the init_ones
211
+ p_ones = init_ones(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
213
+ loss_ones = nn.MSELoss()(x, x_)
214
+
215
+ # Checks if the best parameter is better than the init_rand
216
+ p_rand = init_rand(x, **kwargs)
217
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
218
+ loss_rand = nn.MSELoss()(x, x_)
219
+
220
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
221
+ return p_rand
222
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
223
+ return p_ones
224
+ else:
225
+ return best_params[0][0]
226
+
227
+
228
+ def init_linear_scale( # Symmetric scale. From the study folder
229
+ x: torch.Tensor,
230
+ **kwargs: Dict[str, Any],
231
+ ) -> torch.Tensor:
232
+ assert "bits" in kwargs, "bits must be provided."
233
+ assert "params" in kwargs, "params must be provided."
234
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
235
+
236
+ bits = kwargs.get('bits')
237
+ params = kwargs.get('params')
238
+ qtz_func = kwargs.get('qtz_func')
239
+
240
+ x_ = x.transpose(0, 1)
241
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
242
+ x_ = x_.transpose(0, 1)
243
+
244
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
245
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
246
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
247
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
248
+
249
+ eps = torch.finfo(torch.float32).eps
250
+
251
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
252
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
253
+
254
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
255
+
256
+ # Introduces some noise in scale
257
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
258
+ # scale = scale + 0.01 * torch.randn_like(scale)
259
+ return scale
260
+
261
+
262
+ def init_non_linear_regression_fit(
263
+ x: torch.Tensor,
264
+ **kwargs: Dict[str, Any],
265
+ ) -> torch.Tensor:
266
+
267
+ assert "params_list" in kwargs, "params list must be provided."
268
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
269
+ assert "p0" in kwargs, "p0 must be provided."
270
+ np_fit_func = kwargs.get('np_fit_func')
271
+ params_list = kwargs.get('params_list')
272
+ p0 = kwargs.get('p0')
273
+
274
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
275
+ popt, _ = curve_fit(
276
+ func,
277
+ xdata,
278
+ ydata,
279
+ maxfev=1000,
280
+ p0=p0,
281
+ method='lm'
282
+ )
283
+ return popt
284
+
285
+ # 1. Needs to convert the torch tensor to numpy tensor
286
+ xdata = x.cpu().numpy()
287
+
288
+ # 2. Sorts the data so that it makes it easier to fit to it
289
+ sorted_xdata = np.sort(xdata, axis=-1)
290
+
291
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
292
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
293
+
294
+ # 3. Finds the best parameters for each channel
295
+ try:
296
+ params = []
297
+ for i in range(sorted_xdata.shape[0]):
298
+ xdata_ = sorted_xdata[i]
299
+ p0_ = [p0[p][i] for p in params_list]
300
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
301
+ params.append(ch_params)
302
+
303
+ # 4. Builds the parameters
304
+ result = {}
305
+ for i, p in enumerate(params_list):
306
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
307
+
308
+ return result
309
+
310
+ except ValueError as e:
311
+ print(f"Could not fit the function with error: {e}")
312
+ print(f"Using fallback result...")
313
+ return {
314
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
315
+ }
316
+
317
+
318
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
319
+ val = torch.amin(x, dim=1)
320
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
321
+
322
+
323
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
324
+ # Calculate the original minimum and maximum values
325
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
326
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
327
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
328
+
329
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
330
+ return torch.ones_like(x_min)
331
+
332
+ # Calculate the scale factor
333
+ scale = (_max - _min) / (x_max - x_min)
334
+ return scale
335
+
336
+
337
+
338
+ ############## Quant ###############
339
+
340
+ @torch.enable_grad()
341
+ def learn_parameters(
342
+ x: torch.Tensor,
343
+ params: Dict[str, nn.Parameter],
344
+ qtz_func: nn.Module,
345
+ deqtz_func: nn.Module,
346
+ bits: int,
347
+ target_dtype: torch.dtype,
348
+ epochs: int = 1000,
349
+ early_stop: bool = True,
350
+ do_report: bool = False
351
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
352
+ loss_fn = nn.MSELoss()
353
+
354
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
355
+ # the order of magnitude of the loss divided by 2
356
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
357
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
358
+ loss = loss_fn(x, dequant)
359
+
360
+ base_lr = 0.1
361
+ exponent = int(np.floor(np.log10(loss.item())))
362
+ lr = base_lr * (10 ** (exponent // 2))
363
+
364
+ # Requires gradients in the parameters
365
+ for p in params.values():
366
+ p.requires_grad = True
367
+ p.grad = None
368
+
369
+ param_keys = list(params.keys())
370
+ param_values = list(params.values())
371
+
372
+ # Defines optimizer and loss function
373
+ optimizer = torch.optim.Adam(param_values, lr=lr)
374
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
375
+
376
+ # Contains the best loss and the best parameters
377
+ best_loss = float("inf")
378
+ best_params = None
379
+
380
+ # Used to stop the search early
381
+ min_delta = 1e-7
382
+ acc_loss = []
383
+ percent_epochs_before_stop = 0.1
384
+
385
+ for i in range(epochs):
386
+ optimizer.zero_grad()
387
+
388
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
389
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
390
+ loss = loss_fn(x, dequant)
391
+
392
+ if loss.isnan() or loss.isinf():
393
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
394
+
395
+ loss.backward()
396
+ optimizer.step()
397
+ scheduler.step()
398
+
399
+ acc_loss.append(loss.item())
400
+
401
+ # Reports loss every 10 steps
402
+ if i % 10 == 0 and do_report:
403
+ print(f"Epoch {i}: Loss {loss.item()}")
404
+
405
+ # Optimizes the parameter search by storing the best loss and the parameters
406
+ if loss.item() < best_loss:
407
+ best_loss = loss.item()
408
+ best_params = copy.deepcopy({
409
+ k: v for k, v in params.items() if k in param_keys
410
+ })
411
+
412
+ # We also stop the search if the loss has not considerably during the last 10% epochs
413
+ if early_stop:
414
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
415
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
416
+ break
417
+
418
+ # No longer requires gradients in the parameters
419
+ for p in best_params.values():
420
+ p.requires_grad = False
421
+ p.grad = None
422
+
423
+ if do_report:
424
+ return best_params, acc_loss
425
+ else:
426
+ return best_params
427
+
428
+
429
+ def quantize(
430
+ x: torch.Tensor,
431
+ params: Dict[str, nn.Parameter],
432
+ func: nn.Module,
433
+ bits: int,
434
+ target_dtype: torch.dtype = torch.int8
435
+ ) -> torch.Tensor:
436
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
437
+ x = x.transpose(0, 1) # Aligns shapes
438
+ x = func(x=x, **params)
439
+ x = x.transpose(0, 1)
440
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
441
+ return x
442
+
443
+
444
+ def dequantize(
445
+ x: torch.Tensor,
446
+ params: Dict[str, nn.Parameter],
447
+ func: nn.Module,
448
+ bits: int,
449
+ out_dtype: torch.dtype
450
+ ) -> torch.Tensor:
451
+ x = x.to(dtype=out_dtype)
452
+ x = x.transpose(0, 1)
453
+ x = func(x=x, **params)
454
+ x = x.transpose(0, 1)
455
+ return x
456
+
457
+
458
+ def round_func_BPDA(input):
459
+ # This is equivalent to replacing round function (non-differentiable) with
460
+ # an identity function (differentiable) only when backward.
461
+ forward_value = torch.round(input)
462
+ out = input.clone()
463
+ out.data = forward_value.data
464
+ return out
465
+
466
+
467
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
468
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
469
+
470
+
471
+
472
+ ############## Numpy ###############
473
+
474
+ def np_domain_guard(
475
+ x: np.ndarray,
476
+ min: float = None,
477
+ max: float = None,
478
+ posinf: float = None,
479
+ neginf: float = None,
480
+ nan: float = None
481
+ ) -> np.ndarray:
482
+ """Guard a tensor to a valid domain."""
483
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
484
+ if min is not None or max is not None:
485
+ x = np.clip(x, min, max)
486
+ return x
487
+
488
+
489
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
490
+ """Replace a number in a tensor with another number.
491
+
492
+ Args:
493
+ x (np.ndarray): The input tensor.
494
+ num (float): The number to replace.
495
+ to (float): The number to replace with.
496
+
497
+ Returns:
498
+ np.ndarray: The tensor with the number replaced.
499
+ """
500
+ return np.where(x == num, to, x)
501
+
502
+
503
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
504
+ """Guard the power operation to a valid domain."""
505
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
506
+
fn_gen/ones_t/10/loss.png ADDED
fn_gen/ones_t/10/quantization.png ADDED
fn_gen/ones_t/11/distortion.png ADDED
fn_gen/ones_t/11/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sqrt(_0*x)/_s
2
+ _s**2*x**2/_0
fn_gen/ones_t/11/fn.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sqrt(domain_guard((params['_0'] * x), min=0.1, nan=0.1)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+ params = learn_parameters(x, params,
66
+ qtz_func=quantization,
67
+ deqtz_func=dequantization,
68
+ bits=kwargs['bits'],
69
+ target_dtype=torch.int8,
70
+ epochs=500,
71
+ early_stop=False,
72
+ )
73
+ if 'post_train_hook' in kwargs:
74
+ kwargs['post_train_hook'](parameters=params)
75
+
76
+ return params
77
+
78
+
79
+ ############### Numpy Qtz ###############
80
+
81
+
82
+ def np_quantization(x, _0, _s):
83
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sqrt(np_domain_guard((_0 * x), min=0.1, nan=0.1)))
84
+
85
+
86
+ def np_dequantization(x, _0, _s):
87
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))
88
+
89
+
90
+ def fit_func(x, _0, _s):
91
+ x_ = np_quantization(x, _0, _s)
92
+ x_ = np_dequantization(x_, _0, _s)
93
+ return x_
94
+
95
+
96
+
97
+ ############### HELPERS ###############
98
+
99
+ def domain_guard(
100
+ x: torch.Tensor,
101
+ min: float = None,
102
+ max: float = None,
103
+ posinf: float = None,
104
+ neginf: float = None,
105
+ nan: float = None
106
+ ) -> torch.Tensor:
107
+ """Guard a tensor to a valid domain."""
108
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
109
+ if min is not None or max is not None:
110
+ x = torch.clamp(x, min=min, max=max)
111
+ return x
112
+
113
+
114
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
115
+ """Replace a number in a tensor with another number.
116
+
117
+ Args:
118
+ x (torch.Tensor): The input tensor.
119
+ num (float): The number to replace.
120
+ to (float): The number to replace with.
121
+
122
+ Returns:
123
+ torch.Tensor: The tensor with the number replaced.
124
+ """
125
+ return torch.where(x == num, to, x)
126
+
127
+
128
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
129
+ """Guard the power operation to a valid domain."""
130
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
131
+
132
+
133
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
139
+ val = torch.amin(x, dim=1)
140
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
141
+
142
+
143
+ def init_space_search(
144
+ x: torch.Tensor,
145
+ **kwargs: Dict[str, Any],
146
+ ) -> torch.Tensor:
147
+
148
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
149
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
150
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
151
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
152
+
153
+ def _search_param(tensors: List[torch.tensor], n_params):
154
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
155
+ torch_tensors = torch.stack(tensors)
156
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
157
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
158
+ mean = torch.mean(torch_tensors, dim=0)
159
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
160
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
161
+
162
+ def _calc(x, qtz_func, deqtz_func, **params):
163
+ x_ = x.transpose(0, 1)
164
+ x_ = qtz_func(x=x_, **params)
165
+ x_ = deqtz_func(x=x_, **params)
166
+ x_ = x_.transpose(0, 1)
167
+ return x_
168
+
169
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
170
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
171
+ assert "params_list" in kwargs, "params list must be provided."
172
+ assert "param" in kwargs, "param must be provided."
173
+
174
+ qtz_func = kwargs.get('qtz_func')
175
+ deqtz_func = kwargs.get('deqtz_func')
176
+ params_list = kwargs.get('params_list')
177
+ param = kwargs.get('param')
178
+
179
+ n_runs = 50 # Number of runs to try to find the best parameters
180
+ n_random_params = 50 # Number of random parameters to generate
181
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
182
+ max_initial = 10000 # Maximum value to initialize the parameters
183
+
184
+ # Initializes the parameters
185
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
186
+ params = _build_initial_param(x, max_initial, n_random_params)
187
+
188
+ # Performs the search
189
+ for _ in range(n_runs):
190
+
191
+ best_params = []
192
+ for param_ in params:
193
+ try:
194
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
195
+ loss_ones = nn.MSELoss()(x, x_)
196
+
197
+ if len(best_params) < n_best_to_pick:
198
+ best_params.append((param_, loss_ones.item()))
199
+ best_params = sorted(best_params, key=lambda x: x[1])
200
+ elif loss_ones < best_params[-1][1]:
201
+ best_params[-1] = (param_, loss_ones.item())
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+
204
+ except Exception: # The parameters might not be valid for the function's domain
205
+ continue
206
+
207
+ # Generates new parameters around the mean
208
+ params = _search_param([p for p, _ in best_params], n_random_params)
209
+
210
+ # Checks if the best parameter is better than the init_ones
211
+ p_ones = init_ones(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
213
+ loss_ones = nn.MSELoss()(x, x_)
214
+
215
+ # Checks if the best parameter is better than the init_rand
216
+ p_rand = init_rand(x, **kwargs)
217
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
218
+ loss_rand = nn.MSELoss()(x, x_)
219
+
220
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
221
+ return p_rand
222
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
223
+ return p_ones
224
+ else:
225
+ return best_params[0][0]
226
+
227
+
228
+ def init_linear_scale( # Symmetric scale. From the study folder
229
+ x: torch.Tensor,
230
+ **kwargs: Dict[str, Any],
231
+ ) -> torch.Tensor:
232
+ assert "bits" in kwargs, "bits must be provided."
233
+ assert "params" in kwargs, "params must be provided."
234
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
235
+
236
+ bits = kwargs.get('bits')
237
+ params = kwargs.get('params')
238
+ qtz_func = kwargs.get('qtz_func')
239
+
240
+ x_ = x.transpose(0, 1)
241
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
242
+ x_ = x_.transpose(0, 1)
243
+
244
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
245
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
246
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
247
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
248
+
249
+ eps = torch.finfo(torch.float32).eps
250
+
251
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
252
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
253
+
254
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
255
+
256
+ # Introduces some noise in scale
257
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
258
+ # scale = scale + 0.01 * torch.randn_like(scale)
259
+ return scale
260
+
261
+
262
+ def init_non_linear_regression_fit(
263
+ x: torch.Tensor,
264
+ **kwargs: Dict[str, Any],
265
+ ) -> torch.Tensor:
266
+
267
+ assert "params_list" in kwargs, "params list must be provided."
268
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
269
+ assert "p0" in kwargs, "p0 must be provided."
270
+ np_fit_func = kwargs.get('np_fit_func')
271
+ params_list = kwargs.get('params_list')
272
+ p0 = kwargs.get('p0')
273
+
274
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
275
+ popt, _ = curve_fit(
276
+ func,
277
+ xdata,
278
+ ydata,
279
+ maxfev=1000,
280
+ p0=p0,
281
+ method='lm'
282
+ )
283
+ return popt
284
+
285
+ # 1. Needs to convert the torch tensor to numpy tensor
286
+ xdata = x.cpu().numpy()
287
+
288
+ # 2. Sorts the data so that it makes it easier to fit to it
289
+ sorted_xdata = np.sort(xdata, axis=-1)
290
+
291
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
292
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
293
+
294
+ # 3. Finds the best parameters for each channel
295
+ try:
296
+ params = []
297
+ for i in range(sorted_xdata.shape[0]):
298
+ xdata_ = sorted_xdata[i]
299
+ p0_ = [p0[p][i] for p in params_list]
300
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
301
+ params.append(ch_params)
302
+
303
+ # 4. Builds the parameters
304
+ result = {}
305
+ for i, p in enumerate(params_list):
306
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
307
+
308
+ return result
309
+
310
+ except ValueError as e:
311
+ print(f"Could not fit the function with error: {e}")
312
+ print(f"Using fallback result...")
313
+ return {
314
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
315
+ }
316
+
317
+
318
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
319
+ val = torch.amin(x, dim=1)
320
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
321
+
322
+
323
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
324
+ # Calculate the original minimum and maximum values
325
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
326
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
327
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
328
+
329
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
330
+ return torch.ones_like(x_min)
331
+
332
+ # Calculate the scale factor
333
+ scale = (_max - _min) / (x_max - x_min)
334
+ return scale
335
+
336
+
337
+
338
+ ############## Quant ###############
339
+
340
+ @torch.enable_grad()
341
+ def learn_parameters(
342
+ x: torch.Tensor,
343
+ params: Dict[str, nn.Parameter],
344
+ qtz_func: nn.Module,
345
+ deqtz_func: nn.Module,
346
+ bits: int,
347
+ target_dtype: torch.dtype,
348
+ epochs: int = 1000,
349
+ early_stop: bool = True,
350
+ do_report: bool = False
351
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
352
+ loss_fn = nn.MSELoss()
353
+
354
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
355
+ # the order of magnitude of the loss divided by 2
356
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
357
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
358
+ loss = loss_fn(x, dequant)
359
+
360
+ base_lr = 0.1
361
+ exponent = int(np.floor(np.log10(loss.item())))
362
+ lr = base_lr * (10 ** (exponent // 2))
363
+
364
+ # Requires gradients in the parameters
365
+ for p in params.values():
366
+ p.requires_grad = True
367
+ p.grad = None
368
+
369
+ param_keys = list(params.keys())
370
+ param_values = list(params.values())
371
+
372
+ # Defines optimizer and loss function
373
+ optimizer = torch.optim.Adam(param_values, lr=lr)
374
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
375
+
376
+ # Contains the best loss and the best parameters
377
+ best_loss = float("inf")
378
+ best_params = None
379
+
380
+ # Used to stop the search early
381
+ min_delta = 1e-7
382
+ acc_loss = []
383
+ percent_epochs_before_stop = 0.1
384
+
385
+ for i in range(epochs):
386
+ optimizer.zero_grad()
387
+
388
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
389
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
390
+ loss = loss_fn(x, dequant)
391
+
392
+ if loss.isnan() or loss.isinf():
393
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
394
+
395
+ loss.backward()
396
+ optimizer.step()
397
+ scheduler.step()
398
+
399
+ acc_loss.append(loss.item())
400
+
401
+ # Reports loss every 10 steps
402
+ if i % 10 == 0 and do_report:
403
+ print(f"Epoch {i}: Loss {loss.item()}")
404
+
405
+ # Optimizes the parameter search by storing the best loss and the parameters
406
+ if loss.item() < best_loss:
407
+ best_loss = loss.item()
408
+ best_params = copy.deepcopy({
409
+ k: v for k, v in params.items() if k in param_keys
410
+ })
411
+
412
+ # We also stop the search if the loss has not considerably during the last 10% epochs
413
+ if early_stop:
414
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
415
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
416
+ break
417
+
418
+ # No longer requires gradients in the parameters
419
+ for p in best_params.values():
420
+ p.requires_grad = False
421
+ p.grad = None
422
+
423
+ if do_report:
424
+ return best_params, acc_loss
425
+ else:
426
+ return best_params
427
+
428
+
429
+ def quantize(
430
+ x: torch.Tensor,
431
+ params: Dict[str, nn.Parameter],
432
+ func: nn.Module,
433
+ bits: int,
434
+ target_dtype: torch.dtype = torch.int8
435
+ ) -> torch.Tensor:
436
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
437
+ x = x.transpose(0, 1) # Aligns shapes
438
+ x = func(x=x, **params)
439
+ x = x.transpose(0, 1)
440
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
441
+ return x
442
+
443
+
444
+ def dequantize(
445
+ x: torch.Tensor,
446
+ params: Dict[str, nn.Parameter],
447
+ func: nn.Module,
448
+ bits: int,
449
+ out_dtype: torch.dtype
450
+ ) -> torch.Tensor:
451
+ x = x.to(dtype=out_dtype)
452
+ x = x.transpose(0, 1)
453
+ x = func(x=x, **params)
454
+ x = x.transpose(0, 1)
455
+ return x
456
+
457
+
458
+ def round_func_BPDA(input):
459
+ # This is equivalent to replacing round function (non-differentiable) with
460
+ # an identity function (differentiable) only when backward.
461
+ forward_value = torch.round(input)
462
+ out = input.clone()
463
+ out.data = forward_value.data
464
+ return out
465
+
466
+
467
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
468
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
469
+
470
+
471
+
472
+ ############## Numpy ###############
473
+
474
+ def np_domain_guard(
475
+ x: np.ndarray,
476
+ min: float = None,
477
+ max: float = None,
478
+ posinf: float = None,
479
+ neginf: float = None,
480
+ nan: float = None
481
+ ) -> np.ndarray:
482
+ """Guard a tensor to a valid domain."""
483
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
484
+ if min is not None or max is not None:
485
+ x = np.clip(x, min, max)
486
+ return x
487
+
488
+
489
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
490
+ """Replace a number in a tensor with another number.
491
+
492
+ Args:
493
+ x (np.ndarray): The input tensor.
494
+ num (float): The number to replace.
495
+ to (float): The number to replace with.
496
+
497
+ Returns:
498
+ np.ndarray: The tensor with the number replaced.
499
+ """
500
+ return np.where(x == num, to, x)
501
+
502
+
503
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
504
+ """Guard the power operation to a valid domain."""
505
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
506
+
fn_gen/ones_t/11/loss.png ADDED
fn_gen/ones_t/11/quantization.png ADDED
fn_gen/ones_t/12/distortion.png ADDED
fn_gen/ones_t/12/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sinh(_0*x)/_s
2
+ log(_s*x - sqrt(_s**2*x**2 + 1))/_0
fn_gen/ones_t/12/fn.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+ params = learn_parameters(x, params,
66
+ qtz_func=quantization,
67
+ deqtz_func=dequantization,
68
+ bits=kwargs['bits'],
69
+ target_dtype=torch.int8,
70
+ epochs=500,
71
+ early_stop=False,
72
+ )
73
+ if 'post_train_hook' in kwargs:
74
+ kwargs['post_train_hook'](parameters=params)
75
+
76
+ return params
77
+
78
+
79
+ ############### Numpy Qtz ###############
80
+
81
+
82
+ def np_quantization(x, _0, _s):
83
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sinh((_0 * x)))
84
+
85
+
86
+ def np_dequantization(x, _0, _s):
87
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5)))
88
+
89
+
90
+ def fit_func(x, _0, _s):
91
+ x_ = np_quantization(x, _0, _s)
92
+ x_ = np_dequantization(x_, _0, _s)
93
+ return x_
94
+
95
+
96
+
97
+ ############### HELPERS ###############
98
+
99
+ def domain_guard(
100
+ x: torch.Tensor,
101
+ min: float = None,
102
+ max: float = None,
103
+ posinf: float = None,
104
+ neginf: float = None,
105
+ nan: float = None
106
+ ) -> torch.Tensor:
107
+ """Guard a tensor to a valid domain."""
108
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
109
+ if min is not None or max is not None:
110
+ x = torch.clamp(x, min=min, max=max)
111
+ return x
112
+
113
+
114
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
115
+ """Replace a number in a tensor with another number.
116
+
117
+ Args:
118
+ x (torch.Tensor): The input tensor.
119
+ num (float): The number to replace.
120
+ to (float): The number to replace with.
121
+
122
+ Returns:
123
+ torch.Tensor: The tensor with the number replaced.
124
+ """
125
+ return torch.where(x == num, to, x)
126
+
127
+
128
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
129
+ """Guard the power operation to a valid domain."""
130
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
131
+
132
+
133
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
139
+ val = torch.amin(x, dim=1)
140
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
141
+
142
+
143
+ def init_space_search(
144
+ x: torch.Tensor,
145
+ **kwargs: Dict[str, Any],
146
+ ) -> torch.Tensor:
147
+
148
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
149
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
150
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
151
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
152
+
153
+ def _search_param(tensors: List[torch.tensor], n_params):
154
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
155
+ torch_tensors = torch.stack(tensors)
156
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
157
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
158
+ mean = torch.mean(torch_tensors, dim=0)
159
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
160
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
161
+
162
+ def _calc(x, qtz_func, deqtz_func, **params):
163
+ x_ = x.transpose(0, 1)
164
+ x_ = qtz_func(x=x_, **params)
165
+ x_ = deqtz_func(x=x_, **params)
166
+ x_ = x_.transpose(0, 1)
167
+ return x_
168
+
169
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
170
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
171
+ assert "params_list" in kwargs, "params list must be provided."
172
+ assert "param" in kwargs, "param must be provided."
173
+
174
+ qtz_func = kwargs.get('qtz_func')
175
+ deqtz_func = kwargs.get('deqtz_func')
176
+ params_list = kwargs.get('params_list')
177
+ param = kwargs.get('param')
178
+
179
+ n_runs = 50 # Number of runs to try to find the best parameters
180
+ n_random_params = 50 # Number of random parameters to generate
181
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
182
+ max_initial = 10000 # Maximum value to initialize the parameters
183
+
184
+ # Initializes the parameters
185
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
186
+ params = _build_initial_param(x, max_initial, n_random_params)
187
+
188
+ # Performs the search
189
+ for _ in range(n_runs):
190
+
191
+ best_params = []
192
+ for param_ in params:
193
+ try:
194
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
195
+ loss_ones = nn.MSELoss()(x, x_)
196
+
197
+ if len(best_params) < n_best_to_pick:
198
+ best_params.append((param_, loss_ones.item()))
199
+ best_params = sorted(best_params, key=lambda x: x[1])
200
+ elif loss_ones < best_params[-1][1]:
201
+ best_params[-1] = (param_, loss_ones.item())
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+
204
+ except Exception: # The parameters might not be valid for the function's domain
205
+ continue
206
+
207
+ # Generates new parameters around the mean
208
+ params = _search_param([p for p, _ in best_params], n_random_params)
209
+
210
+ # Checks if the best parameter is better than the init_ones
211
+ p_ones = init_ones(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
213
+ loss_ones = nn.MSELoss()(x, x_)
214
+
215
+ # Checks if the best parameter is better than the init_rand
216
+ p_rand = init_rand(x, **kwargs)
217
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
218
+ loss_rand = nn.MSELoss()(x, x_)
219
+
220
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
221
+ return p_rand
222
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
223
+ return p_ones
224
+ else:
225
+ return best_params[0][0]
226
+
227
+
228
+ def init_linear_scale( # Symmetric scale. From the study folder
229
+ x: torch.Tensor,
230
+ **kwargs: Dict[str, Any],
231
+ ) -> torch.Tensor:
232
+ assert "bits" in kwargs, "bits must be provided."
233
+ assert "params" in kwargs, "params must be provided."
234
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
235
+
236
+ bits = kwargs.get('bits')
237
+ params = kwargs.get('params')
238
+ qtz_func = kwargs.get('qtz_func')
239
+
240
+ x_ = x.transpose(0, 1)
241
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
242
+ x_ = x_.transpose(0, 1)
243
+
244
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
245
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
246
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
247
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
248
+
249
+ eps = torch.finfo(torch.float32).eps
250
+
251
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
252
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
253
+
254
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
255
+
256
+ # Introduces some noise in scale
257
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
258
+ # scale = scale + 0.01 * torch.randn_like(scale)
259
+ return scale
260
+
261
+
262
+ def init_non_linear_regression_fit(
263
+ x: torch.Tensor,
264
+ **kwargs: Dict[str, Any],
265
+ ) -> torch.Tensor:
266
+
267
+ assert "params_list" in kwargs, "params list must be provided."
268
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
269
+ assert "p0" in kwargs, "p0 must be provided."
270
+ np_fit_func = kwargs.get('np_fit_func')
271
+ params_list = kwargs.get('params_list')
272
+ p0 = kwargs.get('p0')
273
+
274
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
275
+ popt, _ = curve_fit(
276
+ func,
277
+ xdata,
278
+ ydata,
279
+ maxfev=1000,
280
+ p0=p0,
281
+ method='lm'
282
+ )
283
+ return popt
284
+
285
+ # 1. Needs to convert the torch tensor to numpy tensor
286
+ xdata = x.cpu().numpy()
287
+
288
+ # 2. Sorts the data so that it makes it easier to fit to it
289
+ sorted_xdata = np.sort(xdata, axis=-1)
290
+
291
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
292
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
293
+
294
+ # 3. Finds the best parameters for each channel
295
+ try:
296
+ params = []
297
+ for i in range(sorted_xdata.shape[0]):
298
+ xdata_ = sorted_xdata[i]
299
+ p0_ = [p0[p][i] for p in params_list]
300
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
301
+ params.append(ch_params)
302
+
303
+ # 4. Builds the parameters
304
+ result = {}
305
+ for i, p in enumerate(params_list):
306
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
307
+
308
+ return result
309
+
310
+ except ValueError as e:
311
+ print(f"Could not fit the function with error: {e}")
312
+ print(f"Using fallback result...")
313
+ return {
314
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
315
+ }
316
+
317
+
318
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
319
+ val = torch.amin(x, dim=1)
320
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
321
+
322
+
323
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
324
+ # Calculate the original minimum and maximum values
325
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
326
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
327
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
328
+
329
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
330
+ return torch.ones_like(x_min)
331
+
332
+ # Calculate the scale factor
333
+ scale = (_max - _min) / (x_max - x_min)
334
+ return scale
335
+
336
+
337
+
338
+ ############## Quant ###############
339
+
340
+ @torch.enable_grad()
341
+ def learn_parameters(
342
+ x: torch.Tensor,
343
+ params: Dict[str, nn.Parameter],
344
+ qtz_func: nn.Module,
345
+ deqtz_func: nn.Module,
346
+ bits: int,
347
+ target_dtype: torch.dtype,
348
+ epochs: int = 1000,
349
+ early_stop: bool = True,
350
+ do_report: bool = False
351
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
352
+ loss_fn = nn.MSELoss()
353
+
354
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
355
+ # the order of magnitude of the loss divided by 2
356
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
357
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
358
+ loss = loss_fn(x, dequant)
359
+
360
+ base_lr = 0.1
361
+ exponent = int(np.floor(np.log10(loss.item())))
362
+ lr = base_lr * (10 ** (exponent // 2))
363
+
364
+ # Requires gradients in the parameters
365
+ for p in params.values():
366
+ p.requires_grad = True
367
+ p.grad = None
368
+
369
+ param_keys = list(params.keys())
370
+ param_values = list(params.values())
371
+
372
+ # Defines optimizer and loss function
373
+ optimizer = torch.optim.Adam(param_values, lr=lr)
374
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
375
+
376
+ # Contains the best loss and the best parameters
377
+ best_loss = float("inf")
378
+ best_params = None
379
+
380
+ # Used to stop the search early
381
+ min_delta = 1e-7
382
+ acc_loss = []
383
+ percent_epochs_before_stop = 0.1
384
+
385
+ for i in range(epochs):
386
+ optimizer.zero_grad()
387
+
388
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
389
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
390
+ loss = loss_fn(x, dequant)
391
+
392
+ if loss.isnan() or loss.isinf():
393
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
394
+
395
+ loss.backward()
396
+ optimizer.step()
397
+ scheduler.step()
398
+
399
+ acc_loss.append(loss.item())
400
+
401
+ # Reports loss every 10 steps
402
+ if i % 10 == 0 and do_report:
403
+ print(f"Epoch {i}: Loss {loss.item()}")
404
+
405
+ # Optimizes the parameter search by storing the best loss and the parameters
406
+ if loss.item() < best_loss:
407
+ best_loss = loss.item()
408
+ best_params = copy.deepcopy({
409
+ k: v for k, v in params.items() if k in param_keys
410
+ })
411
+
412
+ # We also stop the search if the loss has not considerably during the last 10% epochs
413
+ if early_stop:
414
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
415
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
416
+ break
417
+
418
+ # No longer requires gradients in the parameters
419
+ for p in best_params.values():
420
+ p.requires_grad = False
421
+ p.grad = None
422
+
423
+ if do_report:
424
+ return best_params, acc_loss
425
+ else:
426
+ return best_params
427
+
428
+
429
+ def quantize(
430
+ x: torch.Tensor,
431
+ params: Dict[str, nn.Parameter],
432
+ func: nn.Module,
433
+ bits: int,
434
+ target_dtype: torch.dtype = torch.int8
435
+ ) -> torch.Tensor:
436
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
437
+ x = x.transpose(0, 1) # Aligns shapes
438
+ x = func(x=x, **params)
439
+ x = x.transpose(0, 1)
440
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
441
+ return x
442
+
443
+
444
+ def dequantize(
445
+ x: torch.Tensor,
446
+ params: Dict[str, nn.Parameter],
447
+ func: nn.Module,
448
+ bits: int,
449
+ out_dtype: torch.dtype
450
+ ) -> torch.Tensor:
451
+ x = x.to(dtype=out_dtype)
452
+ x = x.transpose(0, 1)
453
+ x = func(x=x, **params)
454
+ x = x.transpose(0, 1)
455
+ return x
456
+
457
+
458
+ def round_func_BPDA(input):
459
+ # This is equivalent to replacing round function (non-differentiable) with
460
+ # an identity function (differentiable) only when backward.
461
+ forward_value = torch.round(input)
462
+ out = input.clone()
463
+ out.data = forward_value.data
464
+ return out
465
+
466
+
467
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
468
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
469
+
470
+
471
+
472
+ ############## Numpy ###############
473
+
474
+ def np_domain_guard(
475
+ x: np.ndarray,
476
+ min: float = None,
477
+ max: float = None,
478
+ posinf: float = None,
479
+ neginf: float = None,
480
+ nan: float = None
481
+ ) -> np.ndarray:
482
+ """Guard a tensor to a valid domain."""
483
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
484
+ if min is not None or max is not None:
485
+ x = np.clip(x, min, max)
486
+ return x
487
+
488
+
489
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
490
+ """Replace a number in a tensor with another number.
491
+
492
+ Args:
493
+ x (np.ndarray): The input tensor.
494
+ num (float): The number to replace.
495
+ to (float): The number to replace with.
496
+
497
+ Returns:
498
+ np.ndarray: The tensor with the number replaced.
499
+ """
500
+ return np.where(x == num, to, x)
501
+
502
+
503
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
504
+ """Guard the power operation to a valid domain."""
505
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
506
+
fn_gen/ones_t/12/loss.png ADDED
fn_gen/ones_t/12/quantization.png ADDED
fn_gen/ones_t/13/distortion.png ADDED
fn_gen/ones_t/13/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x/_s
2
+ _s*x
fn_gen/ones_t/13/fn.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (x * torch.div(1, replace_num(params['_s'], num=0, to=10000)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (params['_s'] * x)
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ }
58
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
59
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
60
+
61
+ if 'post_init_hook' in kwargs:
62
+ kwargs['post_init_hook'](parameters=params)
63
+
64
+ params = learn_parameters(x, params,
65
+ qtz_func=quantization,
66
+ deqtz_func=dequantization,
67
+ bits=kwargs['bits'],
68
+ target_dtype=torch.int8,
69
+ epochs=500,
70
+ early_stop=False,
71
+ )
72
+ if 'post_train_hook' in kwargs:
73
+ kwargs['post_train_hook'](parameters=params)
74
+
75
+ return params
76
+
77
+
78
+ ############### Numpy Qtz ###############
79
+
80
+
81
+ def np_quantization(x, _s):
82
+ return (x * np.divide(1, np_replace_num(_s, num=0, to=10000)))
83
+
84
+
85
+ def np_dequantization(x, _s):
86
+ return (_s * x)
87
+
88
+
89
+ def fit_func(x, _s):
90
+ x_ = np_quantization(x, _s)
91
+ x_ = np_dequantization(x_, _s)
92
+ return x_
93
+
94
+
95
+
96
+ ############### HELPERS ###############
97
+
98
+ def domain_guard(
99
+ x: torch.Tensor,
100
+ min: float = None,
101
+ max: float = None,
102
+ posinf: float = None,
103
+ neginf: float = None,
104
+ nan: float = None
105
+ ) -> torch.Tensor:
106
+ """Guard a tensor to a valid domain."""
107
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
108
+ if min is not None or max is not None:
109
+ x = torch.clamp(x, min=min, max=max)
110
+ return x
111
+
112
+
113
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
114
+ """Replace a number in a tensor with another number.
115
+
116
+ Args:
117
+ x (torch.Tensor): The input tensor.
118
+ num (float): The number to replace.
119
+ to (float): The number to replace with.
120
+
121
+ Returns:
122
+ torch.Tensor: The tensor with the number replaced.
123
+ """
124
+ return torch.where(x == num, to, x)
125
+
126
+
127
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
128
+ """Guard the power operation to a valid domain."""
129
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
130
+
131
+
132
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
133
+ val = torch.amin(x, dim=1)
134
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
135
+
136
+
137
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
138
+ val = torch.amin(x, dim=1)
139
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
140
+
141
+
142
+ def init_space_search(
143
+ x: torch.Tensor,
144
+ **kwargs: Dict[str, Any],
145
+ ) -> torch.Tensor:
146
+
147
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
148
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
149
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
150
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
151
+
152
+ def _search_param(tensors: List[torch.tensor], n_params):
153
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
154
+ torch_tensors = torch.stack(tensors)
155
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
156
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
157
+ mean = torch.mean(torch_tensors, dim=0)
158
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
159
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
160
+
161
+ def _calc(x, qtz_func, deqtz_func, **params):
162
+ x_ = x.transpose(0, 1)
163
+ x_ = qtz_func(x=x_, **params)
164
+ x_ = deqtz_func(x=x_, **params)
165
+ x_ = x_.transpose(0, 1)
166
+ return x_
167
+
168
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
169
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
170
+ assert "params_list" in kwargs, "params list must be provided."
171
+ assert "param" in kwargs, "param must be provided."
172
+
173
+ qtz_func = kwargs.get('qtz_func')
174
+ deqtz_func = kwargs.get('deqtz_func')
175
+ params_list = kwargs.get('params_list')
176
+ param = kwargs.get('param')
177
+
178
+ n_runs = 50 # Number of runs to try to find the best parameters
179
+ n_random_params = 50 # Number of random parameters to generate
180
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
181
+ max_initial = 10000 # Maximum value to initialize the parameters
182
+
183
+ # Initializes the parameters
184
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
185
+ params = _build_initial_param(x, max_initial, n_random_params)
186
+
187
+ # Performs the search
188
+ for _ in range(n_runs):
189
+
190
+ best_params = []
191
+ for param_ in params:
192
+ try:
193
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
194
+ loss_ones = nn.MSELoss()(x, x_)
195
+
196
+ if len(best_params) < n_best_to_pick:
197
+ best_params.append((param_, loss_ones.item()))
198
+ best_params = sorted(best_params, key=lambda x: x[1])
199
+ elif loss_ones < best_params[-1][1]:
200
+ best_params[-1] = (param_, loss_ones.item())
201
+ best_params = sorted(best_params, key=lambda x: x[1])
202
+
203
+ except Exception: # The parameters might not be valid for the function's domain
204
+ continue
205
+
206
+ # Generates new parameters around the mean
207
+ params = _search_param([p for p, _ in best_params], n_random_params)
208
+
209
+ # Checks if the best parameter is better than the init_ones
210
+ p_ones = init_ones(x, **kwargs)
211
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
212
+ loss_ones = nn.MSELoss()(x, x_)
213
+
214
+ # Checks if the best parameter is better than the init_rand
215
+ p_rand = init_rand(x, **kwargs)
216
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
217
+ loss_rand = nn.MSELoss()(x, x_)
218
+
219
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
220
+ return p_rand
221
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
222
+ return p_ones
223
+ else:
224
+ return best_params[0][0]
225
+
226
+
227
+ def init_linear_scale( # Symmetric scale. From the study folder
228
+ x: torch.Tensor,
229
+ **kwargs: Dict[str, Any],
230
+ ) -> torch.Tensor:
231
+ assert "bits" in kwargs, "bits must be provided."
232
+ assert "params" in kwargs, "params must be provided."
233
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
234
+
235
+ bits = kwargs.get('bits')
236
+ params = kwargs.get('params')
237
+ qtz_func = kwargs.get('qtz_func')
238
+
239
+ x_ = x.transpose(0, 1)
240
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
241
+ x_ = x_.transpose(0, 1)
242
+
243
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
244
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
245
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
246
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
247
+
248
+ eps = torch.finfo(torch.float32).eps
249
+
250
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
251
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
252
+
253
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
254
+
255
+ # Introduces some noise in scale
256
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
257
+ # scale = scale + 0.01 * torch.randn_like(scale)
258
+ return scale
259
+
260
+
261
+ def init_non_linear_regression_fit(
262
+ x: torch.Tensor,
263
+ **kwargs: Dict[str, Any],
264
+ ) -> torch.Tensor:
265
+
266
+ assert "params_list" in kwargs, "params list must be provided."
267
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
268
+ assert "p0" in kwargs, "p0 must be provided."
269
+ np_fit_func = kwargs.get('np_fit_func')
270
+ params_list = kwargs.get('params_list')
271
+ p0 = kwargs.get('p0')
272
+
273
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
274
+ popt, _ = curve_fit(
275
+ func,
276
+ xdata,
277
+ ydata,
278
+ maxfev=1000,
279
+ p0=p0,
280
+ method='lm'
281
+ )
282
+ return popt
283
+
284
+ # 1. Needs to convert the torch tensor to numpy tensor
285
+ xdata = x.cpu().numpy()
286
+
287
+ # 2. Sorts the data so that it makes it easier to fit to it
288
+ sorted_xdata = np.sort(xdata, axis=-1)
289
+
290
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
291
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
292
+
293
+ # 3. Finds the best parameters for each channel
294
+ try:
295
+ params = []
296
+ for i in range(sorted_xdata.shape[0]):
297
+ xdata_ = sorted_xdata[i]
298
+ p0_ = [p0[p][i] for p in params_list]
299
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
300
+ params.append(ch_params)
301
+
302
+ # 4. Builds the parameters
303
+ result = {}
304
+ for i, p in enumerate(params_list):
305
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
306
+
307
+ return result
308
+
309
+ except ValueError as e:
310
+ print(f"Could not fit the function with error: {e}")
311
+ print(f"Using fallback result...")
312
+ return {
313
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
314
+ }
315
+
316
+
317
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
318
+ val = torch.amin(x, dim=1)
319
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
320
+
321
+
322
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
323
+ # Calculate the original minimum and maximum values
324
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
325
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
326
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
327
+
328
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
329
+ return torch.ones_like(x_min)
330
+
331
+ # Calculate the scale factor
332
+ scale = (_max - _min) / (x_max - x_min)
333
+ return scale
334
+
335
+
336
+
337
+ ############## Quant ###############
338
+
339
+ @torch.enable_grad()
340
+ def learn_parameters(
341
+ x: torch.Tensor,
342
+ params: Dict[str, nn.Parameter],
343
+ qtz_func: nn.Module,
344
+ deqtz_func: nn.Module,
345
+ bits: int,
346
+ target_dtype: torch.dtype,
347
+ epochs: int = 1000,
348
+ early_stop: bool = True,
349
+ do_report: bool = False
350
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
351
+ loss_fn = nn.MSELoss()
352
+
353
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
354
+ # the order of magnitude of the loss divided by 2
355
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
356
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
357
+ loss = loss_fn(x, dequant)
358
+
359
+ base_lr = 0.1
360
+ exponent = int(np.floor(np.log10(loss.item())))
361
+ lr = base_lr * (10 ** (exponent // 2))
362
+
363
+ # Requires gradients in the parameters
364
+ for p in params.values():
365
+ p.requires_grad = True
366
+ p.grad = None
367
+
368
+ param_keys = list(params.keys())
369
+ param_values = list(params.values())
370
+
371
+ # Defines optimizer and loss function
372
+ optimizer = torch.optim.Adam(param_values, lr=lr)
373
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
374
+
375
+ # Contains the best loss and the best parameters
376
+ best_loss = float("inf")
377
+ best_params = None
378
+
379
+ # Used to stop the search early
380
+ min_delta = 1e-7
381
+ acc_loss = []
382
+ percent_epochs_before_stop = 0.1
383
+
384
+ for i in range(epochs):
385
+ optimizer.zero_grad()
386
+
387
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
388
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
389
+ loss = loss_fn(x, dequant)
390
+
391
+ if loss.isnan() or loss.isinf():
392
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
393
+
394
+ loss.backward()
395
+ optimizer.step()
396
+ scheduler.step()
397
+
398
+ acc_loss.append(loss.item())
399
+
400
+ # Reports loss every 10 steps
401
+ if i % 10 == 0 and do_report:
402
+ print(f"Epoch {i}: Loss {loss.item()}")
403
+
404
+ # Optimizes the parameter search by storing the best loss and the parameters
405
+ if loss.item() < best_loss:
406
+ best_loss = loss.item()
407
+ best_params = copy.deepcopy({
408
+ k: v for k, v in params.items() if k in param_keys
409
+ })
410
+
411
+ # We also stop the search if the loss has not considerably during the last 10% epochs
412
+ if early_stop:
413
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
414
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
415
+ break
416
+
417
+ # No longer requires gradients in the parameters
418
+ for p in best_params.values():
419
+ p.requires_grad = False
420
+ p.grad = None
421
+
422
+ if do_report:
423
+ return best_params, acc_loss
424
+ else:
425
+ return best_params
426
+
427
+
428
+ def quantize(
429
+ x: torch.Tensor,
430
+ params: Dict[str, nn.Parameter],
431
+ func: nn.Module,
432
+ bits: int,
433
+ target_dtype: torch.dtype = torch.int8
434
+ ) -> torch.Tensor:
435
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
436
+ x = x.transpose(0, 1) # Aligns shapes
437
+ x = func(x=x, **params)
438
+ x = x.transpose(0, 1)
439
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
440
+ return x
441
+
442
+
443
+ def dequantize(
444
+ x: torch.Tensor,
445
+ params: Dict[str, nn.Parameter],
446
+ func: nn.Module,
447
+ bits: int,
448
+ out_dtype: torch.dtype
449
+ ) -> torch.Tensor:
450
+ x = x.to(dtype=out_dtype)
451
+ x = x.transpose(0, 1)
452
+ x = func(x=x, **params)
453
+ x = x.transpose(0, 1)
454
+ return x
455
+
456
+
457
+ def round_func_BPDA(input):
458
+ # This is equivalent to replacing round function (non-differentiable) with
459
+ # an identity function (differentiable) only when backward.
460
+ forward_value = torch.round(input)
461
+ out = input.clone()
462
+ out.data = forward_value.data
463
+ return out
464
+
465
+
466
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
467
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
468
+
469
+
470
+
471
+ ############## Numpy ###############
472
+
473
+ def np_domain_guard(
474
+ x: np.ndarray,
475
+ min: float = None,
476
+ max: float = None,
477
+ posinf: float = None,
478
+ neginf: float = None,
479
+ nan: float = None
480
+ ) -> np.ndarray:
481
+ """Guard a tensor to a valid domain."""
482
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
483
+ if min is not None or max is not None:
484
+ x = np.clip(x, min, max)
485
+ return x
486
+
487
+
488
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
489
+ """Replace a number in a tensor with another number.
490
+
491
+ Args:
492
+ x (np.ndarray): The input tensor.
493
+ num (float): The number to replace.
494
+ to (float): The number to replace with.
495
+
496
+ Returns:
497
+ np.ndarray: The tensor with the number replaced.
498
+ """
499
+ return np.where(x == num, to, x)
500
+
501
+
502
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
503
+ """Guard the power operation to a valid domain."""
504
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
505
+
fn_gen/ones_t/13/loss.png ADDED
fn_gen/ones_t/13/quantization.png ADDED
fn_gen/ones_t/14/distortion.png ADDED
fn_gen/ones_t/14/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atanh(_0*x)/_s
2
+ tanh(_s*x)/_0
fn_gen/ones_t/14/fn.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+ params = learn_parameters(x, params,
66
+ qtz_func=quantization,
67
+ deqtz_func=dequantization,
68
+ bits=kwargs['bits'],
69
+ target_dtype=torch.int8,
70
+ epochs=500,
71
+ early_stop=False,
72
+ )
73
+ if 'post_train_hook' in kwargs:
74
+ kwargs['post_train_hook'](parameters=params)
75
+
76
+ return params
77
+
78
+
79
+ ############### Numpy Qtz ###############
80
+
81
+
82
+ def np_quantization(x, _0, _s):
83
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0)))
84
+
85
+
86
+ def np_dequantization(x, _0, _s):
87
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x)))
88
+
89
+
90
+ def fit_func(x, _0, _s):
91
+ x_ = np_quantization(x, _0, _s)
92
+ x_ = np_dequantization(x_, _0, _s)
93
+ return x_
94
+
95
+
96
+
97
+ ############### HELPERS ###############
98
+
99
+ def domain_guard(
100
+ x: torch.Tensor,
101
+ min: float = None,
102
+ max: float = None,
103
+ posinf: float = None,
104
+ neginf: float = None,
105
+ nan: float = None
106
+ ) -> torch.Tensor:
107
+ """Guard a tensor to a valid domain."""
108
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
109
+ if min is not None or max is not None:
110
+ x = torch.clamp(x, min=min, max=max)
111
+ return x
112
+
113
+
114
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
115
+ """Replace a number in a tensor with another number.
116
+
117
+ Args:
118
+ x (torch.Tensor): The input tensor.
119
+ num (float): The number to replace.
120
+ to (float): The number to replace with.
121
+
122
+ Returns:
123
+ torch.Tensor: The tensor with the number replaced.
124
+ """
125
+ return torch.where(x == num, to, x)
126
+
127
+
128
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
129
+ """Guard the power operation to a valid domain."""
130
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
131
+
132
+
133
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
139
+ val = torch.amin(x, dim=1)
140
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
141
+
142
+
143
+ def init_space_search(
144
+ x: torch.Tensor,
145
+ **kwargs: Dict[str, Any],
146
+ ) -> torch.Tensor:
147
+
148
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
149
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
150
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
151
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
152
+
153
+ def _search_param(tensors: List[torch.tensor], n_params):
154
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
155
+ torch_tensors = torch.stack(tensors)
156
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
157
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
158
+ mean = torch.mean(torch_tensors, dim=0)
159
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
160
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
161
+
162
+ def _calc(x, qtz_func, deqtz_func, **params):
163
+ x_ = x.transpose(0, 1)
164
+ x_ = qtz_func(x=x_, **params)
165
+ x_ = deqtz_func(x=x_, **params)
166
+ x_ = x_.transpose(0, 1)
167
+ return x_
168
+
169
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
170
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
171
+ assert "params_list" in kwargs, "params list must be provided."
172
+ assert "param" in kwargs, "param must be provided."
173
+
174
+ qtz_func = kwargs.get('qtz_func')
175
+ deqtz_func = kwargs.get('deqtz_func')
176
+ params_list = kwargs.get('params_list')
177
+ param = kwargs.get('param')
178
+
179
+ n_runs = 50 # Number of runs to try to find the best parameters
180
+ n_random_params = 50 # Number of random parameters to generate
181
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
182
+ max_initial = 10000 # Maximum value to initialize the parameters
183
+
184
+ # Initializes the parameters
185
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
186
+ params = _build_initial_param(x, max_initial, n_random_params)
187
+
188
+ # Performs the search
189
+ for _ in range(n_runs):
190
+
191
+ best_params = []
192
+ for param_ in params:
193
+ try:
194
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
195
+ loss_ones = nn.MSELoss()(x, x_)
196
+
197
+ if len(best_params) < n_best_to_pick:
198
+ best_params.append((param_, loss_ones.item()))
199
+ best_params = sorted(best_params, key=lambda x: x[1])
200
+ elif loss_ones < best_params[-1][1]:
201
+ best_params[-1] = (param_, loss_ones.item())
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+
204
+ except Exception: # The parameters might not be valid for the function's domain
205
+ continue
206
+
207
+ # Generates new parameters around the mean
208
+ params = _search_param([p for p, _ in best_params], n_random_params)
209
+
210
+ # Checks if the best parameter is better than the init_ones
211
+ p_ones = init_ones(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
213
+ loss_ones = nn.MSELoss()(x, x_)
214
+
215
+ # Checks if the best parameter is better than the init_rand
216
+ p_rand = init_rand(x, **kwargs)
217
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
218
+ loss_rand = nn.MSELoss()(x, x_)
219
+
220
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
221
+ return p_rand
222
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
223
+ return p_ones
224
+ else:
225
+ return best_params[0][0]
226
+
227
+
228
+ def init_linear_scale( # Symmetric scale. From the study folder
229
+ x: torch.Tensor,
230
+ **kwargs: Dict[str, Any],
231
+ ) -> torch.Tensor:
232
+ assert "bits" in kwargs, "bits must be provided."
233
+ assert "params" in kwargs, "params must be provided."
234
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
235
+
236
+ bits = kwargs.get('bits')
237
+ params = kwargs.get('params')
238
+ qtz_func = kwargs.get('qtz_func')
239
+
240
+ x_ = x.transpose(0, 1)
241
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
242
+ x_ = x_.transpose(0, 1)
243
+
244
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
245
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
246
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
247
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
248
+
249
+ eps = torch.finfo(torch.float32).eps
250
+
251
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
252
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
253
+
254
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
255
+
256
+ # Introduces some noise in scale
257
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
258
+ # scale = scale + 0.01 * torch.randn_like(scale)
259
+ return scale
260
+
261
+
262
+ def init_non_linear_regression_fit(
263
+ x: torch.Tensor,
264
+ **kwargs: Dict[str, Any],
265
+ ) -> torch.Tensor:
266
+
267
+ assert "params_list" in kwargs, "params list must be provided."
268
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
269
+ assert "p0" in kwargs, "p0 must be provided."
270
+ np_fit_func = kwargs.get('np_fit_func')
271
+ params_list = kwargs.get('params_list')
272
+ p0 = kwargs.get('p0')
273
+
274
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
275
+ popt, _ = curve_fit(
276
+ func,
277
+ xdata,
278
+ ydata,
279
+ maxfev=1000,
280
+ p0=p0,
281
+ method='lm'
282
+ )
283
+ return popt
284
+
285
+ # 1. Needs to convert the torch tensor to numpy tensor
286
+ xdata = x.cpu().numpy()
287
+
288
+ # 2. Sorts the data so that it makes it easier to fit to it
289
+ sorted_xdata = np.sort(xdata, axis=-1)
290
+
291
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
292
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
293
+
294
+ # 3. Finds the best parameters for each channel
295
+ try:
296
+ params = []
297
+ for i in range(sorted_xdata.shape[0]):
298
+ xdata_ = sorted_xdata[i]
299
+ p0_ = [p0[p][i] for p in params_list]
300
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
301
+ params.append(ch_params)
302
+
303
+ # 4. Builds the parameters
304
+ result = {}
305
+ for i, p in enumerate(params_list):
306
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
307
+
308
+ return result
309
+
310
+ except ValueError as e:
311
+ print(f"Could not fit the function with error: {e}")
312
+ print(f"Using fallback result...")
313
+ return {
314
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
315
+ }
316
+
317
+
318
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
319
+ val = torch.amin(x, dim=1)
320
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
321
+
322
+
323
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
324
+ # Calculate the original minimum and maximum values
325
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
326
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
327
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
328
+
329
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
330
+ return torch.ones_like(x_min)
331
+
332
+ # Calculate the scale factor
333
+ scale = (_max - _min) / (x_max - x_min)
334
+ return scale
335
+
336
+
337
+
338
+ ############## Quant ###############
339
+
340
+ @torch.enable_grad()
341
+ def learn_parameters(
342
+ x: torch.Tensor,
343
+ params: Dict[str, nn.Parameter],
344
+ qtz_func: nn.Module,
345
+ deqtz_func: nn.Module,
346
+ bits: int,
347
+ target_dtype: torch.dtype,
348
+ epochs: int = 1000,
349
+ early_stop: bool = True,
350
+ do_report: bool = False
351
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
352
+ loss_fn = nn.MSELoss()
353
+
354
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
355
+ # the order of magnitude of the loss divided by 2
356
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
357
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
358
+ loss = loss_fn(x, dequant)
359
+
360
+ base_lr = 0.1
361
+ exponent = int(np.floor(np.log10(loss.item())))
362
+ lr = base_lr * (10 ** (exponent // 2))
363
+
364
+ # Requires gradients in the parameters
365
+ for p in params.values():
366
+ p.requires_grad = True
367
+ p.grad = None
368
+
369
+ param_keys = list(params.keys())
370
+ param_values = list(params.values())
371
+
372
+ # Defines optimizer and loss function
373
+ optimizer = torch.optim.Adam(param_values, lr=lr)
374
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
375
+
376
+ # Contains the best loss and the best parameters
377
+ best_loss = float("inf")
378
+ best_params = None
379
+
380
+ # Used to stop the search early
381
+ min_delta = 1e-7
382
+ acc_loss = []
383
+ percent_epochs_before_stop = 0.1
384
+
385
+ for i in range(epochs):
386
+ optimizer.zero_grad()
387
+
388
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
389
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
390
+ loss = loss_fn(x, dequant)
391
+
392
+ if loss.isnan() or loss.isinf():
393
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
394
+
395
+ loss.backward()
396
+ optimizer.step()
397
+ scheduler.step()
398
+
399
+ acc_loss.append(loss.item())
400
+
401
+ # Reports loss every 10 steps
402
+ if i % 10 == 0 and do_report:
403
+ print(f"Epoch {i}: Loss {loss.item()}")
404
+
405
+ # Optimizes the parameter search by storing the best loss and the parameters
406
+ if loss.item() < best_loss:
407
+ best_loss = loss.item()
408
+ best_params = copy.deepcopy({
409
+ k: v for k, v in params.items() if k in param_keys
410
+ })
411
+
412
+ # We also stop the search if the loss has not considerably during the last 10% epochs
413
+ if early_stop:
414
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
415
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
416
+ break
417
+
418
+ # No longer requires gradients in the parameters
419
+ for p in best_params.values():
420
+ p.requires_grad = False
421
+ p.grad = None
422
+
423
+ if do_report:
424
+ return best_params, acc_loss
425
+ else:
426
+ return best_params
427
+
428
+
429
+ def quantize(
430
+ x: torch.Tensor,
431
+ params: Dict[str, nn.Parameter],
432
+ func: nn.Module,
433
+ bits: int,
434
+ target_dtype: torch.dtype = torch.int8
435
+ ) -> torch.Tensor:
436
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
437
+ x = x.transpose(0, 1) # Aligns shapes
438
+ x = func(x=x, **params)
439
+ x = x.transpose(0, 1)
440
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
441
+ return x
442
+
443
+
444
+ def dequantize(
445
+ x: torch.Tensor,
446
+ params: Dict[str, nn.Parameter],
447
+ func: nn.Module,
448
+ bits: int,
449
+ out_dtype: torch.dtype
450
+ ) -> torch.Tensor:
451
+ x = x.to(dtype=out_dtype)
452
+ x = x.transpose(0, 1)
453
+ x = func(x=x, **params)
454
+ x = x.transpose(0, 1)
455
+ return x
456
+
457
+
458
+ def round_func_BPDA(input):
459
+ # This is equivalent to replacing round function (non-differentiable) with
460
+ # an identity function (differentiable) only when backward.
461
+ forward_value = torch.round(input)
462
+ out = input.clone()
463
+ out.data = forward_value.data
464
+ return out
465
+
466
+
467
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
468
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
469
+
470
+
471
+
472
+ ############## Numpy ###############
473
+
474
+ def np_domain_guard(
475
+ x: np.ndarray,
476
+ min: float = None,
477
+ max: float = None,
478
+ posinf: float = None,
479
+ neginf: float = None,
480
+ nan: float = None
481
+ ) -> np.ndarray:
482
+ """Guard a tensor to a valid domain."""
483
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
484
+ if min is not None or max is not None:
485
+ x = np.clip(x, min, max)
486
+ return x
487
+
488
+
489
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
490
+ """Replace a number in a tensor with another number.
491
+
492
+ Args:
493
+ x (np.ndarray): The input tensor.
494
+ num (float): The number to replace.
495
+ to (float): The number to replace with.
496
+
497
+ Returns:
498
+ np.ndarray: The tensor with the number replaced.
499
+ """
500
+ return np.where(x == num, to, x)
501
+
502
+
503
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
504
+ """Guard the power operation to a valid domain."""
505
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
506
+
fn_gen/ones_t/14/loss.png ADDED
fn_gen/ones_t/14/quantization.png ADDED
fn_gen/ones_t/15/distortion.png ADDED
fn_gen/ones_t/15/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**3/_s
2
+ (_s*x)**(1/3)
fn_gen/ones_t/15/fn.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(3)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return guarded_torch_power((params['_s'] * x), 1 / 3)
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ }
58
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
59
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
60
+
61
+ if 'post_init_hook' in kwargs:
62
+ kwargs['post_init_hook'](parameters=params)
63
+
64
+ params = learn_parameters(x, params,
65
+ qtz_func=quantization,
66
+ deqtz_func=dequantization,
67
+ bits=kwargs['bits'],
68
+ target_dtype=torch.int8,
69
+ epochs=500,
70
+ early_stop=False,
71
+ )
72
+ if 'post_train_hook' in kwargs:
73
+ kwargs['post_train_hook'](parameters=params)
74
+
75
+ return params
76
+
77
+
78
+ ############### Numpy Qtz ###############
79
+
80
+
81
+ def np_quantization(x, _s):
82
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(3)))
83
+
84
+
85
+ def np_dequantization(x, _s):
86
+ return np_guarded_power((_s * x), 1 / 3)
87
+
88
+
89
+ def fit_func(x, _s):
90
+ x_ = np_quantization(x, _s)
91
+ x_ = np_dequantization(x_, _s)
92
+ return x_
93
+
94
+
95
+
96
+ ############### HELPERS ###############
97
+
98
+ def domain_guard(
99
+ x: torch.Tensor,
100
+ min: float = None,
101
+ max: float = None,
102
+ posinf: float = None,
103
+ neginf: float = None,
104
+ nan: float = None
105
+ ) -> torch.Tensor:
106
+ """Guard a tensor to a valid domain."""
107
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
108
+ if min is not None or max is not None:
109
+ x = torch.clamp(x, min=min, max=max)
110
+ return x
111
+
112
+
113
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
114
+ """Replace a number in a tensor with another number.
115
+
116
+ Args:
117
+ x (torch.Tensor): The input tensor.
118
+ num (float): The number to replace.
119
+ to (float): The number to replace with.
120
+
121
+ Returns:
122
+ torch.Tensor: The tensor with the number replaced.
123
+ """
124
+ return torch.where(x == num, to, x)
125
+
126
+
127
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
128
+ """Guard the power operation to a valid domain."""
129
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
130
+
131
+
132
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
133
+ val = torch.amin(x, dim=1)
134
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
135
+
136
+
137
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
138
+ val = torch.amin(x, dim=1)
139
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
140
+
141
+
142
+ def init_space_search(
143
+ x: torch.Tensor,
144
+ **kwargs: Dict[str, Any],
145
+ ) -> torch.Tensor:
146
+
147
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
148
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
149
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
150
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
151
+
152
+ def _search_param(tensors: List[torch.tensor], n_params):
153
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
154
+ torch_tensors = torch.stack(tensors)
155
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
156
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
157
+ mean = torch.mean(torch_tensors, dim=0)
158
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
159
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
160
+
161
+ def _calc(x, qtz_func, deqtz_func, **params):
162
+ x_ = x.transpose(0, 1)
163
+ x_ = qtz_func(x=x_, **params)
164
+ x_ = deqtz_func(x=x_, **params)
165
+ x_ = x_.transpose(0, 1)
166
+ return x_
167
+
168
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
169
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
170
+ assert "params_list" in kwargs, "params list must be provided."
171
+ assert "param" in kwargs, "param must be provided."
172
+
173
+ qtz_func = kwargs.get('qtz_func')
174
+ deqtz_func = kwargs.get('deqtz_func')
175
+ params_list = kwargs.get('params_list')
176
+ param = kwargs.get('param')
177
+
178
+ n_runs = 50 # Number of runs to try to find the best parameters
179
+ n_random_params = 50 # Number of random parameters to generate
180
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
181
+ max_initial = 10000 # Maximum value to initialize the parameters
182
+
183
+ # Initializes the parameters
184
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
185
+ params = _build_initial_param(x, max_initial, n_random_params)
186
+
187
+ # Performs the search
188
+ for _ in range(n_runs):
189
+
190
+ best_params = []
191
+ for param_ in params:
192
+ try:
193
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
194
+ loss_ones = nn.MSELoss()(x, x_)
195
+
196
+ if len(best_params) < n_best_to_pick:
197
+ best_params.append((param_, loss_ones.item()))
198
+ best_params = sorted(best_params, key=lambda x: x[1])
199
+ elif loss_ones < best_params[-1][1]:
200
+ best_params[-1] = (param_, loss_ones.item())
201
+ best_params = sorted(best_params, key=lambda x: x[1])
202
+
203
+ except Exception: # The parameters might not be valid for the function's domain
204
+ continue
205
+
206
+ # Generates new parameters around the mean
207
+ params = _search_param([p for p, _ in best_params], n_random_params)
208
+
209
+ # Checks if the best parameter is better than the init_ones
210
+ p_ones = init_ones(x, **kwargs)
211
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
212
+ loss_ones = nn.MSELoss()(x, x_)
213
+
214
+ # Checks if the best parameter is better than the init_rand
215
+ p_rand = init_rand(x, **kwargs)
216
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
217
+ loss_rand = nn.MSELoss()(x, x_)
218
+
219
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
220
+ return p_rand
221
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
222
+ return p_ones
223
+ else:
224
+ return best_params[0][0]
225
+
226
+
227
+ def init_linear_scale( # Symmetric scale. From the study folder
228
+ x: torch.Tensor,
229
+ **kwargs: Dict[str, Any],
230
+ ) -> torch.Tensor:
231
+ assert "bits" in kwargs, "bits must be provided."
232
+ assert "params" in kwargs, "params must be provided."
233
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
234
+
235
+ bits = kwargs.get('bits')
236
+ params = kwargs.get('params')
237
+ qtz_func = kwargs.get('qtz_func')
238
+
239
+ x_ = x.transpose(0, 1)
240
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
241
+ x_ = x_.transpose(0, 1)
242
+
243
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
244
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
245
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
246
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
247
+
248
+ eps = torch.finfo(torch.float32).eps
249
+
250
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
251
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
252
+
253
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
254
+
255
+ # Introduces some noise in scale
256
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
257
+ # scale = scale + 0.01 * torch.randn_like(scale)
258
+ return scale
259
+
260
+
261
+ def init_non_linear_regression_fit(
262
+ x: torch.Tensor,
263
+ **kwargs: Dict[str, Any],
264
+ ) -> torch.Tensor:
265
+
266
+ assert "params_list" in kwargs, "params list must be provided."
267
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
268
+ assert "p0" in kwargs, "p0 must be provided."
269
+ np_fit_func = kwargs.get('np_fit_func')
270
+ params_list = kwargs.get('params_list')
271
+ p0 = kwargs.get('p0')
272
+
273
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
274
+ popt, _ = curve_fit(
275
+ func,
276
+ xdata,
277
+ ydata,
278
+ maxfev=1000,
279
+ p0=p0,
280
+ method='lm'
281
+ )
282
+ return popt
283
+
284
+ # 1. Needs to convert the torch tensor to numpy tensor
285
+ xdata = x.cpu().numpy()
286
+
287
+ # 2. Sorts the data so that it makes it easier to fit to it
288
+ sorted_xdata = np.sort(xdata, axis=-1)
289
+
290
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
291
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
292
+
293
+ # 3. Finds the best parameters for each channel
294
+ try:
295
+ params = []
296
+ for i in range(sorted_xdata.shape[0]):
297
+ xdata_ = sorted_xdata[i]
298
+ p0_ = [p0[p][i] for p in params_list]
299
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
300
+ params.append(ch_params)
301
+
302
+ # 4. Builds the parameters
303
+ result = {}
304
+ for i, p in enumerate(params_list):
305
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
306
+
307
+ return result
308
+
309
+ except ValueError as e:
310
+ print(f"Could not fit the function with error: {e}")
311
+ print(f"Using fallback result...")
312
+ return {
313
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
314
+ }
315
+
316
+
317
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
318
+ val = torch.amin(x, dim=1)
319
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
320
+
321
+
322
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
323
+ # Calculate the original minimum and maximum values
324
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
325
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
326
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
327
+
328
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
329
+ return torch.ones_like(x_min)
330
+
331
+ # Calculate the scale factor
332
+ scale = (_max - _min) / (x_max - x_min)
333
+ return scale
334
+
335
+
336
+
337
+ ############## Quant ###############
338
+
339
+ @torch.enable_grad()
340
+ def learn_parameters(
341
+ x: torch.Tensor,
342
+ params: Dict[str, nn.Parameter],
343
+ qtz_func: nn.Module,
344
+ deqtz_func: nn.Module,
345
+ bits: int,
346
+ target_dtype: torch.dtype,
347
+ epochs: int = 1000,
348
+ early_stop: bool = True,
349
+ do_report: bool = False
350
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
351
+ loss_fn = nn.MSELoss()
352
+
353
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
354
+ # the order of magnitude of the loss divided by 2
355
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
356
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
357
+ loss = loss_fn(x, dequant)
358
+
359
+ base_lr = 0.1
360
+ exponent = int(np.floor(np.log10(loss.item())))
361
+ lr = base_lr * (10 ** (exponent // 2))
362
+
363
+ # Requires gradients in the parameters
364
+ for p in params.values():
365
+ p.requires_grad = True
366
+ p.grad = None
367
+
368
+ param_keys = list(params.keys())
369
+ param_values = list(params.values())
370
+
371
+ # Defines optimizer and loss function
372
+ optimizer = torch.optim.Adam(param_values, lr=lr)
373
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
374
+
375
+ # Contains the best loss and the best parameters
376
+ best_loss = float("inf")
377
+ best_params = None
378
+
379
+ # Used to stop the search early
380
+ min_delta = 1e-7
381
+ acc_loss = []
382
+ percent_epochs_before_stop = 0.1
383
+
384
+ for i in range(epochs):
385
+ optimizer.zero_grad()
386
+
387
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
388
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
389
+ loss = loss_fn(x, dequant)
390
+
391
+ if loss.isnan() or loss.isinf():
392
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
393
+
394
+ loss.backward()
395
+ optimizer.step()
396
+ scheduler.step()
397
+
398
+ acc_loss.append(loss.item())
399
+
400
+ # Reports loss every 10 steps
401
+ if i % 10 == 0 and do_report:
402
+ print(f"Epoch {i}: Loss {loss.item()}")
403
+
404
+ # Optimizes the parameter search by storing the best loss and the parameters
405
+ if loss.item() < best_loss:
406
+ best_loss = loss.item()
407
+ best_params = copy.deepcopy({
408
+ k: v for k, v in params.items() if k in param_keys
409
+ })
410
+
411
+ # We also stop the search if the loss has not considerably during the last 10% epochs
412
+ if early_stop:
413
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
414
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
415
+ break
416
+
417
+ # No longer requires gradients in the parameters
418
+ for p in best_params.values():
419
+ p.requires_grad = False
420
+ p.grad = None
421
+
422
+ if do_report:
423
+ return best_params, acc_loss
424
+ else:
425
+ return best_params
426
+
427
+
428
+ def quantize(
429
+ x: torch.Tensor,
430
+ params: Dict[str, nn.Parameter],
431
+ func: nn.Module,
432
+ bits: int,
433
+ target_dtype: torch.dtype = torch.int8
434
+ ) -> torch.Tensor:
435
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
436
+ x = x.transpose(0, 1) # Aligns shapes
437
+ x = func(x=x, **params)
438
+ x = x.transpose(0, 1)
439
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
440
+ return x
441
+
442
+
443
+ def dequantize(
444
+ x: torch.Tensor,
445
+ params: Dict[str, nn.Parameter],
446
+ func: nn.Module,
447
+ bits: int,
448
+ out_dtype: torch.dtype
449
+ ) -> torch.Tensor:
450
+ x = x.to(dtype=out_dtype)
451
+ x = x.transpose(0, 1)
452
+ x = func(x=x, **params)
453
+ x = x.transpose(0, 1)
454
+ return x
455
+
456
+
457
+ def round_func_BPDA(input):
458
+ # This is equivalent to replacing round function (non-differentiable) with
459
+ # an identity function (differentiable) only when backward.
460
+ forward_value = torch.round(input)
461
+ out = input.clone()
462
+ out.data = forward_value.data
463
+ return out
464
+
465
+
466
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
467
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
468
+
469
+
470
+
471
+ ############## Numpy ###############
472
+
473
+ def np_domain_guard(
474
+ x: np.ndarray,
475
+ min: float = None,
476
+ max: float = None,
477
+ posinf: float = None,
478
+ neginf: float = None,
479
+ nan: float = None
480
+ ) -> np.ndarray:
481
+ """Guard a tensor to a valid domain."""
482
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
483
+ if min is not None or max is not None:
484
+ x = np.clip(x, min, max)
485
+ return x
486
+
487
+
488
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
489
+ """Replace a number in a tensor with another number.
490
+
491
+ Args:
492
+ x (np.ndarray): The input tensor.
493
+ num (float): The number to replace.
494
+ to (float): The number to replace with.
495
+
496
+ Returns:
497
+ np.ndarray: The tensor with the number replaced.
498
+ """
499
+ return np.where(x == num, to, x)
500
+
501
+
502
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
503
+ """Guard the power operation to a valid domain."""
504
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
505
+
fn_gen/ones_t/15/loss.png ADDED
fn_gen/ones_t/15/quantization.png ADDED
fn_gen/ones_t/16/distortion.png ADDED
fn_gen/ones_t/16/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tan(_0*x)/_s
2
+ atan(_s*x)/_0
fn_gen/ones_t/16/fn.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+ params = learn_parameters(x, params,
66
+ qtz_func=quantization,
67
+ deqtz_func=dequantization,
68
+ bits=kwargs['bits'],
69
+ target_dtype=torch.int8,
70
+ epochs=500,
71
+ early_stop=False,
72
+ )
73
+ if 'post_train_hook' in kwargs:
74
+ kwargs['post_train_hook'](parameters=params)
75
+
76
+ return params
77
+
78
+
79
+ ############### Numpy Qtz ###############
80
+
81
+
82
+ def np_quantization(x, _0, _s):
83
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0)))
84
+
85
+
86
+ def np_dequantization(x, _0, _s):
87
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x)))
88
+
89
+
90
+ def fit_func(x, _0, _s):
91
+ x_ = np_quantization(x, _0, _s)
92
+ x_ = np_dequantization(x_, _0, _s)
93
+ return x_
94
+
95
+
96
+
97
+ ############### HELPERS ###############
98
+
99
+ def domain_guard(
100
+ x: torch.Tensor,
101
+ min: float = None,
102
+ max: float = None,
103
+ posinf: float = None,
104
+ neginf: float = None,
105
+ nan: float = None
106
+ ) -> torch.Tensor:
107
+ """Guard a tensor to a valid domain."""
108
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
109
+ if min is not None or max is not None:
110
+ x = torch.clamp(x, min=min, max=max)
111
+ return x
112
+
113
+
114
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
115
+ """Replace a number in a tensor with another number.
116
+
117
+ Args:
118
+ x (torch.Tensor): The input tensor.
119
+ num (float): The number to replace.
120
+ to (float): The number to replace with.
121
+
122
+ Returns:
123
+ torch.Tensor: The tensor with the number replaced.
124
+ """
125
+ return torch.where(x == num, to, x)
126
+
127
+
128
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
129
+ """Guard the power operation to a valid domain."""
130
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
131
+
132
+
133
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
139
+ val = torch.amin(x, dim=1)
140
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
141
+
142
+
143
+ def init_space_search(
144
+ x: torch.Tensor,
145
+ **kwargs: Dict[str, Any],
146
+ ) -> torch.Tensor:
147
+
148
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
149
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
150
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
151
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
152
+
153
+ def _search_param(tensors: List[torch.tensor], n_params):
154
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
155
+ torch_tensors = torch.stack(tensors)
156
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
157
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
158
+ mean = torch.mean(torch_tensors, dim=0)
159
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
160
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
161
+
162
+ def _calc(x, qtz_func, deqtz_func, **params):
163
+ x_ = x.transpose(0, 1)
164
+ x_ = qtz_func(x=x_, **params)
165
+ x_ = deqtz_func(x=x_, **params)
166
+ x_ = x_.transpose(0, 1)
167
+ return x_
168
+
169
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
170
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
171
+ assert "params_list" in kwargs, "params list must be provided."
172
+ assert "param" in kwargs, "param must be provided."
173
+
174
+ qtz_func = kwargs.get('qtz_func')
175
+ deqtz_func = kwargs.get('deqtz_func')
176
+ params_list = kwargs.get('params_list')
177
+ param = kwargs.get('param')
178
+
179
+ n_runs = 50 # Number of runs to try to find the best parameters
180
+ n_random_params = 50 # Number of random parameters to generate
181
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
182
+ max_initial = 10000 # Maximum value to initialize the parameters
183
+
184
+ # Initializes the parameters
185
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
186
+ params = _build_initial_param(x, max_initial, n_random_params)
187
+
188
+ # Performs the search
189
+ for _ in range(n_runs):
190
+
191
+ best_params = []
192
+ for param_ in params:
193
+ try:
194
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
195
+ loss_ones = nn.MSELoss()(x, x_)
196
+
197
+ if len(best_params) < n_best_to_pick:
198
+ best_params.append((param_, loss_ones.item()))
199
+ best_params = sorted(best_params, key=lambda x: x[1])
200
+ elif loss_ones < best_params[-1][1]:
201
+ best_params[-1] = (param_, loss_ones.item())
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+
204
+ except Exception: # The parameters might not be valid for the function's domain
205
+ continue
206
+
207
+ # Generates new parameters around the mean
208
+ params = _search_param([p for p, _ in best_params], n_random_params)
209
+
210
+ # Checks if the best parameter is better than the init_ones
211
+ p_ones = init_ones(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
213
+ loss_ones = nn.MSELoss()(x, x_)
214
+
215
+ # Checks if the best parameter is better than the init_rand
216
+ p_rand = init_rand(x, **kwargs)
217
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
218
+ loss_rand = nn.MSELoss()(x, x_)
219
+
220
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
221
+ return p_rand
222
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
223
+ return p_ones
224
+ else:
225
+ return best_params[0][0]
226
+
227
+
228
+ def init_linear_scale( # Symmetric scale. From the study folder
229
+ x: torch.Tensor,
230
+ **kwargs: Dict[str, Any],
231
+ ) -> torch.Tensor:
232
+ assert "bits" in kwargs, "bits must be provided."
233
+ assert "params" in kwargs, "params must be provided."
234
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
235
+
236
+ bits = kwargs.get('bits')
237
+ params = kwargs.get('params')
238
+ qtz_func = kwargs.get('qtz_func')
239
+
240
+ x_ = x.transpose(0, 1)
241
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
242
+ x_ = x_.transpose(0, 1)
243
+
244
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
245
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
246
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
247
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
248
+
249
+ eps = torch.finfo(torch.float32).eps
250
+
251
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
252
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
253
+
254
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
255
+
256
+ # Introduces some noise in scale
257
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
258
+ # scale = scale + 0.01 * torch.randn_like(scale)
259
+ return scale
260
+
261
+
262
+ def init_non_linear_regression_fit(
263
+ x: torch.Tensor,
264
+ **kwargs: Dict[str, Any],
265
+ ) -> torch.Tensor:
266
+
267
+ assert "params_list" in kwargs, "params list must be provided."
268
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
269
+ assert "p0" in kwargs, "p0 must be provided."
270
+ np_fit_func = kwargs.get('np_fit_func')
271
+ params_list = kwargs.get('params_list')
272
+ p0 = kwargs.get('p0')
273
+
274
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
275
+ popt, _ = curve_fit(
276
+ func,
277
+ xdata,
278
+ ydata,
279
+ maxfev=1000,
280
+ p0=p0,
281
+ method='lm'
282
+ )
283
+ return popt
284
+
285
+ # 1. Needs to convert the torch tensor to numpy tensor
286
+ xdata = x.cpu().numpy()
287
+
288
+ # 2. Sorts the data so that it makes it easier to fit to it
289
+ sorted_xdata = np.sort(xdata, axis=-1)
290
+
291
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
292
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
293
+
294
+ # 3. Finds the best parameters for each channel
295
+ try:
296
+ params = []
297
+ for i in range(sorted_xdata.shape[0]):
298
+ xdata_ = sorted_xdata[i]
299
+ p0_ = [p0[p][i] for p in params_list]
300
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
301
+ params.append(ch_params)
302
+
303
+ # 4. Builds the parameters
304
+ result = {}
305
+ for i, p in enumerate(params_list):
306
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
307
+
308
+ return result
309
+
310
+ except ValueError as e:
311
+ print(f"Could not fit the function with error: {e}")
312
+ print(f"Using fallback result...")
313
+ return {
314
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
315
+ }
316
+
317
+
318
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
319
+ val = torch.amin(x, dim=1)
320
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
321
+
322
+
323
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
324
+ # Calculate the original minimum and maximum values
325
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
326
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
327
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
328
+
329
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
330
+ return torch.ones_like(x_min)
331
+
332
+ # Calculate the scale factor
333
+ scale = (_max - _min) / (x_max - x_min)
334
+ return scale
335
+
336
+
337
+
338
+ ############## Quant ###############
339
+
340
+ @torch.enable_grad()
341
+ def learn_parameters(
342
+ x: torch.Tensor,
343
+ params: Dict[str, nn.Parameter],
344
+ qtz_func: nn.Module,
345
+ deqtz_func: nn.Module,
346
+ bits: int,
347
+ target_dtype: torch.dtype,
348
+ epochs: int = 1000,
349
+ early_stop: bool = True,
350
+ do_report: bool = False
351
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
352
+ loss_fn = nn.MSELoss()
353
+
354
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
355
+ # the order of magnitude of the loss divided by 2
356
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
357
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
358
+ loss = loss_fn(x, dequant)
359
+
360
+ base_lr = 0.1
361
+ exponent = int(np.floor(np.log10(loss.item())))
362
+ lr = base_lr * (10 ** (exponent // 2))
363
+
364
+ # Requires gradients in the parameters
365
+ for p in params.values():
366
+ p.requires_grad = True
367
+ p.grad = None
368
+
369
+ param_keys = list(params.keys())
370
+ param_values = list(params.values())
371
+
372
+ # Defines optimizer and loss function
373
+ optimizer = torch.optim.Adam(param_values, lr=lr)
374
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
375
+
376
+ # Contains the best loss and the best parameters
377
+ best_loss = float("inf")
378
+ best_params = None
379
+
380
+ # Used to stop the search early
381
+ min_delta = 1e-7
382
+ acc_loss = []
383
+ percent_epochs_before_stop = 0.1
384
+
385
+ for i in range(epochs):
386
+ optimizer.zero_grad()
387
+
388
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
389
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
390
+ loss = loss_fn(x, dequant)
391
+
392
+ if loss.isnan() or loss.isinf():
393
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
394
+
395
+ loss.backward()
396
+ optimizer.step()
397
+ scheduler.step()
398
+
399
+ acc_loss.append(loss.item())
400
+
401
+ # Reports loss every 10 steps
402
+ if i % 10 == 0 and do_report:
403
+ print(f"Epoch {i}: Loss {loss.item()}")
404
+
405
+ # Optimizes the parameter search by storing the best loss and the parameters
406
+ if loss.item() < best_loss:
407
+ best_loss = loss.item()
408
+ best_params = copy.deepcopy({
409
+ k: v for k, v in params.items() if k in param_keys
410
+ })
411
+
412
+ # We also stop the search if the loss has not considerably during the last 10% epochs
413
+ if early_stop:
414
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
415
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
416
+ break
417
+
418
+ # No longer requires gradients in the parameters
419
+ for p in best_params.values():
420
+ p.requires_grad = False
421
+ p.grad = None
422
+
423
+ if do_report:
424
+ return best_params, acc_loss
425
+ else:
426
+ return best_params
427
+
428
+
429
+ def quantize(
430
+ x: torch.Tensor,
431
+ params: Dict[str, nn.Parameter],
432
+ func: nn.Module,
433
+ bits: int,
434
+ target_dtype: torch.dtype = torch.int8
435
+ ) -> torch.Tensor:
436
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
437
+ x = x.transpose(0, 1) # Aligns shapes
438
+ x = func(x=x, **params)
439
+ x = x.transpose(0, 1)
440
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
441
+ return x
442
+
443
+
444
+ def dequantize(
445
+ x: torch.Tensor,
446
+ params: Dict[str, nn.Parameter],
447
+ func: nn.Module,
448
+ bits: int,
449
+ out_dtype: torch.dtype
450
+ ) -> torch.Tensor:
451
+ x = x.to(dtype=out_dtype)
452
+ x = x.transpose(0, 1)
453
+ x = func(x=x, **params)
454
+ x = x.transpose(0, 1)
455
+ return x
456
+
457
+
458
+ def round_func_BPDA(input):
459
+ # This is equivalent to replacing round function (non-differentiable) with
460
+ # an identity function (differentiable) only when backward.
461
+ forward_value = torch.round(input)
462
+ out = input.clone()
463
+ out.data = forward_value.data
464
+ return out
465
+
466
+
467
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
468
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
469
+
470
+
471
+
472
+ ############## Numpy ###############
473
+
474
+ def np_domain_guard(
475
+ x: np.ndarray,
476
+ min: float = None,
477
+ max: float = None,
478
+ posinf: float = None,
479
+ neginf: float = None,
480
+ nan: float = None
481
+ ) -> np.ndarray:
482
+ """Guard a tensor to a valid domain."""
483
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
484
+ if min is not None or max is not None:
485
+ x = np.clip(x, min, max)
486
+ return x
487
+
488
+
489
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
490
+ """Replace a number in a tensor with another number.
491
+
492
+ Args:
493
+ x (np.ndarray): The input tensor.
494
+ num (float): The number to replace.
495
+ to (float): The number to replace with.
496
+
497
+ Returns:
498
+ np.ndarray: The tensor with the number replaced.
499
+ """
500
+ return np.where(x == num, to, x)
501
+
502
+
503
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
504
+ """Guard the power operation to a valid domain."""
505
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
506
+
fn_gen/ones_t/16/loss.png ADDED
fn_gen/ones_t/16/quantization.png ADDED
fn_gen/ones_t/17/distortion.png ADDED
fn_gen/ones_t/17/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tanh(_0*x)/_s
2
+ log((-_s*x - 1)/(_s*x - 1))/_0
fn_gen/ones_t/17/fn.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ '_0': init_ones(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
58
+ }
59
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
60
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
61
+
62
+ if 'post_init_hook' in kwargs:
63
+ kwargs['post_init_hook'](parameters=params)
64
+
65
+ params = learn_parameters(x, params,
66
+ qtz_func=quantization,
67
+ deqtz_func=dequantization,
68
+ bits=kwargs['bits'],
69
+ target_dtype=torch.int8,
70
+ epochs=500,
71
+ early_stop=False,
72
+ )
73
+ if 'post_train_hook' in kwargs:
74
+ kwargs['post_train_hook'](parameters=params)
75
+
76
+ return params
77
+
78
+
79
+ ############### Numpy Qtz ###############
80
+
81
+
82
+ def np_quantization(x, _0, _s):
83
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x)))
84
+
85
+
86
+ def np_dequantization(x, _0, _s):
87
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5)))
88
+
89
+
90
+ def fit_func(x, _0, _s):
91
+ x_ = np_quantization(x, _0, _s)
92
+ x_ = np_dequantization(x_, _0, _s)
93
+ return x_
94
+
95
+
96
+
97
+ ############### HELPERS ###############
98
+
99
+ def domain_guard(
100
+ x: torch.Tensor,
101
+ min: float = None,
102
+ max: float = None,
103
+ posinf: float = None,
104
+ neginf: float = None,
105
+ nan: float = None
106
+ ) -> torch.Tensor:
107
+ """Guard a tensor to a valid domain."""
108
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
109
+ if min is not None or max is not None:
110
+ x = torch.clamp(x, min=min, max=max)
111
+ return x
112
+
113
+
114
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
115
+ """Replace a number in a tensor with another number.
116
+
117
+ Args:
118
+ x (torch.Tensor): The input tensor.
119
+ num (float): The number to replace.
120
+ to (float): The number to replace with.
121
+
122
+ Returns:
123
+ torch.Tensor: The tensor with the number replaced.
124
+ """
125
+ return torch.where(x == num, to, x)
126
+
127
+
128
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
129
+ """Guard the power operation to a valid domain."""
130
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
131
+
132
+
133
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
134
+ val = torch.amin(x, dim=1)
135
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
136
+
137
+
138
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
139
+ val = torch.amin(x, dim=1)
140
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
141
+
142
+
143
+ def init_space_search(
144
+ x: torch.Tensor,
145
+ **kwargs: Dict[str, Any],
146
+ ) -> torch.Tensor:
147
+
148
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
149
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
150
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
151
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
152
+
153
+ def _search_param(tensors: List[torch.tensor], n_params):
154
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
155
+ torch_tensors = torch.stack(tensors)
156
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
157
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
158
+ mean = torch.mean(torch_tensors, dim=0)
159
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
160
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
161
+
162
+ def _calc(x, qtz_func, deqtz_func, **params):
163
+ x_ = x.transpose(0, 1)
164
+ x_ = qtz_func(x=x_, **params)
165
+ x_ = deqtz_func(x=x_, **params)
166
+ x_ = x_.transpose(0, 1)
167
+ return x_
168
+
169
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
170
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
171
+ assert "params_list" in kwargs, "params list must be provided."
172
+ assert "param" in kwargs, "param must be provided."
173
+
174
+ qtz_func = kwargs.get('qtz_func')
175
+ deqtz_func = kwargs.get('deqtz_func')
176
+ params_list = kwargs.get('params_list')
177
+ param = kwargs.get('param')
178
+
179
+ n_runs = 50 # Number of runs to try to find the best parameters
180
+ n_random_params = 50 # Number of random parameters to generate
181
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
182
+ max_initial = 10000 # Maximum value to initialize the parameters
183
+
184
+ # Initializes the parameters
185
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
186
+ params = _build_initial_param(x, max_initial, n_random_params)
187
+
188
+ # Performs the search
189
+ for _ in range(n_runs):
190
+
191
+ best_params = []
192
+ for param_ in params:
193
+ try:
194
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
195
+ loss_ones = nn.MSELoss()(x, x_)
196
+
197
+ if len(best_params) < n_best_to_pick:
198
+ best_params.append((param_, loss_ones.item()))
199
+ best_params = sorted(best_params, key=lambda x: x[1])
200
+ elif loss_ones < best_params[-1][1]:
201
+ best_params[-1] = (param_, loss_ones.item())
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+
204
+ except Exception: # The parameters might not be valid for the function's domain
205
+ continue
206
+
207
+ # Generates new parameters around the mean
208
+ params = _search_param([p for p, _ in best_params], n_random_params)
209
+
210
+ # Checks if the best parameter is better than the init_ones
211
+ p_ones = init_ones(x, **kwargs)
212
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
213
+ loss_ones = nn.MSELoss()(x, x_)
214
+
215
+ # Checks if the best parameter is better than the init_rand
216
+ p_rand = init_rand(x, **kwargs)
217
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
218
+ loss_rand = nn.MSELoss()(x, x_)
219
+
220
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
221
+ return p_rand
222
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
223
+ return p_ones
224
+ else:
225
+ return best_params[0][0]
226
+
227
+
228
+ def init_linear_scale( # Symmetric scale. From the study folder
229
+ x: torch.Tensor,
230
+ **kwargs: Dict[str, Any],
231
+ ) -> torch.Tensor:
232
+ assert "bits" in kwargs, "bits must be provided."
233
+ assert "params" in kwargs, "params must be provided."
234
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
235
+
236
+ bits = kwargs.get('bits')
237
+ params = kwargs.get('params')
238
+ qtz_func = kwargs.get('qtz_func')
239
+
240
+ x_ = x.transpose(0, 1)
241
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
242
+ x_ = x_.transpose(0, 1)
243
+
244
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
245
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
246
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
247
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
248
+
249
+ eps = torch.finfo(torch.float32).eps
250
+
251
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
252
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
253
+
254
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
255
+
256
+ # Introduces some noise in scale
257
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
258
+ # scale = scale + 0.01 * torch.randn_like(scale)
259
+ return scale
260
+
261
+
262
+ def init_non_linear_regression_fit(
263
+ x: torch.Tensor,
264
+ **kwargs: Dict[str, Any],
265
+ ) -> torch.Tensor:
266
+
267
+ assert "params_list" in kwargs, "params list must be provided."
268
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
269
+ assert "p0" in kwargs, "p0 must be provided."
270
+ np_fit_func = kwargs.get('np_fit_func')
271
+ params_list = kwargs.get('params_list')
272
+ p0 = kwargs.get('p0')
273
+
274
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
275
+ popt, _ = curve_fit(
276
+ func,
277
+ xdata,
278
+ ydata,
279
+ maxfev=1000,
280
+ p0=p0,
281
+ method='lm'
282
+ )
283
+ return popt
284
+
285
+ # 1. Needs to convert the torch tensor to numpy tensor
286
+ xdata = x.cpu().numpy()
287
+
288
+ # 2. Sorts the data so that it makes it easier to fit to it
289
+ sorted_xdata = np.sort(xdata, axis=-1)
290
+
291
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
292
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
293
+
294
+ # 3. Finds the best parameters for each channel
295
+ try:
296
+ params = []
297
+ for i in range(sorted_xdata.shape[0]):
298
+ xdata_ = sorted_xdata[i]
299
+ p0_ = [p0[p][i] for p in params_list]
300
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
301
+ params.append(ch_params)
302
+
303
+ # 4. Builds the parameters
304
+ result = {}
305
+ for i, p in enumerate(params_list):
306
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
307
+
308
+ return result
309
+
310
+ except ValueError as e:
311
+ print(f"Could not fit the function with error: {e}")
312
+ print(f"Using fallback result...")
313
+ return {
314
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
315
+ }
316
+
317
+
318
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
319
+ val = torch.amin(x, dim=1)
320
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
321
+
322
+
323
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
324
+ # Calculate the original minimum and maximum values
325
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
326
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
327
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
328
+
329
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
330
+ return torch.ones_like(x_min)
331
+
332
+ # Calculate the scale factor
333
+ scale = (_max - _min) / (x_max - x_min)
334
+ return scale
335
+
336
+
337
+
338
+ ############## Quant ###############
339
+
340
+ @torch.enable_grad()
341
+ def learn_parameters(
342
+ x: torch.Tensor,
343
+ params: Dict[str, nn.Parameter],
344
+ qtz_func: nn.Module,
345
+ deqtz_func: nn.Module,
346
+ bits: int,
347
+ target_dtype: torch.dtype,
348
+ epochs: int = 1000,
349
+ early_stop: bool = True,
350
+ do_report: bool = False
351
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
352
+ loss_fn = nn.MSELoss()
353
+
354
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
355
+ # the order of magnitude of the loss divided by 2
356
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
357
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
358
+ loss = loss_fn(x, dequant)
359
+
360
+ base_lr = 0.1
361
+ exponent = int(np.floor(np.log10(loss.item())))
362
+ lr = base_lr * (10 ** (exponent // 2))
363
+
364
+ # Requires gradients in the parameters
365
+ for p in params.values():
366
+ p.requires_grad = True
367
+ p.grad = None
368
+
369
+ param_keys = list(params.keys())
370
+ param_values = list(params.values())
371
+
372
+ # Defines optimizer and loss function
373
+ optimizer = torch.optim.Adam(param_values, lr=lr)
374
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
375
+
376
+ # Contains the best loss and the best parameters
377
+ best_loss = float("inf")
378
+ best_params = None
379
+
380
+ # Used to stop the search early
381
+ min_delta = 1e-7
382
+ acc_loss = []
383
+ percent_epochs_before_stop = 0.1
384
+
385
+ for i in range(epochs):
386
+ optimizer.zero_grad()
387
+
388
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
389
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
390
+ loss = loss_fn(x, dequant)
391
+
392
+ if loss.isnan() or loss.isinf():
393
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
394
+
395
+ loss.backward()
396
+ optimizer.step()
397
+ scheduler.step()
398
+
399
+ acc_loss.append(loss.item())
400
+
401
+ # Reports loss every 10 steps
402
+ if i % 10 == 0 and do_report:
403
+ print(f"Epoch {i}: Loss {loss.item()}")
404
+
405
+ # Optimizes the parameter search by storing the best loss and the parameters
406
+ if loss.item() < best_loss:
407
+ best_loss = loss.item()
408
+ best_params = copy.deepcopy({
409
+ k: v for k, v in params.items() if k in param_keys
410
+ })
411
+
412
+ # We also stop the search if the loss has not considerably during the last 10% epochs
413
+ if early_stop:
414
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
415
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
416
+ break
417
+
418
+ # No longer requires gradients in the parameters
419
+ for p in best_params.values():
420
+ p.requires_grad = False
421
+ p.grad = None
422
+
423
+ if do_report:
424
+ return best_params, acc_loss
425
+ else:
426
+ return best_params
427
+
428
+
429
+ def quantize(
430
+ x: torch.Tensor,
431
+ params: Dict[str, nn.Parameter],
432
+ func: nn.Module,
433
+ bits: int,
434
+ target_dtype: torch.dtype = torch.int8
435
+ ) -> torch.Tensor:
436
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
437
+ x = x.transpose(0, 1) # Aligns shapes
438
+ x = func(x=x, **params)
439
+ x = x.transpose(0, 1)
440
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
441
+ return x
442
+
443
+
444
+ def dequantize(
445
+ x: torch.Tensor,
446
+ params: Dict[str, nn.Parameter],
447
+ func: nn.Module,
448
+ bits: int,
449
+ out_dtype: torch.dtype
450
+ ) -> torch.Tensor:
451
+ x = x.to(dtype=out_dtype)
452
+ x = x.transpose(0, 1)
453
+ x = func(x=x, **params)
454
+ x = x.transpose(0, 1)
455
+ return x
456
+
457
+
458
+ def round_func_BPDA(input):
459
+ # This is equivalent to replacing round function (non-differentiable) with
460
+ # an identity function (differentiable) only when backward.
461
+ forward_value = torch.round(input)
462
+ out = input.clone()
463
+ out.data = forward_value.data
464
+ return out
465
+
466
+
467
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
468
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
469
+
470
+
471
+
472
+ ############## Numpy ###############
473
+
474
+ def np_domain_guard(
475
+ x: np.ndarray,
476
+ min: float = None,
477
+ max: float = None,
478
+ posinf: float = None,
479
+ neginf: float = None,
480
+ nan: float = None
481
+ ) -> np.ndarray:
482
+ """Guard a tensor to a valid domain."""
483
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
484
+ if min is not None or max is not None:
485
+ x = np.clip(x, min, max)
486
+ return x
487
+
488
+
489
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
490
+ """Replace a number in a tensor with another number.
491
+
492
+ Args:
493
+ x (np.ndarray): The input tensor.
494
+ num (float): The number to replace.
495
+ to (float): The number to replace with.
496
+
497
+ Returns:
498
+ np.ndarray: The tensor with the number replaced.
499
+ """
500
+ return np.where(x == num, to, x)
501
+
502
+
503
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
504
+ """Guard the power operation to a valid domain."""
505
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
506
+
fn_gen/ones_t/17/loss.png ADDED
fn_gen/ones_t/17/quantization.png ADDED