Diogo-V commited on
Commit
4d230c0
·
verified ·
1 Parent(s): 5394f13

Upload learned functions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fn_gen/rnd_search_t_cos/1/distortion.png +0 -0
  2. fn_gen/rnd_search_t_cos/1/expressions.txt +2 -0
  3. fn_gen/rnd_search_t_cos/1/fn.py +600 -0
  4. fn_gen/rnd_search_t_cos/1/loss.png +0 -0
  5. fn_gen/rnd_search_t_cos/1/quantization.png +0 -0
  6. fn_gen/rnd_search_t_cos/10/distortion.png +0 -0
  7. fn_gen/rnd_search_t_cos/10/expressions.txt +2 -0
  8. fn_gen/rnd_search_t_cos/10/fn.py +600 -0
  9. fn_gen/rnd_search_t_cos/10/loss.png +0 -0
  10. fn_gen/rnd_search_t_cos/10/quantization.png +0 -0
  11. fn_gen/rnd_search_t_cos/11/distortion.png +0 -0
  12. fn_gen/rnd_search_t_cos/11/expressions.txt +2 -0
  13. fn_gen/rnd_search_t_cos/11/fn.py +600 -0
  14. fn_gen/rnd_search_t_cos/11/loss.png +0 -0
  15. fn_gen/rnd_search_t_cos/11/quantization.png +0 -0
  16. fn_gen/rnd_search_t_cos/12/distortion.png +0 -0
  17. fn_gen/rnd_search_t_cos/12/expressions.txt +2 -0
  18. fn_gen/rnd_search_t_cos/12/fn.py +600 -0
  19. fn_gen/rnd_search_t_cos/12/loss.png +0 -0
  20. fn_gen/rnd_search_t_cos/12/quantization.png +0 -0
  21. fn_gen/rnd_search_t_cos/13/distortion.png +0 -0
  22. fn_gen/rnd_search_t_cos/13/expressions.txt +2 -0
  23. fn_gen/rnd_search_t_cos/13/fn.py +600 -0
  24. fn_gen/rnd_search_t_cos/13/loss.png +0 -0
  25. fn_gen/rnd_search_t_cos/13/quantization.png +0 -0
  26. fn_gen/rnd_search_t_cos/15/distortion.png +0 -0
  27. fn_gen/rnd_search_t_cos/15/expressions.txt +2 -0
  28. fn_gen/rnd_search_t_cos/15/fn.py +514 -0
  29. fn_gen/rnd_search_t_cos/15/loss.png +0 -0
  30. fn_gen/rnd_search_t_cos/15/quantization.png +0 -0
  31. fn_gen/rnd_search_t_cos/16/distortion.png +0 -0
  32. fn_gen/rnd_search_t_cos/16/expressions.txt +2 -0
  33. fn_gen/rnd_search_t_cos/16/fn.py +600 -0
  34. fn_gen/rnd_search_t_cos/16/loss.png +0 -0
  35. fn_gen/rnd_search_t_cos/16/quantization.png +0 -0
  36. fn_gen/rnd_search_t_cos/17/distortion.png +0 -0
  37. fn_gen/rnd_search_t_cos/17/expressions.txt +2 -0
  38. fn_gen/rnd_search_t_cos/17/fn.py +514 -0
  39. fn_gen/rnd_search_t_cos/17/loss.png +0 -0
  40. fn_gen/rnd_search_t_cos/17/quantization.png +0 -0
  41. fn_gen/rnd_search_t_cos/18/distortion.png +0 -0
  42. fn_gen/rnd_search_t_cos/18/expressions.txt +2 -0
  43. fn_gen/rnd_search_t_cos/18/fn.py +600 -0
  44. fn_gen/rnd_search_t_cos/18/loss.png +0 -0
  45. fn_gen/rnd_search_t_cos/18/quantization.png +0 -0
  46. fn_gen/rnd_search_t_cos/2/distortion.png +0 -0
  47. fn_gen/rnd_search_t_cos/2/expressions.txt +2 -0
  48. fn_gen/rnd_search_t_cos/2/fn.py +600 -0
  49. fn_gen/rnd_search_t_cos/2/loss.png +0 -0
  50. fn_gen/rnd_search_t_cos/2/quantization.png +0 -0
fn_gen/rnd_search_t_cos/1/distortion.png ADDED
fn_gen/rnd_search_t_cos/1/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asin(_0*x)/_s
2
+ sin(_s*x)/_0
fn_gen/rnd_search_t_cos/1/fn.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asin(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sin((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsin(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sin((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
468
+
469
+ # Contains the best loss and the best parameters
470
+ best_loss = float("inf")
471
+ best_params = None
472
+
473
+ # Used to stop the search early
474
+ min_delta = 1e-7
475
+ acc_loss = []
476
+ percent_epochs_before_stop = 0.1
477
+
478
+ for i in range(epochs):
479
+ optimizer.zero_grad()
480
+
481
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
482
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
483
+ loss = loss_fn(x, dequant)
484
+
485
+ if loss.isnan() or loss.isinf():
486
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
487
+
488
+ loss.backward()
489
+ optimizer.step()
490
+ scheduler.step()
491
+
492
+ acc_loss.append(loss.item())
493
+
494
+ # Reports loss every 10 steps
495
+ if i % 10 == 0 and do_report:
496
+ print(f"Epoch {i}: Loss {loss.item()}")
497
+
498
+ # Optimizes the parameter search by storing the best loss and the parameters
499
+ if loss.item() < best_loss:
500
+ best_loss = loss.item()
501
+ best_params = copy.deepcopy({
502
+ k: v for k, v in params.items() if k in param_keys
503
+ })
504
+
505
+ # We also stop the search if the loss has not considerably during the last 10% epochs
506
+ if early_stop:
507
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
508
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
509
+ break
510
+
511
+ # No longer requires gradients in the parameters
512
+ for p in best_params.values():
513
+ p.requires_grad = False
514
+ p.grad = None
515
+
516
+ if do_report:
517
+ print(f"Best loss: {best_loss}")
518
+ return best_params, acc_loss
519
+ else:
520
+ return best_params
521
+
522
+
523
+ def quantize(
524
+ x: torch.Tensor,
525
+ params: Dict[str, nn.Parameter],
526
+ func: nn.Module,
527
+ bits: int,
528
+ target_dtype: torch.dtype = torch.int8
529
+ ) -> torch.Tensor:
530
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
531
+ x = x.transpose(0, 1) # Aligns shapes
532
+ x = func(x=x, **params)
533
+ x = x.transpose(0, 1)
534
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
535
+ return x
536
+
537
+
538
+ def dequantize(
539
+ x: torch.Tensor,
540
+ params: Dict[str, nn.Parameter],
541
+ func: nn.Module,
542
+ bits: int,
543
+ out_dtype: torch.dtype
544
+ ) -> torch.Tensor:
545
+ x = x.to(dtype=out_dtype)
546
+ x = x.transpose(0, 1)
547
+ x = func(x=x, **params)
548
+ x = x.transpose(0, 1)
549
+ return x
550
+
551
+
552
+ def round_func_BPDA(input):
553
+ # This is equivalent to replacing round function (non-differentiable) with
554
+ # an identity function (differentiable) only when backward.
555
+ forward_value = torch.round(input)
556
+ out = input.clone()
557
+ out.data = forward_value.data
558
+ return out
559
+
560
+
561
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
562
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
563
+
564
+
565
+
566
+ ############## Numpy ###############
567
+
568
+ def np_domain_guard(
569
+ x: np.ndarray,
570
+ min: float = None,
571
+ max: float = None,
572
+ posinf: float = None,
573
+ neginf: float = None,
574
+ nan: float = None
575
+ ) -> np.ndarray:
576
+ """Guard a tensor to a valid domain."""
577
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
578
+ if min is not None or max is not None:
579
+ x = np.clip(x, min, max)
580
+ return x
581
+
582
+
583
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
584
+ """Replace a number in a tensor with another number.
585
+
586
+ Args:
587
+ x (np.ndarray): The input tensor.
588
+ num (float): The number to replace.
589
+ to (float): The number to replace with.
590
+
591
+ Returns:
592
+ np.ndarray: The tensor with the number replaced.
593
+ """
594
+ return np.where(x == num, to, x)
595
+
596
+
597
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
598
+ """Guard the power operation to a valid domain."""
599
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
600
+
fn_gen/rnd_search_t_cos/1/loss.png ADDED
fn_gen/rnd_search_t_cos/1/quantization.png ADDED
fn_gen/rnd_search_t_cos/10/distortion.png ADDED
fn_gen/rnd_search_t_cos/10/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ log(_0*x)/_s
2
+ exp(_s*x)/_0
fn_gen/rnd_search_t_cos/10/fn.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.log(domain_guard((params['_0'] * x), min=1e-5, nan=1e-5)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.exp((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.log(np_domain_guard((_0 * x), min=1e-5, nan=1e-5)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.exp((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
468
+
469
+ # Contains the best loss and the best parameters
470
+ best_loss = float("inf")
471
+ best_params = None
472
+
473
+ # Used to stop the search early
474
+ min_delta = 1e-7
475
+ acc_loss = []
476
+ percent_epochs_before_stop = 0.1
477
+
478
+ for i in range(epochs):
479
+ optimizer.zero_grad()
480
+
481
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
482
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
483
+ loss = loss_fn(x, dequant)
484
+
485
+ if loss.isnan() or loss.isinf():
486
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
487
+
488
+ loss.backward()
489
+ optimizer.step()
490
+ scheduler.step()
491
+
492
+ acc_loss.append(loss.item())
493
+
494
+ # Reports loss every 10 steps
495
+ if i % 10 == 0 and do_report:
496
+ print(f"Epoch {i}: Loss {loss.item()}")
497
+
498
+ # Optimizes the parameter search by storing the best loss and the parameters
499
+ if loss.item() < best_loss:
500
+ best_loss = loss.item()
501
+ best_params = copy.deepcopy({
502
+ k: v for k, v in params.items() if k in param_keys
503
+ })
504
+
505
+ # We also stop the search if the loss has not considerably during the last 10% epochs
506
+ if early_stop:
507
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
508
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
509
+ break
510
+
511
+ # No longer requires gradients in the parameters
512
+ for p in best_params.values():
513
+ p.requires_grad = False
514
+ p.grad = None
515
+
516
+ if do_report:
517
+ print(f"Best loss: {best_loss}")
518
+ return best_params, acc_loss
519
+ else:
520
+ return best_params
521
+
522
+
523
+ def quantize(
524
+ x: torch.Tensor,
525
+ params: Dict[str, nn.Parameter],
526
+ func: nn.Module,
527
+ bits: int,
528
+ target_dtype: torch.dtype = torch.int8
529
+ ) -> torch.Tensor:
530
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
531
+ x = x.transpose(0, 1) # Aligns shapes
532
+ x = func(x=x, **params)
533
+ x = x.transpose(0, 1)
534
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
535
+ return x
536
+
537
+
538
+ def dequantize(
539
+ x: torch.Tensor,
540
+ params: Dict[str, nn.Parameter],
541
+ func: nn.Module,
542
+ bits: int,
543
+ out_dtype: torch.dtype
544
+ ) -> torch.Tensor:
545
+ x = x.to(dtype=out_dtype)
546
+ x = x.transpose(0, 1)
547
+ x = func(x=x, **params)
548
+ x = x.transpose(0, 1)
549
+ return x
550
+
551
+
552
+ def round_func_BPDA(input):
553
+ # This is equivalent to replacing round function (non-differentiable) with
554
+ # an identity function (differentiable) only when backward.
555
+ forward_value = torch.round(input)
556
+ out = input.clone()
557
+ out.data = forward_value.data
558
+ return out
559
+
560
+
561
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
562
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
563
+
564
+
565
+
566
+ ############## Numpy ###############
567
+
568
+ def np_domain_guard(
569
+ x: np.ndarray,
570
+ min: float = None,
571
+ max: float = None,
572
+ posinf: float = None,
573
+ neginf: float = None,
574
+ nan: float = None
575
+ ) -> np.ndarray:
576
+ """Guard a tensor to a valid domain."""
577
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
578
+ if min is not None or max is not None:
579
+ x = np.clip(x, min, max)
580
+ return x
581
+
582
+
583
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
584
+ """Replace a number in a tensor with another number.
585
+
586
+ Args:
587
+ x (np.ndarray): The input tensor.
588
+ num (float): The number to replace.
589
+ to (float): The number to replace with.
590
+
591
+ Returns:
592
+ np.ndarray: The tensor with the number replaced.
593
+ """
594
+ return np.where(x == num, to, x)
595
+
596
+
597
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
598
+ """Guard the power operation to a valid domain."""
599
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
600
+
fn_gen/rnd_search_t_cos/10/loss.png ADDED
fn_gen/rnd_search_t_cos/10/quantization.png ADDED
fn_gen/rnd_search_t_cos/11/distortion.png ADDED
fn_gen/rnd_search_t_cos/11/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sqrt(_0*x)/_s
2
+ _s**2*x**2/_0
fn_gen/rnd_search_t_cos/11/fn.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sqrt(domain_guard((params['_0'] * x), min=0.1, nan=0.1)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sqrt(np_domain_guard((_0 * x), min=0.1, nan=0.1)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
468
+
469
+ # Contains the best loss and the best parameters
470
+ best_loss = float("inf")
471
+ best_params = None
472
+
473
+ # Used to stop the search early
474
+ min_delta = 1e-7
475
+ acc_loss = []
476
+ percent_epochs_before_stop = 0.1
477
+
478
+ for i in range(epochs):
479
+ optimizer.zero_grad()
480
+
481
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
482
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
483
+ loss = loss_fn(x, dequant)
484
+
485
+ if loss.isnan() or loss.isinf():
486
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
487
+
488
+ loss.backward()
489
+ optimizer.step()
490
+ scheduler.step()
491
+
492
+ acc_loss.append(loss.item())
493
+
494
+ # Reports loss every 10 steps
495
+ if i % 10 == 0 and do_report:
496
+ print(f"Epoch {i}: Loss {loss.item()}")
497
+
498
+ # Optimizes the parameter search by storing the best loss and the parameters
499
+ if loss.item() < best_loss:
500
+ best_loss = loss.item()
501
+ best_params = copy.deepcopy({
502
+ k: v for k, v in params.items() if k in param_keys
503
+ })
504
+
505
+ # We also stop the search if the loss has not considerably during the last 10% epochs
506
+ if early_stop:
507
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
508
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
509
+ break
510
+
511
+ # No longer requires gradients in the parameters
512
+ for p in best_params.values():
513
+ p.requires_grad = False
514
+ p.grad = None
515
+
516
+ if do_report:
517
+ print(f"Best loss: {best_loss}")
518
+ return best_params, acc_loss
519
+ else:
520
+ return best_params
521
+
522
+
523
+ def quantize(
524
+ x: torch.Tensor,
525
+ params: Dict[str, nn.Parameter],
526
+ func: nn.Module,
527
+ bits: int,
528
+ target_dtype: torch.dtype = torch.int8
529
+ ) -> torch.Tensor:
530
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
531
+ x = x.transpose(0, 1) # Aligns shapes
532
+ x = func(x=x, **params)
533
+ x = x.transpose(0, 1)
534
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
535
+ return x
536
+
537
+
538
+ def dequantize(
539
+ x: torch.Tensor,
540
+ params: Dict[str, nn.Parameter],
541
+ func: nn.Module,
542
+ bits: int,
543
+ out_dtype: torch.dtype
544
+ ) -> torch.Tensor:
545
+ x = x.to(dtype=out_dtype)
546
+ x = x.transpose(0, 1)
547
+ x = func(x=x, **params)
548
+ x = x.transpose(0, 1)
549
+ return x
550
+
551
+
552
+ def round_func_BPDA(input):
553
+ # This is equivalent to replacing round function (non-differentiable) with
554
+ # an identity function (differentiable) only when backward.
555
+ forward_value = torch.round(input)
556
+ out = input.clone()
557
+ out.data = forward_value.data
558
+ return out
559
+
560
+
561
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
562
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
563
+
564
+
565
+
566
+ ############## Numpy ###############
567
+
568
+ def np_domain_guard(
569
+ x: np.ndarray,
570
+ min: float = None,
571
+ max: float = None,
572
+ posinf: float = None,
573
+ neginf: float = None,
574
+ nan: float = None
575
+ ) -> np.ndarray:
576
+ """Guard a tensor to a valid domain."""
577
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
578
+ if min is not None or max is not None:
579
+ x = np.clip(x, min, max)
580
+ return x
581
+
582
+
583
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
584
+ """Replace a number in a tensor with another number.
585
+
586
+ Args:
587
+ x (np.ndarray): The input tensor.
588
+ num (float): The number to replace.
589
+ to (float): The number to replace with.
590
+
591
+ Returns:
592
+ np.ndarray: The tensor with the number replaced.
593
+ """
594
+ return np.where(x == num, to, x)
595
+
596
+
597
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
598
+ """Guard the power operation to a valid domain."""
599
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
600
+
fn_gen/rnd_search_t_cos/11/loss.png ADDED
fn_gen/rnd_search_t_cos/11/quantization.png ADDED
fn_gen/rnd_search_t_cos/12/distortion.png ADDED
fn_gen/rnd_search_t_cos/12/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acos(_0*x)/_s
2
+ cos(_s*x)/_0
fn_gen/rnd_search_t_cos/12/fn.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
468
+
469
+ # Contains the best loss and the best parameters
470
+ best_loss = float("inf")
471
+ best_params = None
472
+
473
+ # Used to stop the search early
474
+ min_delta = 1e-7
475
+ acc_loss = []
476
+ percent_epochs_before_stop = 0.1
477
+
478
+ for i in range(epochs):
479
+ optimizer.zero_grad()
480
+
481
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
482
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
483
+ loss = loss_fn(x, dequant)
484
+
485
+ if loss.isnan() or loss.isinf():
486
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
487
+
488
+ loss.backward()
489
+ optimizer.step()
490
+ scheduler.step()
491
+
492
+ acc_loss.append(loss.item())
493
+
494
+ # Reports loss every 10 steps
495
+ if i % 10 == 0 and do_report:
496
+ print(f"Epoch {i}: Loss {loss.item()}")
497
+
498
+ # Optimizes the parameter search by storing the best loss and the parameters
499
+ if loss.item() < best_loss:
500
+ best_loss = loss.item()
501
+ best_params = copy.deepcopy({
502
+ k: v for k, v in params.items() if k in param_keys
503
+ })
504
+
505
+ # We also stop the search if the loss has not considerably during the last 10% epochs
506
+ if early_stop:
507
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
508
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
509
+ break
510
+
511
+ # No longer requires gradients in the parameters
512
+ for p in best_params.values():
513
+ p.requires_grad = False
514
+ p.grad = None
515
+
516
+ if do_report:
517
+ print(f"Best loss: {best_loss}")
518
+ return best_params, acc_loss
519
+ else:
520
+ return best_params
521
+
522
+
523
+ def quantize(
524
+ x: torch.Tensor,
525
+ params: Dict[str, nn.Parameter],
526
+ func: nn.Module,
527
+ bits: int,
528
+ target_dtype: torch.dtype = torch.int8
529
+ ) -> torch.Tensor:
530
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
531
+ x = x.transpose(0, 1) # Aligns shapes
532
+ x = func(x=x, **params)
533
+ x = x.transpose(0, 1)
534
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
535
+ return x
536
+
537
+
538
+ def dequantize(
539
+ x: torch.Tensor,
540
+ params: Dict[str, nn.Parameter],
541
+ func: nn.Module,
542
+ bits: int,
543
+ out_dtype: torch.dtype
544
+ ) -> torch.Tensor:
545
+ x = x.to(dtype=out_dtype)
546
+ x = x.transpose(0, 1)
547
+ x = func(x=x, **params)
548
+ x = x.transpose(0, 1)
549
+ return x
550
+
551
+
552
+ def round_func_BPDA(input):
553
+ # This is equivalent to replacing round function (non-differentiable) with
554
+ # an identity function (differentiable) only when backward.
555
+ forward_value = torch.round(input)
556
+ out = input.clone()
557
+ out.data = forward_value.data
558
+ return out
559
+
560
+
561
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
562
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
563
+
564
+
565
+
566
+ ############## Numpy ###############
567
+
568
+ def np_domain_guard(
569
+ x: np.ndarray,
570
+ min: float = None,
571
+ max: float = None,
572
+ posinf: float = None,
573
+ neginf: float = None,
574
+ nan: float = None
575
+ ) -> np.ndarray:
576
+ """Guard a tensor to a valid domain."""
577
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
578
+ if min is not None or max is not None:
579
+ x = np.clip(x, min, max)
580
+ return x
581
+
582
+
583
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
584
+ """Replace a number in a tensor with another number.
585
+
586
+ Args:
587
+ x (np.ndarray): The input tensor.
588
+ num (float): The number to replace.
589
+ to (float): The number to replace with.
590
+
591
+ Returns:
592
+ np.ndarray: The tensor with the number replaced.
593
+ """
594
+ return np.where(x == num, to, x)
595
+
596
+
597
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
598
+ """Guard the power operation to a valid domain."""
599
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
600
+
fn_gen/rnd_search_t_cos/12/loss.png ADDED
fn_gen/rnd_search_t_cos/12/quantization.png ADDED
fn_gen/rnd_search_t_cos/13/distortion.png ADDED
fn_gen/rnd_search_t_cos/13/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tanh(_0*x)/_s
2
+ log((-_s*x - 1)/(_s*x - 1))/_0
fn_gen/rnd_search_t_cos/13/fn.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tanh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard((torch.div(1, replace_num((torch.tensor(-1) + (params['_s'] * x)), num=0, to=10000)) * (torch.tensor(-1) + (torch.tensor(-1) * params['_s'] * x))), min=1e-5, nan=1e-5)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tanh((_0 * x)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard((np.divide(1, np_replace_num((np.array(-1) + (_s * x)), num=0, to=10000)) * (np.array(-1) + (np.array(-1) * _s * x))), min=1e-5, nan=1e-5)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
468
+
469
+ # Contains the best loss and the best parameters
470
+ best_loss = float("inf")
471
+ best_params = None
472
+
473
+ # Used to stop the search early
474
+ min_delta = 1e-7
475
+ acc_loss = []
476
+ percent_epochs_before_stop = 0.1
477
+
478
+ for i in range(epochs):
479
+ optimizer.zero_grad()
480
+
481
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
482
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
483
+ loss = loss_fn(x, dequant)
484
+
485
+ if loss.isnan() or loss.isinf():
486
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
487
+
488
+ loss.backward()
489
+ optimizer.step()
490
+ scheduler.step()
491
+
492
+ acc_loss.append(loss.item())
493
+
494
+ # Reports loss every 10 steps
495
+ if i % 10 == 0 and do_report:
496
+ print(f"Epoch {i}: Loss {loss.item()}")
497
+
498
+ # Optimizes the parameter search by storing the best loss and the parameters
499
+ if loss.item() < best_loss:
500
+ best_loss = loss.item()
501
+ best_params = copy.deepcopy({
502
+ k: v for k, v in params.items() if k in param_keys
503
+ })
504
+
505
+ # We also stop the search if the loss has not considerably during the last 10% epochs
506
+ if early_stop:
507
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
508
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
509
+ break
510
+
511
+ # No longer requires gradients in the parameters
512
+ for p in best_params.values():
513
+ p.requires_grad = False
514
+ p.grad = None
515
+
516
+ if do_report:
517
+ print(f"Best loss: {best_loss}")
518
+ return best_params, acc_loss
519
+ else:
520
+ return best_params
521
+
522
+
523
+ def quantize(
524
+ x: torch.Tensor,
525
+ params: Dict[str, nn.Parameter],
526
+ func: nn.Module,
527
+ bits: int,
528
+ target_dtype: torch.dtype = torch.int8
529
+ ) -> torch.Tensor:
530
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
531
+ x = x.transpose(0, 1) # Aligns shapes
532
+ x = func(x=x, **params)
533
+ x = x.transpose(0, 1)
534
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
535
+ return x
536
+
537
+
538
+ def dequantize(
539
+ x: torch.Tensor,
540
+ params: Dict[str, nn.Parameter],
541
+ func: nn.Module,
542
+ bits: int,
543
+ out_dtype: torch.dtype
544
+ ) -> torch.Tensor:
545
+ x = x.to(dtype=out_dtype)
546
+ x = x.transpose(0, 1)
547
+ x = func(x=x, **params)
548
+ x = x.transpose(0, 1)
549
+ return x
550
+
551
+
552
+ def round_func_BPDA(input):
553
+ # This is equivalent to replacing round function (non-differentiable) with
554
+ # an identity function (differentiable) only when backward.
555
+ forward_value = torch.round(input)
556
+ out = input.clone()
557
+ out.data = forward_value.data
558
+ return out
559
+
560
+
561
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
562
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
563
+
564
+
565
+
566
+ ############## Numpy ###############
567
+
568
+ def np_domain_guard(
569
+ x: np.ndarray,
570
+ min: float = None,
571
+ max: float = None,
572
+ posinf: float = None,
573
+ neginf: float = None,
574
+ nan: float = None
575
+ ) -> np.ndarray:
576
+ """Guard a tensor to a valid domain."""
577
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
578
+ if min is not None or max is not None:
579
+ x = np.clip(x, min, max)
580
+ return x
581
+
582
+
583
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
584
+ """Replace a number in a tensor with another number.
585
+
586
+ Args:
587
+ x (np.ndarray): The input tensor.
588
+ num (float): The number to replace.
589
+ to (float): The number to replace with.
590
+
591
+ Returns:
592
+ np.ndarray: The tensor with the number replaced.
593
+ """
594
+ return np.where(x == num, to, x)
595
+
596
+
597
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
598
+ """Guard the power operation to a valid domain."""
599
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
600
+
fn_gen/rnd_search_t_cos/13/loss.png ADDED
fn_gen/rnd_search_t_cos/13/quantization.png ADDED
fn_gen/rnd_search_t_cos/15/distortion.png ADDED
fn_gen/rnd_search_t_cos/15/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x/_s
2
+ _s*x
fn_gen/rnd_search_t_cos/15/fn.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (x * torch.div(1, replace_num(params['_s'], num=0, to=10000)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (params['_s'] * x)
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
52
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
53
+ # left it here for future reference. Will be removed later.
54
+ # scale = scale + 0.01 * torch.randn_like(scale)
55
+
56
+ return scale
57
+
58
+
59
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
60
+ params = {
61
+ }
62
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
63
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
64
+
65
+ if 'post_init_hook' in kwargs:
66
+ kwargs['post_init_hook'](parameters=params)
67
+
68
+ params = learn_parameters(x, params,
69
+ qtz_func=quantization,
70
+ deqtz_func=dequantization,
71
+ bits=kwargs['bits'],
72
+ target_dtype=torch.int8,
73
+ epochs=500,
74
+ early_stop=False,
75
+ )
76
+ if 'post_train_hook' in kwargs:
77
+ kwargs['post_train_hook'](parameters=params)
78
+
79
+ return params
80
+
81
+
82
+ ############### Numpy Qtz ###############
83
+
84
+
85
+ def np_quantization(x, _s):
86
+ return (x * np.divide(1, np_replace_num(_s, num=0, to=10000)))
87
+
88
+
89
+ def np_dequantization(x, _s):
90
+ return (_s * x)
91
+
92
+
93
+ def fit_func(x, _s):
94
+ x_ = np_quantization(x, _s)
95
+ x_ = np_dequantization(x_, _s)
96
+ return x_
97
+
98
+
99
+
100
+ ############### HELPERS ###############
101
+
102
+ def domain_guard(
103
+ x: torch.Tensor,
104
+ min: float = None,
105
+ max: float = None,
106
+ posinf: float = None,
107
+ neginf: float = None,
108
+ nan: float = None
109
+ ) -> torch.Tensor:
110
+ """Guard a tensor to a valid domain."""
111
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
112
+ if min is not None or max is not None:
113
+ x = torch.clamp(x, min=min, max=max)
114
+ return x
115
+
116
+
117
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
118
+ """Replace a number in a tensor with another number.
119
+
120
+ Args:
121
+ x (torch.Tensor): The input tensor.
122
+ num (float): The number to replace.
123
+ to (float): The number to replace with.
124
+
125
+ Returns:
126
+ torch.Tensor: The tensor with the number replaced.
127
+ """
128
+ return torch.where(x == num, to, x)
129
+
130
+
131
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
132
+ """Guard the power operation to a valid domain."""
133
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
134
+
135
+
136
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
137
+ val = torch.amin(x, dim=1)
138
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
139
+
140
+
141
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
142
+ val = torch.amin(x, dim=1)
143
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
144
+
145
+
146
+ def init_space_search(
147
+ x: torch.Tensor,
148
+ **kwargs: Dict[str, Any],
149
+ ) -> torch.Tensor:
150
+
151
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
152
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
153
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
154
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
155
+
156
+ def _search_param(tensors: List[torch.tensor], n_params):
157
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
158
+ torch_tensors = torch.stack(tensors)
159
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
160
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
161
+ mean = torch.mean(torch_tensors, dim=0)
162
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
163
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
164
+
165
+ def _calc(x, qtz_func, deqtz_func, **params):
166
+ x_ = x.transpose(0, 1)
167
+ x_ = qtz_func(x=x_, **params)
168
+ x_ = deqtz_func(x=x_, **params)
169
+ x_ = x_.transpose(0, 1)
170
+ return x_
171
+
172
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
173
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
174
+ assert "params_list" in kwargs, "params list must be provided."
175
+ assert "param" in kwargs, "param must be provided."
176
+
177
+ qtz_func = kwargs.get('qtz_func')
178
+ deqtz_func = kwargs.get('deqtz_func')
179
+ params_list = kwargs.get('params_list')
180
+ param = kwargs.get('param')
181
+
182
+ n_runs = 50 # Number of runs to try to find the best parameters
183
+ n_random_params = 50 # Number of random parameters to generate
184
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
185
+ max_initial = 10000 # Maximum value to initialize the parameters
186
+
187
+ # Initializes the parameters
188
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
189
+ params = _build_initial_param(x, max_initial, n_random_params)
190
+
191
+ # Performs the search
192
+ for _ in range(n_runs):
193
+
194
+ best_params = []
195
+ for param_ in params:
196
+ try:
197
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
198
+ loss_ones = nn.MSELoss()(x, x_)
199
+
200
+ if len(best_params) < n_best_to_pick:
201
+ best_params.append((param_, loss_ones.item()))
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+ elif loss_ones < best_params[-1][1]:
204
+ best_params[-1] = (param_, loss_ones.item())
205
+ best_params = sorted(best_params, key=lambda x: x[1])
206
+
207
+ except Exception: # The parameters might not be valid for the function's domain
208
+ continue
209
+
210
+ # Generates new parameters around the mean
211
+ params = _search_param([p for p, _ in best_params], n_random_params)
212
+
213
+ # Checks if the best parameter is better than the init_ones
214
+ p_ones = init_ones(x, **kwargs)
215
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
216
+ loss_ones = nn.MSELoss()(x, x_)
217
+
218
+ # Checks if the best parameter is better than the init_rand
219
+ p_rand = init_rand(x, **kwargs)
220
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
221
+ loss_rand = nn.MSELoss()(x, x_)
222
+
223
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
224
+ return p_rand
225
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
226
+ return p_ones
227
+ else:
228
+ return best_params[0][0]
229
+
230
+
231
+ def init_linear_scale( # Symmetric scale. From the study folder
232
+ x: torch.Tensor,
233
+ **kwargs: Dict[str, Any],
234
+ ) -> torch.Tensor:
235
+ assert "bits" in kwargs, "bits must be provided."
236
+ assert "params" in kwargs, "params must be provided."
237
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
238
+
239
+ bits = kwargs.get('bits')
240
+ params = kwargs.get('params')
241
+ qtz_func = kwargs.get('qtz_func')
242
+
243
+ x_ = x.transpose(0, 1)
244
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
245
+ x_ = x_.transpose(0, 1)
246
+
247
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
248
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
249
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
250
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
251
+
252
+ eps = torch.finfo(torch.float32).eps
253
+
254
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
255
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
256
+
257
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
258
+
259
+ # Introduces some noise in scale
260
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
261
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
262
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
263
+ # left it here for future reference. Will be removed later.
264
+ # scale = scale + 0.01 * torch.randn_like(scale)
265
+
266
+ return scale
267
+
268
+
269
+ def init_non_linear_regression_fit(
270
+ x: torch.Tensor,
271
+ **kwargs: Dict[str, Any],
272
+ ) -> torch.Tensor:
273
+
274
+ assert "params_list" in kwargs, "params list must be provided."
275
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
276
+ assert "p0" in kwargs, "p0 must be provided."
277
+ np_fit_func = kwargs.get('np_fit_func')
278
+ params_list = kwargs.get('params_list')
279
+ p0 = kwargs.get('p0')
280
+
281
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
282
+ popt, _ = curve_fit(
283
+ func,
284
+ xdata,
285
+ ydata,
286
+ maxfev=1000,
287
+ p0=p0,
288
+ method='lm'
289
+ )
290
+ return popt
291
+
292
+ # 1. Needs to convert the torch tensor to numpy tensor
293
+ xdata = x.cpu().numpy()
294
+
295
+ # 2. Sorts the data so that it makes it easier to fit to it
296
+ sorted_xdata = np.sort(xdata, axis=-1)
297
+
298
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
299
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
300
+
301
+ # 3. Finds the best parameters for each channel
302
+ try:
303
+ params = []
304
+ for i in range(sorted_xdata.shape[0]):
305
+ xdata_ = sorted_xdata[i]
306
+ p0_ = [p0[p][i] for p in params_list]
307
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
308
+ params.append(ch_params)
309
+
310
+ # 4. Builds the parameters
311
+ result = {}
312
+ for i, p in enumerate(params_list):
313
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
314
+
315
+ return result
316
+
317
+ except ValueError as e:
318
+ print(f"Could not fit the function with error: {e}")
319
+ print(f"Using fallback result...")
320
+ return {
321
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
322
+ }
323
+
324
+
325
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
326
+ val = torch.amin(x, dim=1)
327
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
328
+
329
+
330
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
331
+ # Calculate the original minimum and maximum values
332
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
333
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
334
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
335
+
336
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
337
+ return torch.ones_like(x_min)
338
+
339
+ # Calculate the scale factor
340
+ scale = (_max - _min) / (x_max - x_min)
341
+ return scale
342
+
343
+
344
+
345
+ ############## Quant ###############
346
+
347
+ @torch.enable_grad()
348
+ def learn_parameters(
349
+ x: torch.Tensor,
350
+ params: Dict[str, nn.Parameter],
351
+ qtz_func: nn.Module,
352
+ deqtz_func: nn.Module,
353
+ bits: int,
354
+ target_dtype: torch.dtype,
355
+ epochs: int = 1000,
356
+ early_stop: bool = True,
357
+ do_report: bool = False
358
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
359
+ loss_fn = nn.MSELoss()
360
+
361
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
362
+ # the order of magnitude of the loss divided by 2
363
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
364
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
365
+ loss = loss_fn(x, dequant)
366
+
367
+ base_lr = 0.1
368
+ exponent = int(np.floor(np.log10(loss.item())))
369
+ lr = base_lr * (10 ** (exponent // 2))
370
+
371
+ # Requires gradients in the parameters
372
+ for p in params.values():
373
+ p.requires_grad = True
374
+ p.grad = None
375
+
376
+ param_keys = list(params.keys())
377
+ param_values = list(params.values())
378
+
379
+ # Defines optimizer and loss function
380
+ optimizer = torch.optim.Adam(param_values, lr=lr)
381
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
382
+
383
+ # Contains the best loss and the best parameters
384
+ best_loss = float("inf")
385
+ best_params = None
386
+
387
+ # Used to stop the search early
388
+ min_delta = 1e-7
389
+ acc_loss = []
390
+ percent_epochs_before_stop = 0.1
391
+
392
+ for i in range(epochs):
393
+ optimizer.zero_grad()
394
+
395
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
396
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
397
+ loss = loss_fn(x, dequant)
398
+
399
+ if loss.isnan() or loss.isinf():
400
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
401
+
402
+ loss.backward()
403
+ optimizer.step()
404
+ scheduler.step()
405
+
406
+ acc_loss.append(loss.item())
407
+
408
+ # Reports loss every 10 steps
409
+ if i % 10 == 0 and do_report:
410
+ print(f"Epoch {i}: Loss {loss.item()}")
411
+
412
+ # Optimizes the parameter search by storing the best loss and the parameters
413
+ if loss.item() < best_loss:
414
+ best_loss = loss.item()
415
+ best_params = copy.deepcopy({
416
+ k: v for k, v in params.items() if k in param_keys
417
+ })
418
+
419
+ # We also stop the search if the loss has not considerably during the last 10% epochs
420
+ if early_stop:
421
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
422
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
423
+ break
424
+
425
+ # No longer requires gradients in the parameters
426
+ for p in best_params.values():
427
+ p.requires_grad = False
428
+ p.grad = None
429
+
430
+ if do_report:
431
+ print(f"Best loss: {best_loss}")
432
+ return best_params, acc_loss
433
+ else:
434
+ return best_params
435
+
436
+
437
+ def quantize(
438
+ x: torch.Tensor,
439
+ params: Dict[str, nn.Parameter],
440
+ func: nn.Module,
441
+ bits: int,
442
+ target_dtype: torch.dtype = torch.int8
443
+ ) -> torch.Tensor:
444
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
445
+ x = x.transpose(0, 1) # Aligns shapes
446
+ x = func(x=x, **params)
447
+ x = x.transpose(0, 1)
448
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
449
+ return x
450
+
451
+
452
+ def dequantize(
453
+ x: torch.Tensor,
454
+ params: Dict[str, nn.Parameter],
455
+ func: nn.Module,
456
+ bits: int,
457
+ out_dtype: torch.dtype
458
+ ) -> torch.Tensor:
459
+ x = x.to(dtype=out_dtype)
460
+ x = x.transpose(0, 1)
461
+ x = func(x=x, **params)
462
+ x = x.transpose(0, 1)
463
+ return x
464
+
465
+
466
+ def round_func_BPDA(input):
467
+ # This is equivalent to replacing round function (non-differentiable) with
468
+ # an identity function (differentiable) only when backward.
469
+ forward_value = torch.round(input)
470
+ out = input.clone()
471
+ out.data = forward_value.data
472
+ return out
473
+
474
+
475
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
476
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
477
+
478
+
479
+
480
+ ############## Numpy ###############
481
+
482
+ def np_domain_guard(
483
+ x: np.ndarray,
484
+ min: float = None,
485
+ max: float = None,
486
+ posinf: float = None,
487
+ neginf: float = None,
488
+ nan: float = None
489
+ ) -> np.ndarray:
490
+ """Guard a tensor to a valid domain."""
491
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
492
+ if min is not None or max is not None:
493
+ x = np.clip(x, min, max)
494
+ return x
495
+
496
+
497
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
498
+ """Replace a number in a tensor with another number.
499
+
500
+ Args:
501
+ x (np.ndarray): The input tensor.
502
+ num (float): The number to replace.
503
+ to (float): The number to replace with.
504
+
505
+ Returns:
506
+ np.ndarray: The tensor with the number replaced.
507
+ """
508
+ return np.where(x == num, to, x)
509
+
510
+
511
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
512
+ """Guard the power operation to a valid domain."""
513
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
514
+
fn_gen/rnd_search_t_cos/15/loss.png ADDED
fn_gen/rnd_search_t_cos/15/quantization.png ADDED
fn_gen/rnd_search_t_cos/16/distortion.png ADDED
fn_gen/rnd_search_t_cos/16/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atan(_0*x)/_s
2
+ tan(_s*x)/_0
fn_gen/rnd_search_t_cos/16/fn.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atan((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tan(domain_guard((params['_s'] * x), posinf=1, neginf=-1, nan=0)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctan((_0 * x)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tan(np_domain_guard((_s * x), posinf=1, neginf=-1, nan=0)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
468
+
469
+ # Contains the best loss and the best parameters
470
+ best_loss = float("inf")
471
+ best_params = None
472
+
473
+ # Used to stop the search early
474
+ min_delta = 1e-7
475
+ acc_loss = []
476
+ percent_epochs_before_stop = 0.1
477
+
478
+ for i in range(epochs):
479
+ optimizer.zero_grad()
480
+
481
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
482
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
483
+ loss = loss_fn(x, dequant)
484
+
485
+ if loss.isnan() or loss.isinf():
486
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
487
+
488
+ loss.backward()
489
+ optimizer.step()
490
+ scheduler.step()
491
+
492
+ acc_loss.append(loss.item())
493
+
494
+ # Reports loss every 10 steps
495
+ if i % 10 == 0 and do_report:
496
+ print(f"Epoch {i}: Loss {loss.item()}")
497
+
498
+ # Optimizes the parameter search by storing the best loss and the parameters
499
+ if loss.item() < best_loss:
500
+ best_loss = loss.item()
501
+ best_params = copy.deepcopy({
502
+ k: v for k, v in params.items() if k in param_keys
503
+ })
504
+
505
+ # We also stop the search if the loss has not considerably during the last 10% epochs
506
+ if early_stop:
507
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
508
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
509
+ break
510
+
511
+ # No longer requires gradients in the parameters
512
+ for p in best_params.values():
513
+ p.requires_grad = False
514
+ p.grad = None
515
+
516
+ if do_report:
517
+ print(f"Best loss: {best_loss}")
518
+ return best_params, acc_loss
519
+ else:
520
+ return best_params
521
+
522
+
523
+ def quantize(
524
+ x: torch.Tensor,
525
+ params: Dict[str, nn.Parameter],
526
+ func: nn.Module,
527
+ bits: int,
528
+ target_dtype: torch.dtype = torch.int8
529
+ ) -> torch.Tensor:
530
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
531
+ x = x.transpose(0, 1) # Aligns shapes
532
+ x = func(x=x, **params)
533
+ x = x.transpose(0, 1)
534
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
535
+ return x
536
+
537
+
538
+ def dequantize(
539
+ x: torch.Tensor,
540
+ params: Dict[str, nn.Parameter],
541
+ func: nn.Module,
542
+ bits: int,
543
+ out_dtype: torch.dtype
544
+ ) -> torch.Tensor:
545
+ x = x.to(dtype=out_dtype)
546
+ x = x.transpose(0, 1)
547
+ x = func(x=x, **params)
548
+ x = x.transpose(0, 1)
549
+ return x
550
+
551
+
552
+ def round_func_BPDA(input):
553
+ # This is equivalent to replacing round function (non-differentiable) with
554
+ # an identity function (differentiable) only when backward.
555
+ forward_value = torch.round(input)
556
+ out = input.clone()
557
+ out.data = forward_value.data
558
+ return out
559
+
560
+
561
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
562
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
563
+
564
+
565
+
566
+ ############## Numpy ###############
567
+
568
+ def np_domain_guard(
569
+ x: np.ndarray,
570
+ min: float = None,
571
+ max: float = None,
572
+ posinf: float = None,
573
+ neginf: float = None,
574
+ nan: float = None
575
+ ) -> np.ndarray:
576
+ """Guard a tensor to a valid domain."""
577
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
578
+ if min is not None or max is not None:
579
+ x = np.clip(x, min, max)
580
+ return x
581
+
582
+
583
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
584
+ """Replace a number in a tensor with another number.
585
+
586
+ Args:
587
+ x (np.ndarray): The input tensor.
588
+ num (float): The number to replace.
589
+ to (float): The number to replace with.
590
+
591
+ Returns:
592
+ np.ndarray: The tensor with the number replaced.
593
+ """
594
+ return np.where(x == num, to, x)
595
+
596
+
597
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
598
+ """Guard the power operation to a valid domain."""
599
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
600
+
fn_gen/rnd_search_t_cos/16/loss.png ADDED
fn_gen/rnd_search_t_cos/16/quantization.png ADDED
fn_gen/rnd_search_t_cos/17/distortion.png ADDED
fn_gen/rnd_search_t_cos/17/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**3/_s
2
+ (_s*x)**(1/3)
fn_gen/rnd_search_t_cos/17/fn.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(3)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return guarded_torch_power((params['_s'] * x), 1 / 3)
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
52
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
53
+ # left it here for future reference. Will be removed later.
54
+ # scale = scale + 0.01 * torch.randn_like(scale)
55
+
56
+ return scale
57
+
58
+
59
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
60
+ params = {
61
+ }
62
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
63
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
64
+
65
+ if 'post_init_hook' in kwargs:
66
+ kwargs['post_init_hook'](parameters=params)
67
+
68
+ params = learn_parameters(x, params,
69
+ qtz_func=quantization,
70
+ deqtz_func=dequantization,
71
+ bits=kwargs['bits'],
72
+ target_dtype=torch.int8,
73
+ epochs=500,
74
+ early_stop=False,
75
+ )
76
+ if 'post_train_hook' in kwargs:
77
+ kwargs['post_train_hook'](parameters=params)
78
+
79
+ return params
80
+
81
+
82
+ ############### Numpy Qtz ###############
83
+
84
+
85
+ def np_quantization(x, _s):
86
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(3)))
87
+
88
+
89
+ def np_dequantization(x, _s):
90
+ return np_guarded_power((_s * x), 1 / 3)
91
+
92
+
93
+ def fit_func(x, _s):
94
+ x_ = np_quantization(x, _s)
95
+ x_ = np_dequantization(x_, _s)
96
+ return x_
97
+
98
+
99
+
100
+ ############### HELPERS ###############
101
+
102
+ def domain_guard(
103
+ x: torch.Tensor,
104
+ min: float = None,
105
+ max: float = None,
106
+ posinf: float = None,
107
+ neginf: float = None,
108
+ nan: float = None
109
+ ) -> torch.Tensor:
110
+ """Guard a tensor to a valid domain."""
111
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
112
+ if min is not None or max is not None:
113
+ x = torch.clamp(x, min=min, max=max)
114
+ return x
115
+
116
+
117
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
118
+ """Replace a number in a tensor with another number.
119
+
120
+ Args:
121
+ x (torch.Tensor): The input tensor.
122
+ num (float): The number to replace.
123
+ to (float): The number to replace with.
124
+
125
+ Returns:
126
+ torch.Tensor: The tensor with the number replaced.
127
+ """
128
+ return torch.where(x == num, to, x)
129
+
130
+
131
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
132
+ """Guard the power operation to a valid domain."""
133
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
134
+
135
+
136
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
137
+ val = torch.amin(x, dim=1)
138
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
139
+
140
+
141
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
142
+ val = torch.amin(x, dim=1)
143
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
144
+
145
+
146
+ def init_space_search(
147
+ x: torch.Tensor,
148
+ **kwargs: Dict[str, Any],
149
+ ) -> torch.Tensor:
150
+
151
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
152
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
153
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
154
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
155
+
156
+ def _search_param(tensors: List[torch.tensor], n_params):
157
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
158
+ torch_tensors = torch.stack(tensors)
159
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
160
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
161
+ mean = torch.mean(torch_tensors, dim=0)
162
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
163
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
164
+
165
+ def _calc(x, qtz_func, deqtz_func, **params):
166
+ x_ = x.transpose(0, 1)
167
+ x_ = qtz_func(x=x_, **params)
168
+ x_ = deqtz_func(x=x_, **params)
169
+ x_ = x_.transpose(0, 1)
170
+ return x_
171
+
172
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
173
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
174
+ assert "params_list" in kwargs, "params list must be provided."
175
+ assert "param" in kwargs, "param must be provided."
176
+
177
+ qtz_func = kwargs.get('qtz_func')
178
+ deqtz_func = kwargs.get('deqtz_func')
179
+ params_list = kwargs.get('params_list')
180
+ param = kwargs.get('param')
181
+
182
+ n_runs = 50 # Number of runs to try to find the best parameters
183
+ n_random_params = 50 # Number of random parameters to generate
184
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
185
+ max_initial = 10000 # Maximum value to initialize the parameters
186
+
187
+ # Initializes the parameters
188
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
189
+ params = _build_initial_param(x, max_initial, n_random_params)
190
+
191
+ # Performs the search
192
+ for _ in range(n_runs):
193
+
194
+ best_params = []
195
+ for param_ in params:
196
+ try:
197
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
198
+ loss_ones = nn.MSELoss()(x, x_)
199
+
200
+ if len(best_params) < n_best_to_pick:
201
+ best_params.append((param_, loss_ones.item()))
202
+ best_params = sorted(best_params, key=lambda x: x[1])
203
+ elif loss_ones < best_params[-1][1]:
204
+ best_params[-1] = (param_, loss_ones.item())
205
+ best_params = sorted(best_params, key=lambda x: x[1])
206
+
207
+ except Exception: # The parameters might not be valid for the function's domain
208
+ continue
209
+
210
+ # Generates new parameters around the mean
211
+ params = _search_param([p for p, _ in best_params], n_random_params)
212
+
213
+ # Checks if the best parameter is better than the init_ones
214
+ p_ones = init_ones(x, **kwargs)
215
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
216
+ loss_ones = nn.MSELoss()(x, x_)
217
+
218
+ # Checks if the best parameter is better than the init_rand
219
+ p_rand = init_rand(x, **kwargs)
220
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
221
+ loss_rand = nn.MSELoss()(x, x_)
222
+
223
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
224
+ return p_rand
225
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
226
+ return p_ones
227
+ else:
228
+ return best_params[0][0]
229
+
230
+
231
+ def init_linear_scale( # Symmetric scale. From the study folder
232
+ x: torch.Tensor,
233
+ **kwargs: Dict[str, Any],
234
+ ) -> torch.Tensor:
235
+ assert "bits" in kwargs, "bits must be provided."
236
+ assert "params" in kwargs, "params must be provided."
237
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
238
+
239
+ bits = kwargs.get('bits')
240
+ params = kwargs.get('params')
241
+ qtz_func = kwargs.get('qtz_func')
242
+
243
+ x_ = x.transpose(0, 1)
244
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
245
+ x_ = x_.transpose(0, 1)
246
+
247
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
248
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
249
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
250
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
251
+
252
+ eps = torch.finfo(torch.float32).eps
253
+
254
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
255
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
256
+
257
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
258
+
259
+ # Introduces some noise in scale
260
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
261
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
262
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
263
+ # left it here for future reference. Will be removed later.
264
+ # scale = scale + 0.01 * torch.randn_like(scale)
265
+
266
+ return scale
267
+
268
+
269
+ def init_non_linear_regression_fit(
270
+ x: torch.Tensor,
271
+ **kwargs: Dict[str, Any],
272
+ ) -> torch.Tensor:
273
+
274
+ assert "params_list" in kwargs, "params list must be provided."
275
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
276
+ assert "p0" in kwargs, "p0 must be provided."
277
+ np_fit_func = kwargs.get('np_fit_func')
278
+ params_list = kwargs.get('params_list')
279
+ p0 = kwargs.get('p0')
280
+
281
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
282
+ popt, _ = curve_fit(
283
+ func,
284
+ xdata,
285
+ ydata,
286
+ maxfev=1000,
287
+ p0=p0,
288
+ method='lm'
289
+ )
290
+ return popt
291
+
292
+ # 1. Needs to convert the torch tensor to numpy tensor
293
+ xdata = x.cpu().numpy()
294
+
295
+ # 2. Sorts the data so that it makes it easier to fit to it
296
+ sorted_xdata = np.sort(xdata, axis=-1)
297
+
298
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
299
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
300
+
301
+ # 3. Finds the best parameters for each channel
302
+ try:
303
+ params = []
304
+ for i in range(sorted_xdata.shape[0]):
305
+ xdata_ = sorted_xdata[i]
306
+ p0_ = [p0[p][i] for p in params_list]
307
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
308
+ params.append(ch_params)
309
+
310
+ # 4. Builds the parameters
311
+ result = {}
312
+ for i, p in enumerate(params_list):
313
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
314
+
315
+ return result
316
+
317
+ except ValueError as e:
318
+ print(f"Could not fit the function with error: {e}")
319
+ print(f"Using fallback result...")
320
+ return {
321
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
322
+ }
323
+
324
+
325
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
326
+ val = torch.amin(x, dim=1)
327
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
328
+
329
+
330
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
331
+ # Calculate the original minimum and maximum values
332
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
333
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
334
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
335
+
336
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
337
+ return torch.ones_like(x_min)
338
+
339
+ # Calculate the scale factor
340
+ scale = (_max - _min) / (x_max - x_min)
341
+ return scale
342
+
343
+
344
+
345
+ ############## Quant ###############
346
+
347
+ @torch.enable_grad()
348
+ def learn_parameters(
349
+ x: torch.Tensor,
350
+ params: Dict[str, nn.Parameter],
351
+ qtz_func: nn.Module,
352
+ deqtz_func: nn.Module,
353
+ bits: int,
354
+ target_dtype: torch.dtype,
355
+ epochs: int = 1000,
356
+ early_stop: bool = True,
357
+ do_report: bool = False
358
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
359
+ loss_fn = nn.MSELoss()
360
+
361
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
362
+ # the order of magnitude of the loss divided by 2
363
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
364
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
365
+ loss = loss_fn(x, dequant)
366
+
367
+ base_lr = 0.1
368
+ exponent = int(np.floor(np.log10(loss.item())))
369
+ lr = base_lr * (10 ** (exponent // 2))
370
+
371
+ # Requires gradients in the parameters
372
+ for p in params.values():
373
+ p.requires_grad = True
374
+ p.grad = None
375
+
376
+ param_keys = list(params.keys())
377
+ param_values = list(params.values())
378
+
379
+ # Defines optimizer and loss function
380
+ optimizer = torch.optim.Adam(param_values, lr=lr)
381
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
382
+
383
+ # Contains the best loss and the best parameters
384
+ best_loss = float("inf")
385
+ best_params = None
386
+
387
+ # Used to stop the search early
388
+ min_delta = 1e-7
389
+ acc_loss = []
390
+ percent_epochs_before_stop = 0.1
391
+
392
+ for i in range(epochs):
393
+ optimizer.zero_grad()
394
+
395
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
396
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
397
+ loss = loss_fn(x, dequant)
398
+
399
+ if loss.isnan() or loss.isinf():
400
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
401
+
402
+ loss.backward()
403
+ optimizer.step()
404
+ scheduler.step()
405
+
406
+ acc_loss.append(loss.item())
407
+
408
+ # Reports loss every 10 steps
409
+ if i % 10 == 0 and do_report:
410
+ print(f"Epoch {i}: Loss {loss.item()}")
411
+
412
+ # Optimizes the parameter search by storing the best loss and the parameters
413
+ if loss.item() < best_loss:
414
+ best_loss = loss.item()
415
+ best_params = copy.deepcopy({
416
+ k: v for k, v in params.items() if k in param_keys
417
+ })
418
+
419
+ # We also stop the search if the loss has not considerably during the last 10% epochs
420
+ if early_stop:
421
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
422
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
423
+ break
424
+
425
+ # No longer requires gradients in the parameters
426
+ for p in best_params.values():
427
+ p.requires_grad = False
428
+ p.grad = None
429
+
430
+ if do_report:
431
+ print(f"Best loss: {best_loss}")
432
+ return best_params, acc_loss
433
+ else:
434
+ return best_params
435
+
436
+
437
+ def quantize(
438
+ x: torch.Tensor,
439
+ params: Dict[str, nn.Parameter],
440
+ func: nn.Module,
441
+ bits: int,
442
+ target_dtype: torch.dtype = torch.int8
443
+ ) -> torch.Tensor:
444
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
445
+ x = x.transpose(0, 1) # Aligns shapes
446
+ x = func(x=x, **params)
447
+ x = x.transpose(0, 1)
448
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
449
+ return x
450
+
451
+
452
+ def dequantize(
453
+ x: torch.Tensor,
454
+ params: Dict[str, nn.Parameter],
455
+ func: nn.Module,
456
+ bits: int,
457
+ out_dtype: torch.dtype
458
+ ) -> torch.Tensor:
459
+ x = x.to(dtype=out_dtype)
460
+ x = x.transpose(0, 1)
461
+ x = func(x=x, **params)
462
+ x = x.transpose(0, 1)
463
+ return x
464
+
465
+
466
+ def round_func_BPDA(input):
467
+ # This is equivalent to replacing round function (non-differentiable) with
468
+ # an identity function (differentiable) only when backward.
469
+ forward_value = torch.round(input)
470
+ out = input.clone()
471
+ out.data = forward_value.data
472
+ return out
473
+
474
+
475
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
476
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
477
+
478
+
479
+
480
+ ############## Numpy ###############
481
+
482
+ def np_domain_guard(
483
+ x: np.ndarray,
484
+ min: float = None,
485
+ max: float = None,
486
+ posinf: float = None,
487
+ neginf: float = None,
488
+ nan: float = None
489
+ ) -> np.ndarray:
490
+ """Guard a tensor to a valid domain."""
491
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
492
+ if min is not None or max is not None:
493
+ x = np.clip(x, min, max)
494
+ return x
495
+
496
+
497
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
498
+ """Replace a number in a tensor with another number.
499
+
500
+ Args:
501
+ x (np.ndarray): The input tensor.
502
+ num (float): The number to replace.
503
+ to (float): The number to replace with.
504
+
505
+ Returns:
506
+ np.ndarray: The tensor with the number replaced.
507
+ """
508
+ return np.where(x == num, to, x)
509
+
510
+
511
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
512
+ """Guard the power operation to a valid domain."""
513
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
514
+
fn_gen/rnd_search_t_cos/17/loss.png ADDED
fn_gen/rnd_search_t_cos/17/quantization.png ADDED
fn_gen/rnd_search_t_cos/18/distortion.png ADDED
fn_gen/rnd_search_t_cos/18/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acosh(_0*x)/_s
2
+ cosh(_s*x)/_0
fn_gen/rnd_search_t_cos/18/fn.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acosh(domain_guard((params['_0'] * x), min=1, nan=1)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cosh((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccosh(np_domain_guard((_0 * x), min=1, nan=1)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cosh((_s * x)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
468
+
469
+ # Contains the best loss and the best parameters
470
+ best_loss = float("inf")
471
+ best_params = None
472
+
473
+ # Used to stop the search early
474
+ min_delta = 1e-7
475
+ acc_loss = []
476
+ percent_epochs_before_stop = 0.1
477
+
478
+ for i in range(epochs):
479
+ optimizer.zero_grad()
480
+
481
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
482
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
483
+ loss = loss_fn(x, dequant)
484
+
485
+ if loss.isnan() or loss.isinf():
486
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
487
+
488
+ loss.backward()
489
+ optimizer.step()
490
+ scheduler.step()
491
+
492
+ acc_loss.append(loss.item())
493
+
494
+ # Reports loss every 10 steps
495
+ if i % 10 == 0 and do_report:
496
+ print(f"Epoch {i}: Loss {loss.item()}")
497
+
498
+ # Optimizes the parameter search by storing the best loss and the parameters
499
+ if loss.item() < best_loss:
500
+ best_loss = loss.item()
501
+ best_params = copy.deepcopy({
502
+ k: v for k, v in params.items() if k in param_keys
503
+ })
504
+
505
+ # We also stop the search if the loss has not considerably during the last 10% epochs
506
+ if early_stop:
507
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
508
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
509
+ break
510
+
511
+ # No longer requires gradients in the parameters
512
+ for p in best_params.values():
513
+ p.requires_grad = False
514
+ p.grad = None
515
+
516
+ if do_report:
517
+ print(f"Best loss: {best_loss}")
518
+ return best_params, acc_loss
519
+ else:
520
+ return best_params
521
+
522
+
523
+ def quantize(
524
+ x: torch.Tensor,
525
+ params: Dict[str, nn.Parameter],
526
+ func: nn.Module,
527
+ bits: int,
528
+ target_dtype: torch.dtype = torch.int8
529
+ ) -> torch.Tensor:
530
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
531
+ x = x.transpose(0, 1) # Aligns shapes
532
+ x = func(x=x, **params)
533
+ x = x.transpose(0, 1)
534
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
535
+ return x
536
+
537
+
538
+ def dequantize(
539
+ x: torch.Tensor,
540
+ params: Dict[str, nn.Parameter],
541
+ func: nn.Module,
542
+ bits: int,
543
+ out_dtype: torch.dtype
544
+ ) -> torch.Tensor:
545
+ x = x.to(dtype=out_dtype)
546
+ x = x.transpose(0, 1)
547
+ x = func(x=x, **params)
548
+ x = x.transpose(0, 1)
549
+ return x
550
+
551
+
552
+ def round_func_BPDA(input):
553
+ # This is equivalent to replacing round function (non-differentiable) with
554
+ # an identity function (differentiable) only when backward.
555
+ forward_value = torch.round(input)
556
+ out = input.clone()
557
+ out.data = forward_value.data
558
+ return out
559
+
560
+
561
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
562
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
563
+
564
+
565
+
566
+ ############## Numpy ###############
567
+
568
+ def np_domain_guard(
569
+ x: np.ndarray,
570
+ min: float = None,
571
+ max: float = None,
572
+ posinf: float = None,
573
+ neginf: float = None,
574
+ nan: float = None
575
+ ) -> np.ndarray:
576
+ """Guard a tensor to a valid domain."""
577
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
578
+ if min is not None or max is not None:
579
+ x = np.clip(x, min, max)
580
+ return x
581
+
582
+
583
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
584
+ """Replace a number in a tensor with another number.
585
+
586
+ Args:
587
+ x (np.ndarray): The input tensor.
588
+ num (float): The number to replace.
589
+ to (float): The number to replace with.
590
+
591
+ Returns:
592
+ np.ndarray: The tensor with the number replaced.
593
+ """
594
+ return np.where(x == num, to, x)
595
+
596
+
597
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
598
+ """Guard the power operation to a valid domain."""
599
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
600
+
fn_gen/rnd_search_t_cos/18/loss.png ADDED
fn_gen/rnd_search_t_cos/18/quantization.png ADDED
fn_gen/rnd_search_t_cos/2/distortion.png ADDED
fn_gen/rnd_search_t_cos/2/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cos(_0*x)/_s
2
+ acos(_s*x)/_0
fn_gen/rnd_search_t_cos/2/fn.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.cos((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.acos(domain_guard((params['_s'] * x), min=-0.99999, max=0.99999, nan=0)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ # Checks if the best parameter is better than the init_ones
89
+ p_ones = init_ones(x, **kwargs)
90
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
91
+ loss_ones = nn.MSELoss()(x, x_)
92
+
93
+ # Checks if the best parameter is better than the init_rand
94
+ p_rand = init_rand(x, **kwargs)
95
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
96
+ loss_rand = nn.MSELoss()(x, x_)
97
+
98
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
99
+ return p_rand
100
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
101
+ return p_ones
102
+ else:
103
+ return best_params[0][0]
104
+
105
+
106
+ def init_linear_scale( # Symmetric scale. From the study folder
107
+ x: torch.Tensor,
108
+ **kwargs: Dict[str, Any],
109
+ ) -> torch.Tensor:
110
+ assert "bits" in kwargs, "bits must be provided."
111
+ assert "params" in kwargs, "params must be provided."
112
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
113
+
114
+ bits = kwargs.get('bits')
115
+ params = kwargs.get('params')
116
+ qtz_func = kwargs.get('qtz_func')
117
+
118
+ x_ = x.transpose(0, 1)
119
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
120
+ x_ = x_.transpose(0, 1)
121
+
122
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
123
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
124
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
125
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
126
+
127
+ eps = torch.finfo(torch.float32).eps
128
+
129
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
130
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
131
+
132
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
133
+
134
+ # Introduces some noise in scale
135
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
136
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
137
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
138
+ # left it here for future reference. Will be removed later.
139
+ # scale = scale + 0.01 * torch.randn_like(scale)
140
+
141
+ return scale
142
+
143
+
144
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
145
+ params = {
146
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
147
+ }
148
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
149
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
150
+
151
+ if 'post_init_hook' in kwargs:
152
+ kwargs['post_init_hook'](parameters=params)
153
+
154
+ params = learn_parameters(x, params,
155
+ qtz_func=quantization,
156
+ deqtz_func=dequantization,
157
+ bits=kwargs['bits'],
158
+ target_dtype=torch.int8,
159
+ epochs=500,
160
+ early_stop=False,
161
+ )
162
+ if 'post_train_hook' in kwargs:
163
+ kwargs['post_train_hook'](parameters=params)
164
+
165
+ return params
166
+
167
+
168
+ ############### Numpy Qtz ###############
169
+
170
+
171
+ def np_quantization(x, _0, _s):
172
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.cos((_0 * x)))
173
+
174
+
175
+ def np_dequantization(x, _0, _s):
176
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arccos(np_domain_guard((_s * x), min=-0.99999, max=0.99999, nan=0)))
177
+
178
+
179
+ def fit_func(x, _0, _s):
180
+ x_ = np_quantization(x, _0, _s)
181
+ x_ = np_dequantization(x_, _0, _s)
182
+ return x_
183
+
184
+
185
+
186
+ ############### HELPERS ###############
187
+
188
+ def domain_guard(
189
+ x: torch.Tensor,
190
+ min: float = None,
191
+ max: float = None,
192
+ posinf: float = None,
193
+ neginf: float = None,
194
+ nan: float = None
195
+ ) -> torch.Tensor:
196
+ """Guard a tensor to a valid domain."""
197
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
198
+ if min is not None or max is not None:
199
+ x = torch.clamp(x, min=min, max=max)
200
+ return x
201
+
202
+
203
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
204
+ """Replace a number in a tensor with another number.
205
+
206
+ Args:
207
+ x (torch.Tensor): The input tensor.
208
+ num (float): The number to replace.
209
+ to (float): The number to replace with.
210
+
211
+ Returns:
212
+ torch.Tensor: The tensor with the number replaced.
213
+ """
214
+ return torch.where(x == num, to, x)
215
+
216
+
217
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
218
+ """Guard the power operation to a valid domain."""
219
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
220
+
221
+
222
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
223
+ val = torch.amin(x, dim=1)
224
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
225
+
226
+
227
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
228
+ val = torch.amin(x, dim=1)
229
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
230
+
231
+
232
+ def init_space_search(
233
+ x: torch.Tensor,
234
+ **kwargs: Dict[str, Any],
235
+ ) -> torch.Tensor:
236
+
237
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
238
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
239
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
240
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
241
+
242
+ def _search_param(tensors: List[torch.tensor], n_params):
243
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
244
+ torch_tensors = torch.stack(tensors)
245
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
246
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
247
+ mean = torch.mean(torch_tensors, dim=0)
248
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
249
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
250
+
251
+ def _calc(x, qtz_func, deqtz_func, **params):
252
+ x_ = x.transpose(0, 1)
253
+ x_ = qtz_func(x=x_, **params)
254
+ x_ = deqtz_func(x=x_, **params)
255
+ x_ = x_.transpose(0, 1)
256
+ return x_
257
+
258
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
259
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
260
+ assert "params_list" in kwargs, "params list must be provided."
261
+ assert "param" in kwargs, "param must be provided."
262
+
263
+ qtz_func = kwargs.get('qtz_func')
264
+ deqtz_func = kwargs.get('deqtz_func')
265
+ params_list = kwargs.get('params_list')
266
+ param = kwargs.get('param')
267
+
268
+ n_runs = 50 # Number of runs to try to find the best parameters
269
+ n_random_params = 50 # Number of random parameters to generate
270
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
271
+ max_initial = 10000 # Maximum value to initialize the parameters
272
+
273
+ # Initializes the parameters
274
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
275
+ params = _build_initial_param(x, max_initial, n_random_params)
276
+
277
+ # Performs the search
278
+ for _ in range(n_runs):
279
+
280
+ best_params = []
281
+ for param_ in params:
282
+ try:
283
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
284
+ loss_ones = nn.MSELoss()(x, x_)
285
+
286
+ if len(best_params) < n_best_to_pick:
287
+ best_params.append((param_, loss_ones.item()))
288
+ best_params = sorted(best_params, key=lambda x: x[1])
289
+ elif loss_ones < best_params[-1][1]:
290
+ best_params[-1] = (param_, loss_ones.item())
291
+ best_params = sorted(best_params, key=lambda x: x[1])
292
+
293
+ except Exception: # The parameters might not be valid for the function's domain
294
+ continue
295
+
296
+ # Generates new parameters around the mean
297
+ params = _search_param([p for p, _ in best_params], n_random_params)
298
+
299
+ # Checks if the best parameter is better than the init_ones
300
+ p_ones = init_ones(x, **kwargs)
301
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
302
+ loss_ones = nn.MSELoss()(x, x_)
303
+
304
+ # Checks if the best parameter is better than the init_rand
305
+ p_rand = init_rand(x, **kwargs)
306
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
307
+ loss_rand = nn.MSELoss()(x, x_)
308
+
309
+ if loss_rand < best_params[0][1] and loss_rand < loss_ones:
310
+ return p_rand
311
+ elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
312
+ return p_ones
313
+ else:
314
+ return best_params[0][0]
315
+
316
+
317
+ def init_linear_scale( # Symmetric scale. From the study folder
318
+ x: torch.Tensor,
319
+ **kwargs: Dict[str, Any],
320
+ ) -> torch.Tensor:
321
+ assert "bits" in kwargs, "bits must be provided."
322
+ assert "params" in kwargs, "params must be provided."
323
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
324
+
325
+ bits = kwargs.get('bits')
326
+ params = kwargs.get('params')
327
+ qtz_func = kwargs.get('qtz_func')
328
+
329
+ x_ = x.transpose(0, 1)
330
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
331
+ x_ = x_.transpose(0, 1)
332
+
333
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
334
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
335
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
336
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
337
+
338
+ eps = torch.finfo(torch.float32).eps
339
+
340
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
341
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
342
+
343
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
344
+
345
+ # Introduces some noise in scale
346
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
347
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything.
348
+ # NOTE(diogo): This has been disproven. The noise does not help the learning process but I still
349
+ # left it here for future reference. Will be removed later.
350
+ # scale = scale + 0.01 * torch.randn_like(scale)
351
+
352
+ return scale
353
+
354
+
355
+ def init_non_linear_regression_fit(
356
+ x: torch.Tensor,
357
+ **kwargs: Dict[str, Any],
358
+ ) -> torch.Tensor:
359
+
360
+ assert "params_list" in kwargs, "params list must be provided."
361
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
362
+ assert "p0" in kwargs, "p0 must be provided."
363
+ np_fit_func = kwargs.get('np_fit_func')
364
+ params_list = kwargs.get('params_list')
365
+ p0 = kwargs.get('p0')
366
+
367
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
368
+ popt, _ = curve_fit(
369
+ func,
370
+ xdata,
371
+ ydata,
372
+ maxfev=1000,
373
+ p0=p0,
374
+ method='lm'
375
+ )
376
+ return popt
377
+
378
+ # 1. Needs to convert the torch tensor to numpy tensor
379
+ xdata = x.cpu().numpy()
380
+
381
+ # 2. Sorts the data so that it makes it easier to fit to it
382
+ sorted_xdata = np.sort(xdata, axis=-1)
383
+
384
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
385
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
386
+
387
+ # 3. Finds the best parameters for each channel
388
+ try:
389
+ params = []
390
+ for i in range(sorted_xdata.shape[0]):
391
+ xdata_ = sorted_xdata[i]
392
+ p0_ = [p0[p][i] for p in params_list]
393
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
394
+ params.append(ch_params)
395
+
396
+ # 4. Builds the parameters
397
+ result = {}
398
+ for i, p in enumerate(params_list):
399
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
400
+
401
+ return result
402
+
403
+ except ValueError as e:
404
+ print(f"Could not fit the function with error: {e}")
405
+ print(f"Using fallback result...")
406
+ return {
407
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
408
+ }
409
+
410
+
411
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
412
+ val = torch.amin(x, dim=1)
413
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
414
+
415
+
416
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
417
+ # Calculate the original minimum and maximum values
418
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
419
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
420
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
421
+
422
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
423
+ return torch.ones_like(x_min)
424
+
425
+ # Calculate the scale factor
426
+ scale = (_max - _min) / (x_max - x_min)
427
+ return scale
428
+
429
+
430
+
431
+ ############## Quant ###############
432
+
433
+ @torch.enable_grad()
434
+ def learn_parameters(
435
+ x: torch.Tensor,
436
+ params: Dict[str, nn.Parameter],
437
+ qtz_func: nn.Module,
438
+ deqtz_func: nn.Module,
439
+ bits: int,
440
+ target_dtype: torch.dtype,
441
+ epochs: int = 1000,
442
+ early_stop: bool = True,
443
+ do_report: bool = False
444
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
445
+ loss_fn = nn.MSELoss()
446
+
447
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
448
+ # the order of magnitude of the loss divided by 2
449
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
450
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
451
+ loss = loss_fn(x, dequant)
452
+
453
+ base_lr = 0.1
454
+ exponent = int(np.floor(np.log10(loss.item())))
455
+ lr = base_lr * (10 ** (exponent // 2))
456
+
457
+ # Requires gradients in the parameters
458
+ for p in params.values():
459
+ p.requires_grad = True
460
+ p.grad = None
461
+
462
+ param_keys = list(params.keys())
463
+ param_values = list(params.values())
464
+
465
+ # Defines optimizer and loss function
466
+ optimizer = torch.optim.Adam(param_values, lr=lr)
467
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs // 10, T_mult=1, eta_min=lr * 0.1, last_epoch=-1)
468
+
469
+ # Contains the best loss and the best parameters
470
+ best_loss = float("inf")
471
+ best_params = None
472
+
473
+ # Used to stop the search early
474
+ min_delta = 1e-7
475
+ acc_loss = []
476
+ percent_epochs_before_stop = 0.1
477
+
478
+ for i in range(epochs):
479
+ optimizer.zero_grad()
480
+
481
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
482
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
483
+ loss = loss_fn(x, dequant)
484
+
485
+ if loss.isnan() or loss.isinf():
486
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
487
+
488
+ loss.backward()
489
+ optimizer.step()
490
+ scheduler.step()
491
+
492
+ acc_loss.append(loss.item())
493
+
494
+ # Reports loss every 10 steps
495
+ if i % 10 == 0 and do_report:
496
+ print(f"Epoch {i}: Loss {loss.item()}")
497
+
498
+ # Optimizes the parameter search by storing the best loss and the parameters
499
+ if loss.item() < best_loss:
500
+ best_loss = loss.item()
501
+ best_params = copy.deepcopy({
502
+ k: v for k, v in params.items() if k in param_keys
503
+ })
504
+
505
+ # We also stop the search if the loss has not considerably during the last 10% epochs
506
+ if early_stop:
507
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
508
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
509
+ break
510
+
511
+ # No longer requires gradients in the parameters
512
+ for p in best_params.values():
513
+ p.requires_grad = False
514
+ p.grad = None
515
+
516
+ if do_report:
517
+ print(f"Best loss: {best_loss}")
518
+ return best_params, acc_loss
519
+ else:
520
+ return best_params
521
+
522
+
523
+ def quantize(
524
+ x: torch.Tensor,
525
+ params: Dict[str, nn.Parameter],
526
+ func: nn.Module,
527
+ bits: int,
528
+ target_dtype: torch.dtype = torch.int8
529
+ ) -> torch.Tensor:
530
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
531
+ x = x.transpose(0, 1) # Aligns shapes
532
+ x = func(x=x, **params)
533
+ x = x.transpose(0, 1)
534
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
535
+ return x
536
+
537
+
538
+ def dequantize(
539
+ x: torch.Tensor,
540
+ params: Dict[str, nn.Parameter],
541
+ func: nn.Module,
542
+ bits: int,
543
+ out_dtype: torch.dtype
544
+ ) -> torch.Tensor:
545
+ x = x.to(dtype=out_dtype)
546
+ x = x.transpose(0, 1)
547
+ x = func(x=x, **params)
548
+ x = x.transpose(0, 1)
549
+ return x
550
+
551
+
552
+ def round_func_BPDA(input):
553
+ # This is equivalent to replacing round function (non-differentiable) with
554
+ # an identity function (differentiable) only when backward.
555
+ forward_value = torch.round(input)
556
+ out = input.clone()
557
+ out.data = forward_value.data
558
+ return out
559
+
560
+
561
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
562
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
563
+
564
+
565
+
566
+ ############## Numpy ###############
567
+
568
+ def np_domain_guard(
569
+ x: np.ndarray,
570
+ min: float = None,
571
+ max: float = None,
572
+ posinf: float = None,
573
+ neginf: float = None,
574
+ nan: float = None
575
+ ) -> np.ndarray:
576
+ """Guard a tensor to a valid domain."""
577
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
578
+ if min is not None or max is not None:
579
+ x = np.clip(x, min, max)
580
+ return x
581
+
582
+
583
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
584
+ """Replace a number in a tensor with another number.
585
+
586
+ Args:
587
+ x (np.ndarray): The input tensor.
588
+ num (float): The number to replace.
589
+ to (float): The number to replace with.
590
+
591
+ Returns:
592
+ np.ndarray: The tensor with the number replaced.
593
+ """
594
+ return np.where(x == num, to, x)
595
+
596
+
597
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
598
+ """Guard the power operation to a valid domain."""
599
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
600
+
fn_gen/rnd_search_t_cos/2/loss.png ADDED
fn_gen/rnd_search_t_cos/2/quantization.png ADDED