Diogo-V commited on
Commit
f138151
1 Parent(s): f89e433

Upload learned functions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fn_gen/rnd_search/0/__pycache__/fn.cpython-311.pyc +0 -0
  2. fn_gen/rnd_search/0/distortion.png +0 -0
  3. fn_gen/rnd_search/0/expressions.txt +2 -0
  4. fn_gen/rnd_search/0/fn.py +554 -0
  5. fn_gen/rnd_search/0/loss.png +0 -0
  6. fn_gen/rnd_search/0/quantization.png +0 -0
  7. fn_gen/rnd_search/1/__pycache__/fn.cpython-311.pyc +0 -0
  8. fn_gen/rnd_search/1/distortion.png +0 -0
  9. fn_gen/rnd_search/1/expressions.txt +2 -0
  10. fn_gen/rnd_search/1/fn.py +554 -0
  11. fn_gen/rnd_search/1/loss.png +0 -0
  12. fn_gen/rnd_search/1/quantization.png +0 -0
  13. fn_gen/rnd_search/10/__pycache__/fn.cpython-311.pyc +0 -0
  14. fn_gen/rnd_search/10/distortion.png +0 -0
  15. fn_gen/rnd_search/10/expressions.txt +2 -0
  16. fn_gen/rnd_search/10/fn.py +483 -0
  17. fn_gen/rnd_search/10/loss.png +0 -0
  18. fn_gen/rnd_search/10/quantization.png +0 -0
  19. fn_gen/rnd_search/11/__pycache__/fn.cpython-311.pyc +0 -0
  20. fn_gen/rnd_search/11/distortion.png +0 -0
  21. fn_gen/rnd_search/11/expressions.txt +2 -0
  22. fn_gen/rnd_search/11/fn.py +554 -0
  23. fn_gen/rnd_search/11/loss.png +0 -0
  24. fn_gen/rnd_search/11/quantization.png +0 -0
  25. fn_gen/rnd_search/12/__pycache__/fn.cpython-311.pyc +0 -0
  26. fn_gen/rnd_search/12/distortion.png +0 -0
  27. fn_gen/rnd_search/12/expressions.txt +2 -0
  28. fn_gen/rnd_search/12/fn.py +554 -0
  29. fn_gen/rnd_search/12/loss.png +0 -0
  30. fn_gen/rnd_search/12/quantization.png +0 -0
  31. fn_gen/rnd_search/13/__pycache__/fn.cpython-311.pyc +0 -0
  32. fn_gen/rnd_search/13/distortion.png +0 -0
  33. fn_gen/rnd_search/13/expressions.txt +2 -0
  34. fn_gen/rnd_search/13/fn.py +554 -0
  35. fn_gen/rnd_search/13/loss.png +0 -0
  36. fn_gen/rnd_search/13/quantization.png +0 -0
  37. fn_gen/rnd_search/14/__pycache__/fn.cpython-311.pyc +0 -0
  38. fn_gen/rnd_search/14/distortion.png +0 -0
  39. fn_gen/rnd_search/14/expressions.txt +2 -0
  40. fn_gen/rnd_search/14/fn.py +554 -0
  41. fn_gen/rnd_search/14/loss.png +0 -0
  42. fn_gen/rnd_search/14/quantization.png +0 -0
  43. fn_gen/rnd_search/16/__pycache__/fn.cpython-311.pyc +0 -0
  44. fn_gen/rnd_search/16/distortion.png +0 -0
  45. fn_gen/rnd_search/16/expressions.txt +2 -0
  46. fn_gen/rnd_search/16/fn.py +554 -0
  47. fn_gen/rnd_search/16/loss.png +0 -0
  48. fn_gen/rnd_search/16/quantization.png +0 -0
  49. fn_gen/rnd_search/17/__pycache__/fn.cpython-311.pyc +0 -0
  50. fn_gen/rnd_search/17/distortion.png +0 -0
fn_gen/rnd_search/0/__pycache__/fn.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
fn_gen/rnd_search/0/distortion.png ADDED
fn_gen/rnd_search/0/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atanh(_0*x)/_s
2
+ tanh(_s*x)/_0
fn_gen/rnd_search/0/fn.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atanh(domain_guard((params['_0'] * x), min=-0.9999, max=0.9999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tanh((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ return best_params[0][0]
89
+
90
+
91
+ def init_linear_scale( # Symmetric scale. From the study folder
92
+ x: torch.Tensor,
93
+ **kwargs: Dict[str, Any],
94
+ ) -> torch.Tensor:
95
+ assert "bits" in kwargs, "bits must be provided."
96
+ assert "params" in kwargs, "params must be provided."
97
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
98
+
99
+ bits = kwargs.get('bits')
100
+ params = kwargs.get('params')
101
+ qtz_func = kwargs.get('qtz_func')
102
+
103
+ x_ = x.transpose(0, 1)
104
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
105
+ x_ = x_.transpose(0, 1)
106
+
107
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
108
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
109
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
110
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
111
+
112
+ eps = torch.finfo(torch.float32).eps
113
+
114
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
115
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
116
+
117
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
118
+
119
+ # Introduces some noise in scale
120
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
121
+ # scale = scale + 0.01 * torch.randn_like(scale)
122
+ return scale
123
+
124
+
125
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
126
+ params = {
127
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
128
+ }
129
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
130
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
131
+
132
+ if 'post_init_hook' in kwargs:
133
+ kwargs['post_init_hook'](parameters=params)
134
+
135
+
136
+ if 'post_train_hook' in kwargs:
137
+ kwargs['post_train_hook'](parameters=params)
138
+
139
+ return params
140
+
141
+
142
+ ############### Numpy Qtz ###############
143
+
144
+
145
+ def np_quantization(x, _0, _s):
146
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctanh(np_domain_guard((_0 * x), min=-0.9999, max=0.9999, nan=0)))
147
+
148
+
149
+ def np_dequantization(x, _0, _s):
150
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tanh((_s * x)))
151
+
152
+
153
+ def fit_func(x, _0, _s):
154
+ x_ = np_quantization(x, _0, _s)
155
+ x_ = np_dequantization(x_, _0, _s)
156
+ return x_
157
+
158
+
159
+
160
+ ############### HELPERS ###############
161
+
162
+ def domain_guard(
163
+ x: torch.Tensor,
164
+ min: float = None,
165
+ max: float = None,
166
+ posinf: float = None,
167
+ neginf: float = None,
168
+ nan: float = None
169
+ ) -> torch.Tensor:
170
+ """Guard a tensor to a valid domain."""
171
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
172
+ if min is not None or max is not None:
173
+ x = torch.clamp(x, min=min, max=max)
174
+ return x
175
+
176
+
177
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
178
+ """Replace a number in a tensor with another number.
179
+
180
+ Args:
181
+ x (torch.Tensor): The input tensor.
182
+ num (float): The number to replace.
183
+ to (float): The number to replace with.
184
+
185
+ Returns:
186
+ torch.Tensor: The tensor with the number replaced.
187
+ """
188
+ return torch.where(x == num, to, x)
189
+
190
+
191
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
192
+ """Guard the power operation to a valid domain."""
193
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
194
+
195
+
196
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
197
+ val = torch.amin(x, dim=1)
198
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
199
+
200
+
201
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
202
+ val = torch.amin(x, dim=1)
203
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
204
+
205
+
206
+ def init_space_search(
207
+ x: torch.Tensor,
208
+ **kwargs: Dict[str, Any],
209
+ ) -> torch.Tensor:
210
+
211
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
212
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
213
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
214
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
215
+
216
+ def _search_param(tensors: List[torch.tensor], n_params):
217
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
218
+ torch_tensors = torch.stack(tensors)
219
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
220
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
221
+ mean = torch.mean(torch_tensors, dim=0)
222
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
223
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
224
+
225
+ def _calc(x, qtz_func, deqtz_func, **params):
226
+ x_ = x.transpose(0, 1)
227
+ x_ = qtz_func(x=x_, **params)
228
+ x_ = deqtz_func(x=x_, **params)
229
+ x_ = x_.transpose(0, 1)
230
+ return x_
231
+
232
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
233
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
234
+ assert "params_list" in kwargs, "params list must be provided."
235
+ assert "param" in kwargs, "param must be provided."
236
+
237
+ qtz_func = kwargs.get('qtz_func')
238
+ deqtz_func = kwargs.get('deqtz_func')
239
+ params_list = kwargs.get('params_list')
240
+ param = kwargs.get('param')
241
+
242
+ n_runs = 50 # Number of runs to try to find the best parameters
243
+ n_random_params = 50 # Number of random parameters to generate
244
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
245
+ max_initial = 10000 # Maximum value to initialize the parameters
246
+
247
+ # Initializes the parameters
248
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
249
+ params = _build_initial_param(x, max_initial, n_random_params)
250
+
251
+ # Performs the search
252
+ for _ in range(n_runs):
253
+
254
+ best_params = []
255
+ for param_ in params:
256
+ try:
257
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
258
+ loss_ones = nn.MSELoss()(x, x_)
259
+
260
+ if len(best_params) < n_best_to_pick:
261
+ best_params.append((param_, loss_ones.item()))
262
+ best_params = sorted(best_params, key=lambda x: x[1])
263
+ elif loss_ones < best_params[-1][1]:
264
+ best_params[-1] = (param_, loss_ones.item())
265
+ best_params = sorted(best_params, key=lambda x: x[1])
266
+
267
+ except Exception: # The parameters might not be valid for the function's domain
268
+ continue
269
+
270
+ # Generates new parameters around the mean
271
+ params = _search_param([p for p, _ in best_params], n_random_params)
272
+
273
+ return best_params[0][0]
274
+
275
+
276
+ def init_linear_scale( # Symmetric scale. From the study folder
277
+ x: torch.Tensor,
278
+ **kwargs: Dict[str, Any],
279
+ ) -> torch.Tensor:
280
+ assert "bits" in kwargs, "bits must be provided."
281
+ assert "params" in kwargs, "params must be provided."
282
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
283
+
284
+ bits = kwargs.get('bits')
285
+ params = kwargs.get('params')
286
+ qtz_func = kwargs.get('qtz_func')
287
+
288
+ x_ = x.transpose(0, 1)
289
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
290
+ x_ = x_.transpose(0, 1)
291
+
292
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
293
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
294
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
295
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
296
+
297
+ eps = torch.finfo(torch.float32).eps
298
+
299
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
300
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
301
+
302
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
303
+
304
+ # Introduces some noise in scale
305
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
306
+ # scale = scale + 0.01 * torch.randn_like(scale)
307
+ return scale
308
+
309
+
310
+ def init_non_linear_regression_fit(
311
+ x: torch.Tensor,
312
+ **kwargs: Dict[str, Any],
313
+ ) -> torch.Tensor:
314
+
315
+ assert "params_list" in kwargs, "params list must be provided."
316
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
317
+ assert "p0" in kwargs, "p0 must be provided."
318
+ np_fit_func = kwargs.get('np_fit_func')
319
+ params_list = kwargs.get('params_list')
320
+ p0 = kwargs.get('p0')
321
+
322
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
323
+ popt, _ = curve_fit(
324
+ func,
325
+ xdata,
326
+ ydata,
327
+ maxfev=1000,
328
+ p0=p0,
329
+ method='lm'
330
+ )
331
+ return popt
332
+
333
+ # 1. Needs to convert the torch tensor to numpy tensor
334
+ xdata = x.cpu().numpy()
335
+
336
+ # 2. Sorts the data so that it makes it easier to fit to it
337
+ sorted_xdata = np.sort(xdata, axis=-1)
338
+
339
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
340
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
341
+
342
+ # 3. Finds the best parameters for each channel
343
+ try:
344
+ params = []
345
+ for i in range(sorted_xdata.shape[0]):
346
+ xdata_ = sorted_xdata[i]
347
+ p0_ = [p0[p][i] for p in params_list]
348
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
349
+ params.append(ch_params)
350
+
351
+ # 4. Builds the parameters
352
+ result = {}
353
+ for i, p in enumerate(params_list):
354
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
355
+
356
+ return result
357
+
358
+ except ValueError as e:
359
+ print(f"Could not fit the function with error: {e}")
360
+ print(f"Using fallback result...")
361
+ return {
362
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
363
+ }
364
+
365
+
366
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
367
+ val = torch.amin(x, dim=1)
368
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
369
+
370
+
371
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
372
+ # Calculate the original minimum and maximum values
373
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
374
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
375
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
376
+
377
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
378
+ return torch.ones_like(x_min)
379
+
380
+ # Calculate the scale factor
381
+ scale = (_max - _min) / (x_max - x_min)
382
+ return scale
383
+
384
+
385
+
386
+ ############## Quant ###############
387
+
388
+ @torch.enable_grad()
389
+ def learn_parameters(
390
+ x: torch.Tensor,
391
+ params: Dict[str, nn.Parameter],
392
+ qtz_func: nn.Module,
393
+ deqtz_func: nn.Module,
394
+ bits: int,
395
+ target_dtype: torch.dtype,
396
+ epochs: int = 1000,
397
+ early_stop: bool = True,
398
+ do_report: bool = False
399
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
400
+ loss_fn = nn.MSELoss()
401
+
402
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
403
+ # the order of magnitude of the loss divided by 2
404
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
405
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
406
+ loss = loss_fn(x, dequant)
407
+
408
+ base_lr = 0.1
409
+ exponent = int(np.floor(np.log10(loss.item())))
410
+ lr = base_lr * (10 ** (exponent // 2))
411
+
412
+ # Requires gradients in the parameters
413
+ for p in params.values():
414
+ p.requires_grad = True
415
+ p.grad = None
416
+
417
+ param_keys = list(params.keys())
418
+ param_values = list(params.values())
419
+
420
+ # Defines optimizer and loss function
421
+ optimizer = torch.optim.Adam(param_values, lr=lr)
422
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
423
+
424
+ # Contains the best loss and the best parameters
425
+ best_loss = float("inf")
426
+ best_params = None
427
+
428
+ # Used to stop the search early
429
+ min_delta = 1e-7
430
+ acc_loss = []
431
+ percent_epochs_before_stop = 0.1
432
+
433
+ for i in range(epochs):
434
+ optimizer.zero_grad()
435
+
436
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
437
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
438
+ loss = loss_fn(x, dequant)
439
+
440
+ if loss.isnan() or loss.isinf():
441
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
442
+
443
+ loss.backward()
444
+ optimizer.step()
445
+ scheduler.step()
446
+
447
+ acc_loss.append(loss.item())
448
+
449
+ # Reports loss every 10 steps
450
+ if i % 10 == 0 and do_report:
451
+ print(f"Epoch {i}: Loss {loss.item()}")
452
+
453
+ # Optimizes the parameter search by storing the best loss and the parameters
454
+ if loss.item() < best_loss:
455
+ best_loss = loss.item()
456
+ best_params = copy.deepcopy({
457
+ k: v for k, v in params.items() if k in param_keys
458
+ })
459
+
460
+ # We also stop the search if the loss has not considerably during the last 10% epochs
461
+ if early_stop:
462
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
463
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
464
+ break
465
+
466
+ # No longer requires gradients in the parameters
467
+ for p in best_params.values():
468
+ p.requires_grad = False
469
+ p.grad = None
470
+
471
+ if do_report:
472
+ return best_params, acc_loss
473
+ else:
474
+ return best_params
475
+
476
+
477
+ def quantize(
478
+ x: torch.Tensor,
479
+ params: Dict[str, nn.Parameter],
480
+ func: nn.Module,
481
+ bits: int,
482
+ target_dtype: torch.dtype = torch.int8
483
+ ) -> torch.Tensor:
484
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
485
+ x = x.transpose(0, 1) # Aligns shapes
486
+ x = func(x=x, **params)
487
+ x = x.transpose(0, 1)
488
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
489
+ return x
490
+
491
+
492
+ def dequantize(
493
+ x: torch.Tensor,
494
+ params: Dict[str, nn.Parameter],
495
+ func: nn.Module,
496
+ bits: int,
497
+ out_dtype: torch.dtype
498
+ ) -> torch.Tensor:
499
+ x = x.to(dtype=out_dtype)
500
+ x = x.transpose(0, 1)
501
+ x = func(x=x, **params)
502
+ x = x.transpose(0, 1)
503
+ return x
504
+
505
+
506
+ def round_func_BPDA(input):
507
+ # This is equivalent to replacing round function (non-differentiable) with
508
+ # an identity function (differentiable) only when backward.
509
+ forward_value = torch.round(input)
510
+ out = input.clone()
511
+ out.data = forward_value.data
512
+ return out
513
+
514
+
515
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
516
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
517
+
518
+
519
+
520
+ ############## Numpy ###############
521
+
522
+ def np_domain_guard(
523
+ x: np.ndarray,
524
+ min: float = None,
525
+ max: float = None,
526
+ posinf: float = None,
527
+ neginf: float = None,
528
+ nan: float = None
529
+ ) -> np.ndarray:
530
+ """Guard a tensor to a valid domain."""
531
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
532
+ if min is not None or max is not None:
533
+ x = np.clip(x, min, max)
534
+ return x
535
+
536
+
537
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
538
+ """Replace a number in a tensor with another number.
539
+
540
+ Args:
541
+ x (np.ndarray): The input tensor.
542
+ num (float): The number to replace.
543
+ to (float): The number to replace with.
544
+
545
+ Returns:
546
+ np.ndarray: The tensor with the number replaced.
547
+ """
548
+ return np.where(x == num, to, x)
549
+
550
+
551
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
552
+ """Guard the power operation to a valid domain."""
553
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
554
+
fn_gen/rnd_search/0/loss.png ADDED
fn_gen/rnd_search/0/quantization.png ADDED
fn_gen/rnd_search/1/__pycache__/fn.cpython-311.pyc ADDED
Binary file (27.4 kB). View file
 
fn_gen/rnd_search/1/distortion.png ADDED
fn_gen/rnd_search/1/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ (_0*x)**(1/3)/_s
2
+ _s**3*x**3/_0
fn_gen/rnd_search/1/fn.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power((params['_0'] * x), 1 / 3))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * guarded_torch_power(params['_s'], torch.tensor(3)) * guarded_torch_power(x, torch.tensor(3)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ return best_params[0][0]
89
+
90
+
91
+ def init_linear_scale( # Symmetric scale. From the study folder
92
+ x: torch.Tensor,
93
+ **kwargs: Dict[str, Any],
94
+ ) -> torch.Tensor:
95
+ assert "bits" in kwargs, "bits must be provided."
96
+ assert "params" in kwargs, "params must be provided."
97
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
98
+
99
+ bits = kwargs.get('bits')
100
+ params = kwargs.get('params')
101
+ qtz_func = kwargs.get('qtz_func')
102
+
103
+ x_ = x.transpose(0, 1)
104
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
105
+ x_ = x_.transpose(0, 1)
106
+
107
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
108
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
109
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
110
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
111
+
112
+ eps = torch.finfo(torch.float32).eps
113
+
114
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
115
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
116
+
117
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
118
+
119
+ # Introduces some noise in scale
120
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
121
+ # scale = scale + 0.01 * torch.randn_like(scale)
122
+ return scale
123
+
124
+
125
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
126
+ params = {
127
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
128
+ }
129
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
130
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
131
+
132
+ if 'post_init_hook' in kwargs:
133
+ kwargs['post_init_hook'](parameters=params)
134
+
135
+
136
+ if 'post_train_hook' in kwargs:
137
+ kwargs['post_train_hook'](parameters=params)
138
+
139
+ return params
140
+
141
+
142
+ ############### Numpy Qtz ###############
143
+
144
+
145
+ def np_quantization(x, _0, _s):
146
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power((_0 * x), 1 / 3))
147
+
148
+
149
+ def np_dequantization(x, _0, _s):
150
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np_guarded_power(_s, np.array(3)) * np_guarded_power(x, np.array(3)))
151
+
152
+
153
+ def fit_func(x, _0, _s):
154
+ x_ = np_quantization(x, _0, _s)
155
+ x_ = np_dequantization(x_, _0, _s)
156
+ return x_
157
+
158
+
159
+
160
+ ############### HELPERS ###############
161
+
162
+ def domain_guard(
163
+ x: torch.Tensor,
164
+ min: float = None,
165
+ max: float = None,
166
+ posinf: float = None,
167
+ neginf: float = None,
168
+ nan: float = None
169
+ ) -> torch.Tensor:
170
+ """Guard a tensor to a valid domain."""
171
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
172
+ if min is not None or max is not None:
173
+ x = torch.clamp(x, min=min, max=max)
174
+ return x
175
+
176
+
177
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
178
+ """Replace a number in a tensor with another number.
179
+
180
+ Args:
181
+ x (torch.Tensor): The input tensor.
182
+ num (float): The number to replace.
183
+ to (float): The number to replace with.
184
+
185
+ Returns:
186
+ torch.Tensor: The tensor with the number replaced.
187
+ """
188
+ return torch.where(x == num, to, x)
189
+
190
+
191
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
192
+ """Guard the power operation to a valid domain."""
193
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
194
+
195
+
196
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
197
+ val = torch.amin(x, dim=1)
198
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
199
+
200
+
201
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
202
+ val = torch.amin(x, dim=1)
203
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
204
+
205
+
206
+ def init_space_search(
207
+ x: torch.Tensor,
208
+ **kwargs: Dict[str, Any],
209
+ ) -> torch.Tensor:
210
+
211
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
212
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
213
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
214
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
215
+
216
+ def _search_param(tensors: List[torch.tensor], n_params):
217
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
218
+ torch_tensors = torch.stack(tensors)
219
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
220
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
221
+ mean = torch.mean(torch_tensors, dim=0)
222
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
223
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
224
+
225
+ def _calc(x, qtz_func, deqtz_func, **params):
226
+ x_ = x.transpose(0, 1)
227
+ x_ = qtz_func(x=x_, **params)
228
+ x_ = deqtz_func(x=x_, **params)
229
+ x_ = x_.transpose(0, 1)
230
+ return x_
231
+
232
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
233
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
234
+ assert "params_list" in kwargs, "params list must be provided."
235
+ assert "param" in kwargs, "param must be provided."
236
+
237
+ qtz_func = kwargs.get('qtz_func')
238
+ deqtz_func = kwargs.get('deqtz_func')
239
+ params_list = kwargs.get('params_list')
240
+ param = kwargs.get('param')
241
+
242
+ n_runs = 50 # Number of runs to try to find the best parameters
243
+ n_random_params = 50 # Number of random parameters to generate
244
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
245
+ max_initial = 10000 # Maximum value to initialize the parameters
246
+
247
+ # Initializes the parameters
248
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
249
+ params = _build_initial_param(x, max_initial, n_random_params)
250
+
251
+ # Performs the search
252
+ for _ in range(n_runs):
253
+
254
+ best_params = []
255
+ for param_ in params:
256
+ try:
257
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
258
+ loss_ones = nn.MSELoss()(x, x_)
259
+
260
+ if len(best_params) < n_best_to_pick:
261
+ best_params.append((param_, loss_ones.item()))
262
+ best_params = sorted(best_params, key=lambda x: x[1])
263
+ elif loss_ones < best_params[-1][1]:
264
+ best_params[-1] = (param_, loss_ones.item())
265
+ best_params = sorted(best_params, key=lambda x: x[1])
266
+
267
+ except Exception: # The parameters might not be valid for the function's domain
268
+ continue
269
+
270
+ # Generates new parameters around the mean
271
+ params = _search_param([p for p, _ in best_params], n_random_params)
272
+
273
+ return best_params[0][0]
274
+
275
+
276
+ def init_linear_scale( # Symmetric scale. From the study folder
277
+ x: torch.Tensor,
278
+ **kwargs: Dict[str, Any],
279
+ ) -> torch.Tensor:
280
+ assert "bits" in kwargs, "bits must be provided."
281
+ assert "params" in kwargs, "params must be provided."
282
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
283
+
284
+ bits = kwargs.get('bits')
285
+ params = kwargs.get('params')
286
+ qtz_func = kwargs.get('qtz_func')
287
+
288
+ x_ = x.transpose(0, 1)
289
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
290
+ x_ = x_.transpose(0, 1)
291
+
292
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
293
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
294
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
295
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
296
+
297
+ eps = torch.finfo(torch.float32).eps
298
+
299
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
300
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
301
+
302
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
303
+
304
+ # Introduces some noise in scale
305
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
306
+ # scale = scale + 0.01 * torch.randn_like(scale)
307
+ return scale
308
+
309
+
310
+ def init_non_linear_regression_fit(
311
+ x: torch.Tensor,
312
+ **kwargs: Dict[str, Any],
313
+ ) -> torch.Tensor:
314
+
315
+ assert "params_list" in kwargs, "params list must be provided."
316
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
317
+ assert "p0" in kwargs, "p0 must be provided."
318
+ np_fit_func = kwargs.get('np_fit_func')
319
+ params_list = kwargs.get('params_list')
320
+ p0 = kwargs.get('p0')
321
+
322
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
323
+ popt, _ = curve_fit(
324
+ func,
325
+ xdata,
326
+ ydata,
327
+ maxfev=1000,
328
+ p0=p0,
329
+ method='lm'
330
+ )
331
+ return popt
332
+
333
+ # 1. Needs to convert the torch tensor to numpy tensor
334
+ xdata = x.cpu().numpy()
335
+
336
+ # 2. Sorts the data so that it makes it easier to fit to it
337
+ sorted_xdata = np.sort(xdata, axis=-1)
338
+
339
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
340
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
341
+
342
+ # 3. Finds the best parameters for each channel
343
+ try:
344
+ params = []
345
+ for i in range(sorted_xdata.shape[0]):
346
+ xdata_ = sorted_xdata[i]
347
+ p0_ = [p0[p][i] for p in params_list]
348
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
349
+ params.append(ch_params)
350
+
351
+ # 4. Builds the parameters
352
+ result = {}
353
+ for i, p in enumerate(params_list):
354
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
355
+
356
+ return result
357
+
358
+ except ValueError as e:
359
+ print(f"Could not fit the function with error: {e}")
360
+ print(f"Using fallback result...")
361
+ return {
362
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
363
+ }
364
+
365
+
366
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
367
+ val = torch.amin(x, dim=1)
368
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
369
+
370
+
371
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
372
+ # Calculate the original minimum and maximum values
373
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
374
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
375
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
376
+
377
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
378
+ return torch.ones_like(x_min)
379
+
380
+ # Calculate the scale factor
381
+ scale = (_max - _min) / (x_max - x_min)
382
+ return scale
383
+
384
+
385
+
386
+ ############## Quant ###############
387
+
388
+ @torch.enable_grad()
389
+ def learn_parameters(
390
+ x: torch.Tensor,
391
+ params: Dict[str, nn.Parameter],
392
+ qtz_func: nn.Module,
393
+ deqtz_func: nn.Module,
394
+ bits: int,
395
+ target_dtype: torch.dtype,
396
+ epochs: int = 1000,
397
+ early_stop: bool = True,
398
+ do_report: bool = False
399
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
400
+ loss_fn = nn.MSELoss()
401
+
402
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
403
+ # the order of magnitude of the loss divided by 2
404
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
405
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
406
+ loss = loss_fn(x, dequant)
407
+
408
+ base_lr = 0.1
409
+ exponent = int(np.floor(np.log10(loss.item())))
410
+ lr = base_lr * (10 ** (exponent // 2))
411
+
412
+ # Requires gradients in the parameters
413
+ for p in params.values():
414
+ p.requires_grad = True
415
+ p.grad = None
416
+
417
+ param_keys = list(params.keys())
418
+ param_values = list(params.values())
419
+
420
+ # Defines optimizer and loss function
421
+ optimizer = torch.optim.Adam(param_values, lr=lr)
422
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
423
+
424
+ # Contains the best loss and the best parameters
425
+ best_loss = float("inf")
426
+ best_params = None
427
+
428
+ # Used to stop the search early
429
+ min_delta = 1e-7
430
+ acc_loss = []
431
+ percent_epochs_before_stop = 0.1
432
+
433
+ for i in range(epochs):
434
+ optimizer.zero_grad()
435
+
436
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
437
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
438
+ loss = loss_fn(x, dequant)
439
+
440
+ if loss.isnan() or loss.isinf():
441
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
442
+
443
+ loss.backward()
444
+ optimizer.step()
445
+ scheduler.step()
446
+
447
+ acc_loss.append(loss.item())
448
+
449
+ # Reports loss every 10 steps
450
+ if i % 10 == 0 and do_report:
451
+ print(f"Epoch {i}: Loss {loss.item()}")
452
+
453
+ # Optimizes the parameter search by storing the best loss and the parameters
454
+ if loss.item() < best_loss:
455
+ best_loss = loss.item()
456
+ best_params = copy.deepcopy({
457
+ k: v for k, v in params.items() if k in param_keys
458
+ })
459
+
460
+ # We also stop the search if the loss has not considerably during the last 10% epochs
461
+ if early_stop:
462
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
463
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
464
+ break
465
+
466
+ # No longer requires gradients in the parameters
467
+ for p in best_params.values():
468
+ p.requires_grad = False
469
+ p.grad = None
470
+
471
+ if do_report:
472
+ return best_params, acc_loss
473
+ else:
474
+ return best_params
475
+
476
+
477
+ def quantize(
478
+ x: torch.Tensor,
479
+ params: Dict[str, nn.Parameter],
480
+ func: nn.Module,
481
+ bits: int,
482
+ target_dtype: torch.dtype = torch.int8
483
+ ) -> torch.Tensor:
484
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
485
+ x = x.transpose(0, 1) # Aligns shapes
486
+ x = func(x=x, **params)
487
+ x = x.transpose(0, 1)
488
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
489
+ return x
490
+
491
+
492
+ def dequantize(
493
+ x: torch.Tensor,
494
+ params: Dict[str, nn.Parameter],
495
+ func: nn.Module,
496
+ bits: int,
497
+ out_dtype: torch.dtype
498
+ ) -> torch.Tensor:
499
+ x = x.to(dtype=out_dtype)
500
+ x = x.transpose(0, 1)
501
+ x = func(x=x, **params)
502
+ x = x.transpose(0, 1)
503
+ return x
504
+
505
+
506
+ def round_func_BPDA(input):
507
+ # This is equivalent to replacing round function (non-differentiable) with
508
+ # an identity function (differentiable) only when backward.
509
+ forward_value = torch.round(input)
510
+ out = input.clone()
511
+ out.data = forward_value.data
512
+ return out
513
+
514
+
515
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
516
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
517
+
518
+
519
+
520
+ ############## Numpy ###############
521
+
522
+ def np_domain_guard(
523
+ x: np.ndarray,
524
+ min: float = None,
525
+ max: float = None,
526
+ posinf: float = None,
527
+ neginf: float = None,
528
+ nan: float = None
529
+ ) -> np.ndarray:
530
+ """Guard a tensor to a valid domain."""
531
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
532
+ if min is not None or max is not None:
533
+ x = np.clip(x, min, max)
534
+ return x
535
+
536
+
537
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
538
+ """Replace a number in a tensor with another number.
539
+
540
+ Args:
541
+ x (np.ndarray): The input tensor.
542
+ num (float): The number to replace.
543
+ to (float): The number to replace with.
544
+
545
+ Returns:
546
+ np.ndarray: The tensor with the number replaced.
547
+ """
548
+ return np.where(x == num, to, x)
549
+
550
+
551
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
552
+ """Guard the power operation to a valid domain."""
553
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
554
+
fn_gen/rnd_search/1/loss.png ADDED
fn_gen/rnd_search/1/quantization.png ADDED
fn_gen/rnd_search/10/__pycache__/fn.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
fn_gen/rnd_search/10/distortion.png ADDED
fn_gen/rnd_search/10/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ x**2/_s
2
+ sqrt(_s*x)
fn_gen/rnd_search/10/fn.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * guarded_torch_power(x, torch.tensor(2)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return torch.sqrt(domain_guard((params['_s'] * x), min=0.1, nan=0.1))
19
+
20
+
21
+ def init_linear_scale( # Symmetric scale. From the study folder
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+ assert "bits" in kwargs, "bits must be provided."
26
+ assert "params" in kwargs, "params must be provided."
27
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
28
+
29
+ bits = kwargs.get('bits')
30
+ params = kwargs.get('params')
31
+ qtz_func = kwargs.get('qtz_func')
32
+
33
+ x_ = x.transpose(0, 1)
34
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
35
+ x_ = x_.transpose(0, 1)
36
+
37
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
38
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
39
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
40
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
41
+
42
+ eps = torch.finfo(torch.float32).eps
43
+
44
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
45
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
46
+
47
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
48
+
49
+ # Introduces some noise in scale
50
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
51
+ # scale = scale + 0.01 * torch.randn_like(scale)
52
+ return scale
53
+
54
+
55
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
56
+ params = {
57
+ }
58
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
59
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
60
+
61
+ if 'post_init_hook' in kwargs:
62
+ kwargs['post_init_hook'](parameters=params)
63
+
64
+
65
+ if 'post_train_hook' in kwargs:
66
+ kwargs['post_train_hook'](parameters=params)
67
+
68
+ return params
69
+
70
+
71
+ ############### Numpy Qtz ###############
72
+
73
+
74
+ def np_quantization(x, _s):
75
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np_guarded_power(x, np.array(2)))
76
+
77
+
78
+ def np_dequantization(x, _s):
79
+ return np.sqrt(np_domain_guard((_s * x), min=0.1, nan=0.1))
80
+
81
+
82
+ def fit_func(x, _s):
83
+ x_ = np_quantization(x, _s)
84
+ x_ = np_dequantization(x_, _s)
85
+ return x_
86
+
87
+
88
+
89
+ ############### HELPERS ###############
90
+
91
+ def domain_guard(
92
+ x: torch.Tensor,
93
+ min: float = None,
94
+ max: float = None,
95
+ posinf: float = None,
96
+ neginf: float = None,
97
+ nan: float = None
98
+ ) -> torch.Tensor:
99
+ """Guard a tensor to a valid domain."""
100
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
101
+ if min is not None or max is not None:
102
+ x = torch.clamp(x, min=min, max=max)
103
+ return x
104
+
105
+
106
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
107
+ """Replace a number in a tensor with another number.
108
+
109
+ Args:
110
+ x (torch.Tensor): The input tensor.
111
+ num (float): The number to replace.
112
+ to (float): The number to replace with.
113
+
114
+ Returns:
115
+ torch.Tensor: The tensor with the number replaced.
116
+ """
117
+ return torch.where(x == num, to, x)
118
+
119
+
120
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
121
+ """Guard the power operation to a valid domain."""
122
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
123
+
124
+
125
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
126
+ val = torch.amin(x, dim=1)
127
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
128
+
129
+
130
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
131
+ val = torch.amin(x, dim=1)
132
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
133
+
134
+
135
+ def init_space_search(
136
+ x: torch.Tensor,
137
+ **kwargs: Dict[str, Any],
138
+ ) -> torch.Tensor:
139
+
140
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
141
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
142
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
143
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
144
+
145
+ def _search_param(tensors: List[torch.tensor], n_params):
146
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
147
+ torch_tensors = torch.stack(tensors)
148
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
149
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
150
+ mean = torch.mean(torch_tensors, dim=0)
151
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
152
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
153
+
154
+ def _calc(x, qtz_func, deqtz_func, **params):
155
+ x_ = x.transpose(0, 1)
156
+ x_ = qtz_func(x=x_, **params)
157
+ x_ = deqtz_func(x=x_, **params)
158
+ x_ = x_.transpose(0, 1)
159
+ return x_
160
+
161
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
162
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
163
+ assert "params_list" in kwargs, "params list must be provided."
164
+ assert "param" in kwargs, "param must be provided."
165
+
166
+ qtz_func = kwargs.get('qtz_func')
167
+ deqtz_func = kwargs.get('deqtz_func')
168
+ params_list = kwargs.get('params_list')
169
+ param = kwargs.get('param')
170
+
171
+ n_runs = 50 # Number of runs to try to find the best parameters
172
+ n_random_params = 50 # Number of random parameters to generate
173
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
174
+ max_initial = 10000 # Maximum value to initialize the parameters
175
+
176
+ # Initializes the parameters
177
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
178
+ params = _build_initial_param(x, max_initial, n_random_params)
179
+
180
+ # Performs the search
181
+ for _ in range(n_runs):
182
+
183
+ best_params = []
184
+ for param_ in params:
185
+ try:
186
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
187
+ loss_ones = nn.MSELoss()(x, x_)
188
+
189
+ if len(best_params) < n_best_to_pick:
190
+ best_params.append((param_, loss_ones.item()))
191
+ best_params = sorted(best_params, key=lambda x: x[1])
192
+ elif loss_ones < best_params[-1][1]:
193
+ best_params[-1] = (param_, loss_ones.item())
194
+ best_params = sorted(best_params, key=lambda x: x[1])
195
+
196
+ except Exception: # The parameters might not be valid for the function's domain
197
+ continue
198
+
199
+ # Generates new parameters around the mean
200
+ params = _search_param([p for p, _ in best_params], n_random_params)
201
+
202
+ return best_params[0][0]
203
+
204
+
205
+ def init_linear_scale( # Symmetric scale. From the study folder
206
+ x: torch.Tensor,
207
+ **kwargs: Dict[str, Any],
208
+ ) -> torch.Tensor:
209
+ assert "bits" in kwargs, "bits must be provided."
210
+ assert "params" in kwargs, "params must be provided."
211
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
212
+
213
+ bits = kwargs.get('bits')
214
+ params = kwargs.get('params')
215
+ qtz_func = kwargs.get('qtz_func')
216
+
217
+ x_ = x.transpose(0, 1)
218
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
219
+ x_ = x_.transpose(0, 1)
220
+
221
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
222
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
223
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
224
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
225
+
226
+ eps = torch.finfo(torch.float32).eps
227
+
228
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
229
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
230
+
231
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
232
+
233
+ # Introduces some noise in scale
234
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
235
+ # scale = scale + 0.01 * torch.randn_like(scale)
236
+ return scale
237
+
238
+
239
+ def init_non_linear_regression_fit(
240
+ x: torch.Tensor,
241
+ **kwargs: Dict[str, Any],
242
+ ) -> torch.Tensor:
243
+
244
+ assert "params_list" in kwargs, "params list must be provided."
245
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
246
+ assert "p0" in kwargs, "p0 must be provided."
247
+ np_fit_func = kwargs.get('np_fit_func')
248
+ params_list = kwargs.get('params_list')
249
+ p0 = kwargs.get('p0')
250
+
251
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
252
+ popt, _ = curve_fit(
253
+ func,
254
+ xdata,
255
+ ydata,
256
+ maxfev=1000,
257
+ p0=p0,
258
+ method='lm'
259
+ )
260
+ return popt
261
+
262
+ # 1. Needs to convert the torch tensor to numpy tensor
263
+ xdata = x.cpu().numpy()
264
+
265
+ # 2. Sorts the data so that it makes it easier to fit to it
266
+ sorted_xdata = np.sort(xdata, axis=-1)
267
+
268
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
269
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
270
+
271
+ # 3. Finds the best parameters for each channel
272
+ try:
273
+ params = []
274
+ for i in range(sorted_xdata.shape[0]):
275
+ xdata_ = sorted_xdata[i]
276
+ p0_ = [p0[p][i] for p in params_list]
277
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
278
+ params.append(ch_params)
279
+
280
+ # 4. Builds the parameters
281
+ result = {}
282
+ for i, p in enumerate(params_list):
283
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
284
+
285
+ return result
286
+
287
+ except ValueError as e:
288
+ print(f"Could not fit the function with error: {e}")
289
+ print(f"Using fallback result...")
290
+ return {
291
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
292
+ }
293
+
294
+
295
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
296
+ val = torch.amin(x, dim=1)
297
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
298
+
299
+
300
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
301
+ # Calculate the original minimum and maximum values
302
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
303
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
304
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
305
+
306
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
307
+ return torch.ones_like(x_min)
308
+
309
+ # Calculate the scale factor
310
+ scale = (_max - _min) / (x_max - x_min)
311
+ return scale
312
+
313
+
314
+
315
+ ############## Quant ###############
316
+
317
+ @torch.enable_grad()
318
+ def learn_parameters(
319
+ x: torch.Tensor,
320
+ params: Dict[str, nn.Parameter],
321
+ qtz_func: nn.Module,
322
+ deqtz_func: nn.Module,
323
+ bits: int,
324
+ target_dtype: torch.dtype,
325
+ epochs: int = 1000,
326
+ early_stop: bool = True,
327
+ do_report: bool = False
328
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
329
+ loss_fn = nn.MSELoss()
330
+
331
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
332
+ # the order of magnitude of the loss divided by 2
333
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
334
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
335
+ loss = loss_fn(x, dequant)
336
+
337
+ base_lr = 0.1
338
+ exponent = int(np.floor(np.log10(loss.item())))
339
+ lr = base_lr * (10 ** (exponent // 2))
340
+
341
+ # Requires gradients in the parameters
342
+ for p in params.values():
343
+ p.requires_grad = True
344
+ p.grad = None
345
+
346
+ param_keys = list(params.keys())
347
+ param_values = list(params.values())
348
+
349
+ # Defines optimizer and loss function
350
+ optimizer = torch.optim.Adam(param_values, lr=lr)
351
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
352
+
353
+ # Contains the best loss and the best parameters
354
+ best_loss = float("inf")
355
+ best_params = None
356
+
357
+ # Used to stop the search early
358
+ min_delta = 1e-7
359
+ acc_loss = []
360
+ percent_epochs_before_stop = 0.1
361
+
362
+ for i in range(epochs):
363
+ optimizer.zero_grad()
364
+
365
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
366
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
367
+ loss = loss_fn(x, dequant)
368
+
369
+ if loss.isnan() or loss.isinf():
370
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
371
+
372
+ loss.backward()
373
+ optimizer.step()
374
+ scheduler.step()
375
+
376
+ acc_loss.append(loss.item())
377
+
378
+ # Reports loss every 10 steps
379
+ if i % 10 == 0 and do_report:
380
+ print(f"Epoch {i}: Loss {loss.item()}")
381
+
382
+ # Optimizes the parameter search by storing the best loss and the parameters
383
+ if loss.item() < best_loss:
384
+ best_loss = loss.item()
385
+ best_params = copy.deepcopy({
386
+ k: v for k, v in params.items() if k in param_keys
387
+ })
388
+
389
+ # We also stop the search if the loss has not considerably during the last 10% epochs
390
+ if early_stop:
391
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
392
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
393
+ break
394
+
395
+ # No longer requires gradients in the parameters
396
+ for p in best_params.values():
397
+ p.requires_grad = False
398
+ p.grad = None
399
+
400
+ if do_report:
401
+ return best_params, acc_loss
402
+ else:
403
+ return best_params
404
+
405
+
406
+ def quantize(
407
+ x: torch.Tensor,
408
+ params: Dict[str, nn.Parameter],
409
+ func: nn.Module,
410
+ bits: int,
411
+ target_dtype: torch.dtype = torch.int8
412
+ ) -> torch.Tensor:
413
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
414
+ x = x.transpose(0, 1) # Aligns shapes
415
+ x = func(x=x, **params)
416
+ x = x.transpose(0, 1)
417
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
418
+ return x
419
+
420
+
421
+ def dequantize(
422
+ x: torch.Tensor,
423
+ params: Dict[str, nn.Parameter],
424
+ func: nn.Module,
425
+ bits: int,
426
+ out_dtype: torch.dtype
427
+ ) -> torch.Tensor:
428
+ x = x.to(dtype=out_dtype)
429
+ x = x.transpose(0, 1)
430
+ x = func(x=x, **params)
431
+ x = x.transpose(0, 1)
432
+ return x
433
+
434
+
435
+ def round_func_BPDA(input):
436
+ # This is equivalent to replacing round function (non-differentiable) with
437
+ # an identity function (differentiable) only when backward.
438
+ forward_value = torch.round(input)
439
+ out = input.clone()
440
+ out.data = forward_value.data
441
+ return out
442
+
443
+
444
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
445
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
446
+
447
+
448
+
449
+ ############## Numpy ###############
450
+
451
+ def np_domain_guard(
452
+ x: np.ndarray,
453
+ min: float = None,
454
+ max: float = None,
455
+ posinf: float = None,
456
+ neginf: float = None,
457
+ nan: float = None
458
+ ) -> np.ndarray:
459
+ """Guard a tensor to a valid domain."""
460
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
461
+ if min is not None or max is not None:
462
+ x = np.clip(x, min, max)
463
+ return x
464
+
465
+
466
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
467
+ """Replace a number in a tensor with another number.
468
+
469
+ Args:
470
+ x (np.ndarray): The input tensor.
471
+ num (float): The number to replace.
472
+ to (float): The number to replace with.
473
+
474
+ Returns:
475
+ np.ndarray: The tensor with the number replaced.
476
+ """
477
+ return np.where(x == num, to, x)
478
+
479
+
480
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
481
+ """Guard the power operation to a valid domain."""
482
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
483
+
fn_gen/rnd_search/10/loss.png ADDED
fn_gen/rnd_search/10/quantization.png ADDED
fn_gen/rnd_search/11/__pycache__/fn.cpython-311.pyc ADDED
Binary file (27.1 kB). View file
 
fn_gen/rnd_search/11/distortion.png ADDED
fn_gen/rnd_search/11/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asinh(_0*x)/_s
2
+ sinh(_s*x)/_0
fn_gen/rnd_search/11/fn.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.asinh((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.sinh((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ return best_params[0][0]
89
+
90
+
91
+ def init_linear_scale( # Symmetric scale. From the study folder
92
+ x: torch.Tensor,
93
+ **kwargs: Dict[str, Any],
94
+ ) -> torch.Tensor:
95
+ assert "bits" in kwargs, "bits must be provided."
96
+ assert "params" in kwargs, "params must be provided."
97
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
98
+
99
+ bits = kwargs.get('bits')
100
+ params = kwargs.get('params')
101
+ qtz_func = kwargs.get('qtz_func')
102
+
103
+ x_ = x.transpose(0, 1)
104
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
105
+ x_ = x_.transpose(0, 1)
106
+
107
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
108
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
109
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
110
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
111
+
112
+ eps = torch.finfo(torch.float32).eps
113
+
114
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
115
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
116
+
117
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
118
+
119
+ # Introduces some noise in scale
120
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
121
+ # scale = scale + 0.01 * torch.randn_like(scale)
122
+ return scale
123
+
124
+
125
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
126
+ params = {
127
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
128
+ }
129
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
130
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
131
+
132
+ if 'post_init_hook' in kwargs:
133
+ kwargs['post_init_hook'](parameters=params)
134
+
135
+
136
+ if 'post_train_hook' in kwargs:
137
+ kwargs['post_train_hook'](parameters=params)
138
+
139
+ return params
140
+
141
+
142
+ ############### Numpy Qtz ###############
143
+
144
+
145
+ def np_quantization(x, _0, _s):
146
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arcsinh((_0 * x)))
147
+
148
+
149
+ def np_dequantization(x, _0, _s):
150
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.sinh((_s * x)))
151
+
152
+
153
+ def fit_func(x, _0, _s):
154
+ x_ = np_quantization(x, _0, _s)
155
+ x_ = np_dequantization(x_, _0, _s)
156
+ return x_
157
+
158
+
159
+
160
+ ############### HELPERS ###############
161
+
162
+ def domain_guard(
163
+ x: torch.Tensor,
164
+ min: float = None,
165
+ max: float = None,
166
+ posinf: float = None,
167
+ neginf: float = None,
168
+ nan: float = None
169
+ ) -> torch.Tensor:
170
+ """Guard a tensor to a valid domain."""
171
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
172
+ if min is not None or max is not None:
173
+ x = torch.clamp(x, min=min, max=max)
174
+ return x
175
+
176
+
177
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
178
+ """Replace a number in a tensor with another number.
179
+
180
+ Args:
181
+ x (torch.Tensor): The input tensor.
182
+ num (float): The number to replace.
183
+ to (float): The number to replace with.
184
+
185
+ Returns:
186
+ torch.Tensor: The tensor with the number replaced.
187
+ """
188
+ return torch.where(x == num, to, x)
189
+
190
+
191
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
192
+ """Guard the power operation to a valid domain."""
193
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
194
+
195
+
196
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
197
+ val = torch.amin(x, dim=1)
198
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
199
+
200
+
201
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
202
+ val = torch.amin(x, dim=1)
203
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
204
+
205
+
206
+ def init_space_search(
207
+ x: torch.Tensor,
208
+ **kwargs: Dict[str, Any],
209
+ ) -> torch.Tensor:
210
+
211
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
212
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
213
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
214
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
215
+
216
+ def _search_param(tensors: List[torch.tensor], n_params):
217
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
218
+ torch_tensors = torch.stack(tensors)
219
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
220
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
221
+ mean = torch.mean(torch_tensors, dim=0)
222
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
223
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
224
+
225
+ def _calc(x, qtz_func, deqtz_func, **params):
226
+ x_ = x.transpose(0, 1)
227
+ x_ = qtz_func(x=x_, **params)
228
+ x_ = deqtz_func(x=x_, **params)
229
+ x_ = x_.transpose(0, 1)
230
+ return x_
231
+
232
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
233
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
234
+ assert "params_list" in kwargs, "params list must be provided."
235
+ assert "param" in kwargs, "param must be provided."
236
+
237
+ qtz_func = kwargs.get('qtz_func')
238
+ deqtz_func = kwargs.get('deqtz_func')
239
+ params_list = kwargs.get('params_list')
240
+ param = kwargs.get('param')
241
+
242
+ n_runs = 50 # Number of runs to try to find the best parameters
243
+ n_random_params = 50 # Number of random parameters to generate
244
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
245
+ max_initial = 10000 # Maximum value to initialize the parameters
246
+
247
+ # Initializes the parameters
248
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
249
+ params = _build_initial_param(x, max_initial, n_random_params)
250
+
251
+ # Performs the search
252
+ for _ in range(n_runs):
253
+
254
+ best_params = []
255
+ for param_ in params:
256
+ try:
257
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
258
+ loss_ones = nn.MSELoss()(x, x_)
259
+
260
+ if len(best_params) < n_best_to_pick:
261
+ best_params.append((param_, loss_ones.item()))
262
+ best_params = sorted(best_params, key=lambda x: x[1])
263
+ elif loss_ones < best_params[-1][1]:
264
+ best_params[-1] = (param_, loss_ones.item())
265
+ best_params = sorted(best_params, key=lambda x: x[1])
266
+
267
+ except Exception: # The parameters might not be valid for the function's domain
268
+ continue
269
+
270
+ # Generates new parameters around the mean
271
+ params = _search_param([p for p, _ in best_params], n_random_params)
272
+
273
+ return best_params[0][0]
274
+
275
+
276
+ def init_linear_scale( # Symmetric scale. From the study folder
277
+ x: torch.Tensor,
278
+ **kwargs: Dict[str, Any],
279
+ ) -> torch.Tensor:
280
+ assert "bits" in kwargs, "bits must be provided."
281
+ assert "params" in kwargs, "params must be provided."
282
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
283
+
284
+ bits = kwargs.get('bits')
285
+ params = kwargs.get('params')
286
+ qtz_func = kwargs.get('qtz_func')
287
+
288
+ x_ = x.transpose(0, 1)
289
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
290
+ x_ = x_.transpose(0, 1)
291
+
292
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
293
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
294
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
295
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
296
+
297
+ eps = torch.finfo(torch.float32).eps
298
+
299
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
300
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
301
+
302
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
303
+
304
+ # Introduces some noise in scale
305
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
306
+ # scale = scale + 0.01 * torch.randn_like(scale)
307
+ return scale
308
+
309
+
310
+ def init_non_linear_regression_fit(
311
+ x: torch.Tensor,
312
+ **kwargs: Dict[str, Any],
313
+ ) -> torch.Tensor:
314
+
315
+ assert "params_list" in kwargs, "params list must be provided."
316
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
317
+ assert "p0" in kwargs, "p0 must be provided."
318
+ np_fit_func = kwargs.get('np_fit_func')
319
+ params_list = kwargs.get('params_list')
320
+ p0 = kwargs.get('p0')
321
+
322
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
323
+ popt, _ = curve_fit(
324
+ func,
325
+ xdata,
326
+ ydata,
327
+ maxfev=1000,
328
+ p0=p0,
329
+ method='lm'
330
+ )
331
+ return popt
332
+
333
+ # 1. Needs to convert the torch tensor to numpy tensor
334
+ xdata = x.cpu().numpy()
335
+
336
+ # 2. Sorts the data so that it makes it easier to fit to it
337
+ sorted_xdata = np.sort(xdata, axis=-1)
338
+
339
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
340
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
341
+
342
+ # 3. Finds the best parameters for each channel
343
+ try:
344
+ params = []
345
+ for i in range(sorted_xdata.shape[0]):
346
+ xdata_ = sorted_xdata[i]
347
+ p0_ = [p0[p][i] for p in params_list]
348
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
349
+ params.append(ch_params)
350
+
351
+ # 4. Builds the parameters
352
+ result = {}
353
+ for i, p in enumerate(params_list):
354
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
355
+
356
+ return result
357
+
358
+ except ValueError as e:
359
+ print(f"Could not fit the function with error: {e}")
360
+ print(f"Using fallback result...")
361
+ return {
362
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
363
+ }
364
+
365
+
366
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
367
+ val = torch.amin(x, dim=1)
368
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
369
+
370
+
371
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
372
+ # Calculate the original minimum and maximum values
373
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
374
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
375
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
376
+
377
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
378
+ return torch.ones_like(x_min)
379
+
380
+ # Calculate the scale factor
381
+ scale = (_max - _min) / (x_max - x_min)
382
+ return scale
383
+
384
+
385
+
386
+ ############## Quant ###############
387
+
388
+ @torch.enable_grad()
389
+ def learn_parameters(
390
+ x: torch.Tensor,
391
+ params: Dict[str, nn.Parameter],
392
+ qtz_func: nn.Module,
393
+ deqtz_func: nn.Module,
394
+ bits: int,
395
+ target_dtype: torch.dtype,
396
+ epochs: int = 1000,
397
+ early_stop: bool = True,
398
+ do_report: bool = False
399
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
400
+ loss_fn = nn.MSELoss()
401
+
402
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
403
+ # the order of magnitude of the loss divided by 2
404
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
405
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
406
+ loss = loss_fn(x, dequant)
407
+
408
+ base_lr = 0.1
409
+ exponent = int(np.floor(np.log10(loss.item())))
410
+ lr = base_lr * (10 ** (exponent // 2))
411
+
412
+ # Requires gradients in the parameters
413
+ for p in params.values():
414
+ p.requires_grad = True
415
+ p.grad = None
416
+
417
+ param_keys = list(params.keys())
418
+ param_values = list(params.values())
419
+
420
+ # Defines optimizer and loss function
421
+ optimizer = torch.optim.Adam(param_values, lr=lr)
422
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
423
+
424
+ # Contains the best loss and the best parameters
425
+ best_loss = float("inf")
426
+ best_params = None
427
+
428
+ # Used to stop the search early
429
+ min_delta = 1e-7
430
+ acc_loss = []
431
+ percent_epochs_before_stop = 0.1
432
+
433
+ for i in range(epochs):
434
+ optimizer.zero_grad()
435
+
436
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
437
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
438
+ loss = loss_fn(x, dequant)
439
+
440
+ if loss.isnan() or loss.isinf():
441
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
442
+
443
+ loss.backward()
444
+ optimizer.step()
445
+ scheduler.step()
446
+
447
+ acc_loss.append(loss.item())
448
+
449
+ # Reports loss every 10 steps
450
+ if i % 10 == 0 and do_report:
451
+ print(f"Epoch {i}: Loss {loss.item()}")
452
+
453
+ # Optimizes the parameter search by storing the best loss and the parameters
454
+ if loss.item() < best_loss:
455
+ best_loss = loss.item()
456
+ best_params = copy.deepcopy({
457
+ k: v for k, v in params.items() if k in param_keys
458
+ })
459
+
460
+ # We also stop the search if the loss has not considerably during the last 10% epochs
461
+ if early_stop:
462
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
463
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
464
+ break
465
+
466
+ # No longer requires gradients in the parameters
467
+ for p in best_params.values():
468
+ p.requires_grad = False
469
+ p.grad = None
470
+
471
+ if do_report:
472
+ return best_params, acc_loss
473
+ else:
474
+ return best_params
475
+
476
+
477
+ def quantize(
478
+ x: torch.Tensor,
479
+ params: Dict[str, nn.Parameter],
480
+ func: nn.Module,
481
+ bits: int,
482
+ target_dtype: torch.dtype = torch.int8
483
+ ) -> torch.Tensor:
484
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
485
+ x = x.transpose(0, 1) # Aligns shapes
486
+ x = func(x=x, **params)
487
+ x = x.transpose(0, 1)
488
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
489
+ return x
490
+
491
+
492
+ def dequantize(
493
+ x: torch.Tensor,
494
+ params: Dict[str, nn.Parameter],
495
+ func: nn.Module,
496
+ bits: int,
497
+ out_dtype: torch.dtype
498
+ ) -> torch.Tensor:
499
+ x = x.to(dtype=out_dtype)
500
+ x = x.transpose(0, 1)
501
+ x = func(x=x, **params)
502
+ x = x.transpose(0, 1)
503
+ return x
504
+
505
+
506
+ def round_func_BPDA(input):
507
+ # This is equivalent to replacing round function (non-differentiable) with
508
+ # an identity function (differentiable) only when backward.
509
+ forward_value = torch.round(input)
510
+ out = input.clone()
511
+ out.data = forward_value.data
512
+ return out
513
+
514
+
515
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
516
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
517
+
518
+
519
+
520
+ ############## Numpy ###############
521
+
522
+ def np_domain_guard(
523
+ x: np.ndarray,
524
+ min: float = None,
525
+ max: float = None,
526
+ posinf: float = None,
527
+ neginf: float = None,
528
+ nan: float = None
529
+ ) -> np.ndarray:
530
+ """Guard a tensor to a valid domain."""
531
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
532
+ if min is not None or max is not None:
533
+ x = np.clip(x, min, max)
534
+ return x
535
+
536
+
537
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
538
+ """Replace a number in a tensor with another number.
539
+
540
+ Args:
541
+ x (np.ndarray): The input tensor.
542
+ num (float): The number to replace.
543
+ to (float): The number to replace with.
544
+
545
+ Returns:
546
+ np.ndarray: The tensor with the number replaced.
547
+ """
548
+ return np.where(x == num, to, x)
549
+
550
+
551
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
552
+ """Guard the power operation to a valid domain."""
553
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
554
+
fn_gen/rnd_search/11/loss.png ADDED
fn_gen/rnd_search/11/quantization.png ADDED
fn_gen/rnd_search/12/__pycache__/fn.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
fn_gen/rnd_search/12/distortion.png ADDED
fn_gen/rnd_search/12/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acos(_0*x)/_s
2
+ cos(_s*x)/_0
fn_gen/rnd_search/12/fn.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acos(domain_guard((params['_0'] * x), min=-0.99999, max=0.99999, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cos((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ return best_params[0][0]
89
+
90
+
91
+ def init_linear_scale( # Symmetric scale. From the study folder
92
+ x: torch.Tensor,
93
+ **kwargs: Dict[str, Any],
94
+ ) -> torch.Tensor:
95
+ assert "bits" in kwargs, "bits must be provided."
96
+ assert "params" in kwargs, "params must be provided."
97
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
98
+
99
+ bits = kwargs.get('bits')
100
+ params = kwargs.get('params')
101
+ qtz_func = kwargs.get('qtz_func')
102
+
103
+ x_ = x.transpose(0, 1)
104
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
105
+ x_ = x_.transpose(0, 1)
106
+
107
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
108
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
109
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
110
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
111
+
112
+ eps = torch.finfo(torch.float32).eps
113
+
114
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
115
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
116
+
117
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
118
+
119
+ # Introduces some noise in scale
120
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
121
+ # scale = scale + 0.01 * torch.randn_like(scale)
122
+ return scale
123
+
124
+
125
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
126
+ params = {
127
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
128
+ }
129
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
130
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
131
+
132
+ if 'post_init_hook' in kwargs:
133
+ kwargs['post_init_hook'](parameters=params)
134
+
135
+
136
+ if 'post_train_hook' in kwargs:
137
+ kwargs['post_train_hook'](parameters=params)
138
+
139
+ return params
140
+
141
+
142
+ ############### Numpy Qtz ###############
143
+
144
+
145
+ def np_quantization(x, _0, _s):
146
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccos(np_domain_guard((_0 * x), min=-0.99999, max=0.99999, nan=0)))
147
+
148
+
149
+ def np_dequantization(x, _0, _s):
150
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cos((_s * x)))
151
+
152
+
153
+ def fit_func(x, _0, _s):
154
+ x_ = np_quantization(x, _0, _s)
155
+ x_ = np_dequantization(x_, _0, _s)
156
+ return x_
157
+
158
+
159
+
160
+ ############### HELPERS ###############
161
+
162
+ def domain_guard(
163
+ x: torch.Tensor,
164
+ min: float = None,
165
+ max: float = None,
166
+ posinf: float = None,
167
+ neginf: float = None,
168
+ nan: float = None
169
+ ) -> torch.Tensor:
170
+ """Guard a tensor to a valid domain."""
171
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
172
+ if min is not None or max is not None:
173
+ x = torch.clamp(x, min=min, max=max)
174
+ return x
175
+
176
+
177
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
178
+ """Replace a number in a tensor with another number.
179
+
180
+ Args:
181
+ x (torch.Tensor): The input tensor.
182
+ num (float): The number to replace.
183
+ to (float): The number to replace with.
184
+
185
+ Returns:
186
+ torch.Tensor: The tensor with the number replaced.
187
+ """
188
+ return torch.where(x == num, to, x)
189
+
190
+
191
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
192
+ """Guard the power operation to a valid domain."""
193
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
194
+
195
+
196
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
197
+ val = torch.amin(x, dim=1)
198
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
199
+
200
+
201
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
202
+ val = torch.amin(x, dim=1)
203
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
204
+
205
+
206
+ def init_space_search(
207
+ x: torch.Tensor,
208
+ **kwargs: Dict[str, Any],
209
+ ) -> torch.Tensor:
210
+
211
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
212
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
213
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
214
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
215
+
216
+ def _search_param(tensors: List[torch.tensor], n_params):
217
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
218
+ torch_tensors = torch.stack(tensors)
219
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
220
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
221
+ mean = torch.mean(torch_tensors, dim=0)
222
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
223
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
224
+
225
+ def _calc(x, qtz_func, deqtz_func, **params):
226
+ x_ = x.transpose(0, 1)
227
+ x_ = qtz_func(x=x_, **params)
228
+ x_ = deqtz_func(x=x_, **params)
229
+ x_ = x_.transpose(0, 1)
230
+ return x_
231
+
232
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
233
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
234
+ assert "params_list" in kwargs, "params list must be provided."
235
+ assert "param" in kwargs, "param must be provided."
236
+
237
+ qtz_func = kwargs.get('qtz_func')
238
+ deqtz_func = kwargs.get('deqtz_func')
239
+ params_list = kwargs.get('params_list')
240
+ param = kwargs.get('param')
241
+
242
+ n_runs = 50 # Number of runs to try to find the best parameters
243
+ n_random_params = 50 # Number of random parameters to generate
244
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
245
+ max_initial = 10000 # Maximum value to initialize the parameters
246
+
247
+ # Initializes the parameters
248
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
249
+ params = _build_initial_param(x, max_initial, n_random_params)
250
+
251
+ # Performs the search
252
+ for _ in range(n_runs):
253
+
254
+ best_params = []
255
+ for param_ in params:
256
+ try:
257
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
258
+ loss_ones = nn.MSELoss()(x, x_)
259
+
260
+ if len(best_params) < n_best_to_pick:
261
+ best_params.append((param_, loss_ones.item()))
262
+ best_params = sorted(best_params, key=lambda x: x[1])
263
+ elif loss_ones < best_params[-1][1]:
264
+ best_params[-1] = (param_, loss_ones.item())
265
+ best_params = sorted(best_params, key=lambda x: x[1])
266
+
267
+ except Exception: # The parameters might not be valid for the function's domain
268
+ continue
269
+
270
+ # Generates new parameters around the mean
271
+ params = _search_param([p for p, _ in best_params], n_random_params)
272
+
273
+ return best_params[0][0]
274
+
275
+
276
+ def init_linear_scale( # Symmetric scale. From the study folder
277
+ x: torch.Tensor,
278
+ **kwargs: Dict[str, Any],
279
+ ) -> torch.Tensor:
280
+ assert "bits" in kwargs, "bits must be provided."
281
+ assert "params" in kwargs, "params must be provided."
282
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
283
+
284
+ bits = kwargs.get('bits')
285
+ params = kwargs.get('params')
286
+ qtz_func = kwargs.get('qtz_func')
287
+
288
+ x_ = x.transpose(0, 1)
289
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
290
+ x_ = x_.transpose(0, 1)
291
+
292
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
293
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
294
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
295
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
296
+
297
+ eps = torch.finfo(torch.float32).eps
298
+
299
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
300
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
301
+
302
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
303
+
304
+ # Introduces some noise in scale
305
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
306
+ # scale = scale + 0.01 * torch.randn_like(scale)
307
+ return scale
308
+
309
+
310
+ def init_non_linear_regression_fit(
311
+ x: torch.Tensor,
312
+ **kwargs: Dict[str, Any],
313
+ ) -> torch.Tensor:
314
+
315
+ assert "params_list" in kwargs, "params list must be provided."
316
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
317
+ assert "p0" in kwargs, "p0 must be provided."
318
+ np_fit_func = kwargs.get('np_fit_func')
319
+ params_list = kwargs.get('params_list')
320
+ p0 = kwargs.get('p0')
321
+
322
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
323
+ popt, _ = curve_fit(
324
+ func,
325
+ xdata,
326
+ ydata,
327
+ maxfev=1000,
328
+ p0=p0,
329
+ method='lm'
330
+ )
331
+ return popt
332
+
333
+ # 1. Needs to convert the torch tensor to numpy tensor
334
+ xdata = x.cpu().numpy()
335
+
336
+ # 2. Sorts the data so that it makes it easier to fit to it
337
+ sorted_xdata = np.sort(xdata, axis=-1)
338
+
339
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
340
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
341
+
342
+ # 3. Finds the best parameters for each channel
343
+ try:
344
+ params = []
345
+ for i in range(sorted_xdata.shape[0]):
346
+ xdata_ = sorted_xdata[i]
347
+ p0_ = [p0[p][i] for p in params_list]
348
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
349
+ params.append(ch_params)
350
+
351
+ # 4. Builds the parameters
352
+ result = {}
353
+ for i, p in enumerate(params_list):
354
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
355
+
356
+ return result
357
+
358
+ except ValueError as e:
359
+ print(f"Could not fit the function with error: {e}")
360
+ print(f"Using fallback result...")
361
+ return {
362
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
363
+ }
364
+
365
+
366
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
367
+ val = torch.amin(x, dim=1)
368
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
369
+
370
+
371
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
372
+ # Calculate the original minimum and maximum values
373
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
374
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
375
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
376
+
377
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
378
+ return torch.ones_like(x_min)
379
+
380
+ # Calculate the scale factor
381
+ scale = (_max - _min) / (x_max - x_min)
382
+ return scale
383
+
384
+
385
+
386
+ ############## Quant ###############
387
+
388
+ @torch.enable_grad()
389
+ def learn_parameters(
390
+ x: torch.Tensor,
391
+ params: Dict[str, nn.Parameter],
392
+ qtz_func: nn.Module,
393
+ deqtz_func: nn.Module,
394
+ bits: int,
395
+ target_dtype: torch.dtype,
396
+ epochs: int = 1000,
397
+ early_stop: bool = True,
398
+ do_report: bool = False
399
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
400
+ loss_fn = nn.MSELoss()
401
+
402
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
403
+ # the order of magnitude of the loss divided by 2
404
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
405
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
406
+ loss = loss_fn(x, dequant)
407
+
408
+ base_lr = 0.1
409
+ exponent = int(np.floor(np.log10(loss.item())))
410
+ lr = base_lr * (10 ** (exponent // 2))
411
+
412
+ # Requires gradients in the parameters
413
+ for p in params.values():
414
+ p.requires_grad = True
415
+ p.grad = None
416
+
417
+ param_keys = list(params.keys())
418
+ param_values = list(params.values())
419
+
420
+ # Defines optimizer and loss function
421
+ optimizer = torch.optim.Adam(param_values, lr=lr)
422
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
423
+
424
+ # Contains the best loss and the best parameters
425
+ best_loss = float("inf")
426
+ best_params = None
427
+
428
+ # Used to stop the search early
429
+ min_delta = 1e-7
430
+ acc_loss = []
431
+ percent_epochs_before_stop = 0.1
432
+
433
+ for i in range(epochs):
434
+ optimizer.zero_grad()
435
+
436
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
437
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
438
+ loss = loss_fn(x, dequant)
439
+
440
+ if loss.isnan() or loss.isinf():
441
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
442
+
443
+ loss.backward()
444
+ optimizer.step()
445
+ scheduler.step()
446
+
447
+ acc_loss.append(loss.item())
448
+
449
+ # Reports loss every 10 steps
450
+ if i % 10 == 0 and do_report:
451
+ print(f"Epoch {i}: Loss {loss.item()}")
452
+
453
+ # Optimizes the parameter search by storing the best loss and the parameters
454
+ if loss.item() < best_loss:
455
+ best_loss = loss.item()
456
+ best_params = copy.deepcopy({
457
+ k: v for k, v in params.items() if k in param_keys
458
+ })
459
+
460
+ # We also stop the search if the loss has not considerably during the last 10% epochs
461
+ if early_stop:
462
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
463
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
464
+ break
465
+
466
+ # No longer requires gradients in the parameters
467
+ for p in best_params.values():
468
+ p.requires_grad = False
469
+ p.grad = None
470
+
471
+ if do_report:
472
+ return best_params, acc_loss
473
+ else:
474
+ return best_params
475
+
476
+
477
+ def quantize(
478
+ x: torch.Tensor,
479
+ params: Dict[str, nn.Parameter],
480
+ func: nn.Module,
481
+ bits: int,
482
+ target_dtype: torch.dtype = torch.int8
483
+ ) -> torch.Tensor:
484
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
485
+ x = x.transpose(0, 1) # Aligns shapes
486
+ x = func(x=x, **params)
487
+ x = x.transpose(0, 1)
488
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
489
+ return x
490
+
491
+
492
+ def dequantize(
493
+ x: torch.Tensor,
494
+ params: Dict[str, nn.Parameter],
495
+ func: nn.Module,
496
+ bits: int,
497
+ out_dtype: torch.dtype
498
+ ) -> torch.Tensor:
499
+ x = x.to(dtype=out_dtype)
500
+ x = x.transpose(0, 1)
501
+ x = func(x=x, **params)
502
+ x = x.transpose(0, 1)
503
+ return x
504
+
505
+
506
+ def round_func_BPDA(input):
507
+ # This is equivalent to replacing round function (non-differentiable) with
508
+ # an identity function (differentiable) only when backward.
509
+ forward_value = torch.round(input)
510
+ out = input.clone()
511
+ out.data = forward_value.data
512
+ return out
513
+
514
+
515
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
516
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
517
+
518
+
519
+
520
+ ############## Numpy ###############
521
+
522
+ def np_domain_guard(
523
+ x: np.ndarray,
524
+ min: float = None,
525
+ max: float = None,
526
+ posinf: float = None,
527
+ neginf: float = None,
528
+ nan: float = None
529
+ ) -> np.ndarray:
530
+ """Guard a tensor to a valid domain."""
531
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
532
+ if min is not None or max is not None:
533
+ x = np.clip(x, min, max)
534
+ return x
535
+
536
+
537
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
538
+ """Replace a number in a tensor with another number.
539
+
540
+ Args:
541
+ x (np.ndarray): The input tensor.
542
+ num (float): The number to replace.
543
+ to (float): The number to replace with.
544
+
545
+ Returns:
546
+ np.ndarray: The tensor with the number replaced.
547
+ """
548
+ return np.where(x == num, to, x)
549
+
550
+
551
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
552
+ """Guard the power operation to a valid domain."""
553
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
554
+
fn_gen/rnd_search/12/loss.png ADDED
fn_gen/rnd_search/12/quantization.png ADDED
fn_gen/rnd_search/13/__pycache__/fn.cpython-311.pyc ADDED
Binary file (27.2 kB). View file
 
fn_gen/rnd_search/13/distortion.png ADDED
fn_gen/rnd_search/13/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ acosh(_0*x)/_s
2
+ cosh(_s*x)/_0
fn_gen/rnd_search/13/fn.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.acosh(domain_guard((params['_0'] * x), min=1, nan=1)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.cosh((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ return best_params[0][0]
89
+
90
+
91
+ def init_linear_scale( # Symmetric scale. From the study folder
92
+ x: torch.Tensor,
93
+ **kwargs: Dict[str, Any],
94
+ ) -> torch.Tensor:
95
+ assert "bits" in kwargs, "bits must be provided."
96
+ assert "params" in kwargs, "params must be provided."
97
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
98
+
99
+ bits = kwargs.get('bits')
100
+ params = kwargs.get('params')
101
+ qtz_func = kwargs.get('qtz_func')
102
+
103
+ x_ = x.transpose(0, 1)
104
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
105
+ x_ = x_.transpose(0, 1)
106
+
107
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
108
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
109
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
110
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
111
+
112
+ eps = torch.finfo(torch.float32).eps
113
+
114
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
115
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
116
+
117
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
118
+
119
+ # Introduces some noise in scale
120
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
121
+ # scale = scale + 0.01 * torch.randn_like(scale)
122
+ return scale
123
+
124
+
125
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
126
+ params = {
127
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
128
+ }
129
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
130
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
131
+
132
+ if 'post_init_hook' in kwargs:
133
+ kwargs['post_init_hook'](parameters=params)
134
+
135
+
136
+ if 'post_train_hook' in kwargs:
137
+ kwargs['post_train_hook'](parameters=params)
138
+
139
+ return params
140
+
141
+
142
+ ############### Numpy Qtz ###############
143
+
144
+
145
+ def np_quantization(x, _0, _s):
146
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arccosh(np_domain_guard((_0 * x), min=1, nan=1)))
147
+
148
+
149
+ def np_dequantization(x, _0, _s):
150
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.cosh((_s * x)))
151
+
152
+
153
+ def fit_func(x, _0, _s):
154
+ x_ = np_quantization(x, _0, _s)
155
+ x_ = np_dequantization(x_, _0, _s)
156
+ return x_
157
+
158
+
159
+
160
+ ############### HELPERS ###############
161
+
162
+ def domain_guard(
163
+ x: torch.Tensor,
164
+ min: float = None,
165
+ max: float = None,
166
+ posinf: float = None,
167
+ neginf: float = None,
168
+ nan: float = None
169
+ ) -> torch.Tensor:
170
+ """Guard a tensor to a valid domain."""
171
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
172
+ if min is not None or max is not None:
173
+ x = torch.clamp(x, min=min, max=max)
174
+ return x
175
+
176
+
177
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
178
+ """Replace a number in a tensor with another number.
179
+
180
+ Args:
181
+ x (torch.Tensor): The input tensor.
182
+ num (float): The number to replace.
183
+ to (float): The number to replace with.
184
+
185
+ Returns:
186
+ torch.Tensor: The tensor with the number replaced.
187
+ """
188
+ return torch.where(x == num, to, x)
189
+
190
+
191
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
192
+ """Guard the power operation to a valid domain."""
193
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
194
+
195
+
196
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
197
+ val = torch.amin(x, dim=1)
198
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
199
+
200
+
201
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
202
+ val = torch.amin(x, dim=1)
203
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
204
+
205
+
206
+ def init_space_search(
207
+ x: torch.Tensor,
208
+ **kwargs: Dict[str, Any],
209
+ ) -> torch.Tensor:
210
+
211
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
212
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
213
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
214
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
215
+
216
+ def _search_param(tensors: List[torch.tensor], n_params):
217
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
218
+ torch_tensors = torch.stack(tensors)
219
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
220
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
221
+ mean = torch.mean(torch_tensors, dim=0)
222
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
223
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
224
+
225
+ def _calc(x, qtz_func, deqtz_func, **params):
226
+ x_ = x.transpose(0, 1)
227
+ x_ = qtz_func(x=x_, **params)
228
+ x_ = deqtz_func(x=x_, **params)
229
+ x_ = x_.transpose(0, 1)
230
+ return x_
231
+
232
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
233
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
234
+ assert "params_list" in kwargs, "params list must be provided."
235
+ assert "param" in kwargs, "param must be provided."
236
+
237
+ qtz_func = kwargs.get('qtz_func')
238
+ deqtz_func = kwargs.get('deqtz_func')
239
+ params_list = kwargs.get('params_list')
240
+ param = kwargs.get('param')
241
+
242
+ n_runs = 50 # Number of runs to try to find the best parameters
243
+ n_random_params = 50 # Number of random parameters to generate
244
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
245
+ max_initial = 10000 # Maximum value to initialize the parameters
246
+
247
+ # Initializes the parameters
248
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
249
+ params = _build_initial_param(x, max_initial, n_random_params)
250
+
251
+ # Performs the search
252
+ for _ in range(n_runs):
253
+
254
+ best_params = []
255
+ for param_ in params:
256
+ try:
257
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
258
+ loss_ones = nn.MSELoss()(x, x_)
259
+
260
+ if len(best_params) < n_best_to_pick:
261
+ best_params.append((param_, loss_ones.item()))
262
+ best_params = sorted(best_params, key=lambda x: x[1])
263
+ elif loss_ones < best_params[-1][1]:
264
+ best_params[-1] = (param_, loss_ones.item())
265
+ best_params = sorted(best_params, key=lambda x: x[1])
266
+
267
+ except Exception: # The parameters might not be valid for the function's domain
268
+ continue
269
+
270
+ # Generates new parameters around the mean
271
+ params = _search_param([p for p, _ in best_params], n_random_params)
272
+
273
+ return best_params[0][0]
274
+
275
+
276
+ def init_linear_scale( # Symmetric scale. From the study folder
277
+ x: torch.Tensor,
278
+ **kwargs: Dict[str, Any],
279
+ ) -> torch.Tensor:
280
+ assert "bits" in kwargs, "bits must be provided."
281
+ assert "params" in kwargs, "params must be provided."
282
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
283
+
284
+ bits = kwargs.get('bits')
285
+ params = kwargs.get('params')
286
+ qtz_func = kwargs.get('qtz_func')
287
+
288
+ x_ = x.transpose(0, 1)
289
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
290
+ x_ = x_.transpose(0, 1)
291
+
292
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
293
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
294
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
295
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
296
+
297
+ eps = torch.finfo(torch.float32).eps
298
+
299
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
300
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
301
+
302
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
303
+
304
+ # Introduces some noise in scale
305
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
306
+ # scale = scale + 0.01 * torch.randn_like(scale)
307
+ return scale
308
+
309
+
310
+ def init_non_linear_regression_fit(
311
+ x: torch.Tensor,
312
+ **kwargs: Dict[str, Any],
313
+ ) -> torch.Tensor:
314
+
315
+ assert "params_list" in kwargs, "params list must be provided."
316
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
317
+ assert "p0" in kwargs, "p0 must be provided."
318
+ np_fit_func = kwargs.get('np_fit_func')
319
+ params_list = kwargs.get('params_list')
320
+ p0 = kwargs.get('p0')
321
+
322
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
323
+ popt, _ = curve_fit(
324
+ func,
325
+ xdata,
326
+ ydata,
327
+ maxfev=1000,
328
+ p0=p0,
329
+ method='lm'
330
+ )
331
+ return popt
332
+
333
+ # 1. Needs to convert the torch tensor to numpy tensor
334
+ xdata = x.cpu().numpy()
335
+
336
+ # 2. Sorts the data so that it makes it easier to fit to it
337
+ sorted_xdata = np.sort(xdata, axis=-1)
338
+
339
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
340
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
341
+
342
+ # 3. Finds the best parameters for each channel
343
+ try:
344
+ params = []
345
+ for i in range(sorted_xdata.shape[0]):
346
+ xdata_ = sorted_xdata[i]
347
+ p0_ = [p0[p][i] for p in params_list]
348
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
349
+ params.append(ch_params)
350
+
351
+ # 4. Builds the parameters
352
+ result = {}
353
+ for i, p in enumerate(params_list):
354
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
355
+
356
+ return result
357
+
358
+ except ValueError as e:
359
+ print(f"Could not fit the function with error: {e}")
360
+ print(f"Using fallback result...")
361
+ return {
362
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
363
+ }
364
+
365
+
366
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
367
+ val = torch.amin(x, dim=1)
368
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
369
+
370
+
371
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
372
+ # Calculate the original minimum and maximum values
373
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
374
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
375
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
376
+
377
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
378
+ return torch.ones_like(x_min)
379
+
380
+ # Calculate the scale factor
381
+ scale = (_max - _min) / (x_max - x_min)
382
+ return scale
383
+
384
+
385
+
386
+ ############## Quant ###############
387
+
388
+ @torch.enable_grad()
389
+ def learn_parameters(
390
+ x: torch.Tensor,
391
+ params: Dict[str, nn.Parameter],
392
+ qtz_func: nn.Module,
393
+ deqtz_func: nn.Module,
394
+ bits: int,
395
+ target_dtype: torch.dtype,
396
+ epochs: int = 1000,
397
+ early_stop: bool = True,
398
+ do_report: bool = False
399
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
400
+ loss_fn = nn.MSELoss()
401
+
402
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
403
+ # the order of magnitude of the loss divided by 2
404
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
405
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
406
+ loss = loss_fn(x, dequant)
407
+
408
+ base_lr = 0.1
409
+ exponent = int(np.floor(np.log10(loss.item())))
410
+ lr = base_lr * (10 ** (exponent // 2))
411
+
412
+ # Requires gradients in the parameters
413
+ for p in params.values():
414
+ p.requires_grad = True
415
+ p.grad = None
416
+
417
+ param_keys = list(params.keys())
418
+ param_values = list(params.values())
419
+
420
+ # Defines optimizer and loss function
421
+ optimizer = torch.optim.Adam(param_values, lr=lr)
422
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
423
+
424
+ # Contains the best loss and the best parameters
425
+ best_loss = float("inf")
426
+ best_params = None
427
+
428
+ # Used to stop the search early
429
+ min_delta = 1e-7
430
+ acc_loss = []
431
+ percent_epochs_before_stop = 0.1
432
+
433
+ for i in range(epochs):
434
+ optimizer.zero_grad()
435
+
436
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
437
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
438
+ loss = loss_fn(x, dequant)
439
+
440
+ if loss.isnan() or loss.isinf():
441
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
442
+
443
+ loss.backward()
444
+ optimizer.step()
445
+ scheduler.step()
446
+
447
+ acc_loss.append(loss.item())
448
+
449
+ # Reports loss every 10 steps
450
+ if i % 10 == 0 and do_report:
451
+ print(f"Epoch {i}: Loss {loss.item()}")
452
+
453
+ # Optimizes the parameter search by storing the best loss and the parameters
454
+ if loss.item() < best_loss:
455
+ best_loss = loss.item()
456
+ best_params = copy.deepcopy({
457
+ k: v for k, v in params.items() if k in param_keys
458
+ })
459
+
460
+ # We also stop the search if the loss has not considerably during the last 10% epochs
461
+ if early_stop:
462
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
463
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
464
+ break
465
+
466
+ # No longer requires gradients in the parameters
467
+ for p in best_params.values():
468
+ p.requires_grad = False
469
+ p.grad = None
470
+
471
+ if do_report:
472
+ return best_params, acc_loss
473
+ else:
474
+ return best_params
475
+
476
+
477
+ def quantize(
478
+ x: torch.Tensor,
479
+ params: Dict[str, nn.Parameter],
480
+ func: nn.Module,
481
+ bits: int,
482
+ target_dtype: torch.dtype = torch.int8
483
+ ) -> torch.Tensor:
484
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
485
+ x = x.transpose(0, 1) # Aligns shapes
486
+ x = func(x=x, **params)
487
+ x = x.transpose(0, 1)
488
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
489
+ return x
490
+
491
+
492
+ def dequantize(
493
+ x: torch.Tensor,
494
+ params: Dict[str, nn.Parameter],
495
+ func: nn.Module,
496
+ bits: int,
497
+ out_dtype: torch.dtype
498
+ ) -> torch.Tensor:
499
+ x = x.to(dtype=out_dtype)
500
+ x = x.transpose(0, 1)
501
+ x = func(x=x, **params)
502
+ x = x.transpose(0, 1)
503
+ return x
504
+
505
+
506
+ def round_func_BPDA(input):
507
+ # This is equivalent to replacing round function (non-differentiable) with
508
+ # an identity function (differentiable) only when backward.
509
+ forward_value = torch.round(input)
510
+ out = input.clone()
511
+ out.data = forward_value.data
512
+ return out
513
+
514
+
515
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
516
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
517
+
518
+
519
+
520
+ ############## Numpy ###############
521
+
522
+ def np_domain_guard(
523
+ x: np.ndarray,
524
+ min: float = None,
525
+ max: float = None,
526
+ posinf: float = None,
527
+ neginf: float = None,
528
+ nan: float = None
529
+ ) -> np.ndarray:
530
+ """Guard a tensor to a valid domain."""
531
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
532
+ if min is not None or max is not None:
533
+ x = np.clip(x, min, max)
534
+ return x
535
+
536
+
537
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
538
+ """Replace a number in a tensor with another number.
539
+
540
+ Args:
541
+ x (np.ndarray): The input tensor.
542
+ num (float): The number to replace.
543
+ to (float): The number to replace with.
544
+
545
+ Returns:
546
+ np.ndarray: The tensor with the number replaced.
547
+ """
548
+ return np.where(x == num, to, x)
549
+
550
+
551
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
552
+ """Guard the power operation to a valid domain."""
553
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
554
+
fn_gen/rnd_search/13/loss.png ADDED
fn_gen/rnd_search/13/quantization.png ADDED
fn_gen/rnd_search/14/__pycache__/fn.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
fn_gen/rnd_search/14/distortion.png ADDED
fn_gen/rnd_search/14/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tan(_0*x)/_s
2
+ atan(_s*x)/_0
fn_gen/rnd_search/14/fn.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.tan(domain_guard((params['_0'] * x), posinf=1, neginf=-1, nan=0)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.atan((params['_s'] * x)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ return best_params[0][0]
89
+
90
+
91
+ def init_linear_scale( # Symmetric scale. From the study folder
92
+ x: torch.Tensor,
93
+ **kwargs: Dict[str, Any],
94
+ ) -> torch.Tensor:
95
+ assert "bits" in kwargs, "bits must be provided."
96
+ assert "params" in kwargs, "params must be provided."
97
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
98
+
99
+ bits = kwargs.get('bits')
100
+ params = kwargs.get('params')
101
+ qtz_func = kwargs.get('qtz_func')
102
+
103
+ x_ = x.transpose(0, 1)
104
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
105
+ x_ = x_.transpose(0, 1)
106
+
107
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
108
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
109
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
110
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
111
+
112
+ eps = torch.finfo(torch.float32).eps
113
+
114
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
115
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
116
+
117
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
118
+
119
+ # Introduces some noise in scale
120
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
121
+ # scale = scale + 0.01 * torch.randn_like(scale)
122
+ return scale
123
+
124
+
125
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
126
+ params = {
127
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
128
+ }
129
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
130
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
131
+
132
+ if 'post_init_hook' in kwargs:
133
+ kwargs['post_init_hook'](parameters=params)
134
+
135
+
136
+ if 'post_train_hook' in kwargs:
137
+ kwargs['post_train_hook'](parameters=params)
138
+
139
+ return params
140
+
141
+
142
+ ############### Numpy Qtz ###############
143
+
144
+
145
+ def np_quantization(x, _0, _s):
146
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.tan(np_domain_guard((_0 * x), posinf=1, neginf=-1, nan=0)))
147
+
148
+
149
+ def np_dequantization(x, _0, _s):
150
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.arctan((_s * x)))
151
+
152
+
153
+ def fit_func(x, _0, _s):
154
+ x_ = np_quantization(x, _0, _s)
155
+ x_ = np_dequantization(x_, _0, _s)
156
+ return x_
157
+
158
+
159
+
160
+ ############### HELPERS ###############
161
+
162
+ def domain_guard(
163
+ x: torch.Tensor,
164
+ min: float = None,
165
+ max: float = None,
166
+ posinf: float = None,
167
+ neginf: float = None,
168
+ nan: float = None
169
+ ) -> torch.Tensor:
170
+ """Guard a tensor to a valid domain."""
171
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
172
+ if min is not None or max is not None:
173
+ x = torch.clamp(x, min=min, max=max)
174
+ return x
175
+
176
+
177
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
178
+ """Replace a number in a tensor with another number.
179
+
180
+ Args:
181
+ x (torch.Tensor): The input tensor.
182
+ num (float): The number to replace.
183
+ to (float): The number to replace with.
184
+
185
+ Returns:
186
+ torch.Tensor: The tensor with the number replaced.
187
+ """
188
+ return torch.where(x == num, to, x)
189
+
190
+
191
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
192
+ """Guard the power operation to a valid domain."""
193
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
194
+
195
+
196
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
197
+ val = torch.amin(x, dim=1)
198
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
199
+
200
+
201
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
202
+ val = torch.amin(x, dim=1)
203
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
204
+
205
+
206
+ def init_space_search(
207
+ x: torch.Tensor,
208
+ **kwargs: Dict[str, Any],
209
+ ) -> torch.Tensor:
210
+
211
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
212
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
213
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
214
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
215
+
216
+ def _search_param(tensors: List[torch.tensor], n_params):
217
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
218
+ torch_tensors = torch.stack(tensors)
219
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
220
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
221
+ mean = torch.mean(torch_tensors, dim=0)
222
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
223
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
224
+
225
+ def _calc(x, qtz_func, deqtz_func, **params):
226
+ x_ = x.transpose(0, 1)
227
+ x_ = qtz_func(x=x_, **params)
228
+ x_ = deqtz_func(x=x_, **params)
229
+ x_ = x_.transpose(0, 1)
230
+ return x_
231
+
232
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
233
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
234
+ assert "params_list" in kwargs, "params list must be provided."
235
+ assert "param" in kwargs, "param must be provided."
236
+
237
+ qtz_func = kwargs.get('qtz_func')
238
+ deqtz_func = kwargs.get('deqtz_func')
239
+ params_list = kwargs.get('params_list')
240
+ param = kwargs.get('param')
241
+
242
+ n_runs = 50 # Number of runs to try to find the best parameters
243
+ n_random_params = 50 # Number of random parameters to generate
244
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
245
+ max_initial = 10000 # Maximum value to initialize the parameters
246
+
247
+ # Initializes the parameters
248
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
249
+ params = _build_initial_param(x, max_initial, n_random_params)
250
+
251
+ # Performs the search
252
+ for _ in range(n_runs):
253
+
254
+ best_params = []
255
+ for param_ in params:
256
+ try:
257
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
258
+ loss_ones = nn.MSELoss()(x, x_)
259
+
260
+ if len(best_params) < n_best_to_pick:
261
+ best_params.append((param_, loss_ones.item()))
262
+ best_params = sorted(best_params, key=lambda x: x[1])
263
+ elif loss_ones < best_params[-1][1]:
264
+ best_params[-1] = (param_, loss_ones.item())
265
+ best_params = sorted(best_params, key=lambda x: x[1])
266
+
267
+ except Exception: # The parameters might not be valid for the function's domain
268
+ continue
269
+
270
+ # Generates new parameters around the mean
271
+ params = _search_param([p for p, _ in best_params], n_random_params)
272
+
273
+ return best_params[0][0]
274
+
275
+
276
+ def init_linear_scale( # Symmetric scale. From the study folder
277
+ x: torch.Tensor,
278
+ **kwargs: Dict[str, Any],
279
+ ) -> torch.Tensor:
280
+ assert "bits" in kwargs, "bits must be provided."
281
+ assert "params" in kwargs, "params must be provided."
282
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
283
+
284
+ bits = kwargs.get('bits')
285
+ params = kwargs.get('params')
286
+ qtz_func = kwargs.get('qtz_func')
287
+
288
+ x_ = x.transpose(0, 1)
289
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
290
+ x_ = x_.transpose(0, 1)
291
+
292
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
293
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
294
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
295
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
296
+
297
+ eps = torch.finfo(torch.float32).eps
298
+
299
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
300
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
301
+
302
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
303
+
304
+ # Introduces some noise in scale
305
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
306
+ # scale = scale + 0.01 * torch.randn_like(scale)
307
+ return scale
308
+
309
+
310
+ def init_non_linear_regression_fit(
311
+ x: torch.Tensor,
312
+ **kwargs: Dict[str, Any],
313
+ ) -> torch.Tensor:
314
+
315
+ assert "params_list" in kwargs, "params list must be provided."
316
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
317
+ assert "p0" in kwargs, "p0 must be provided."
318
+ np_fit_func = kwargs.get('np_fit_func')
319
+ params_list = kwargs.get('params_list')
320
+ p0 = kwargs.get('p0')
321
+
322
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
323
+ popt, _ = curve_fit(
324
+ func,
325
+ xdata,
326
+ ydata,
327
+ maxfev=1000,
328
+ p0=p0,
329
+ method='lm'
330
+ )
331
+ return popt
332
+
333
+ # 1. Needs to convert the torch tensor to numpy tensor
334
+ xdata = x.cpu().numpy()
335
+
336
+ # 2. Sorts the data so that it makes it easier to fit to it
337
+ sorted_xdata = np.sort(xdata, axis=-1)
338
+
339
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
340
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
341
+
342
+ # 3. Finds the best parameters for each channel
343
+ try:
344
+ params = []
345
+ for i in range(sorted_xdata.shape[0]):
346
+ xdata_ = sorted_xdata[i]
347
+ p0_ = [p0[p][i] for p in params_list]
348
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
349
+ params.append(ch_params)
350
+
351
+ # 4. Builds the parameters
352
+ result = {}
353
+ for i, p in enumerate(params_list):
354
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
355
+
356
+ return result
357
+
358
+ except ValueError as e:
359
+ print(f"Could not fit the function with error: {e}")
360
+ print(f"Using fallback result...")
361
+ return {
362
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
363
+ }
364
+
365
+
366
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
367
+ val = torch.amin(x, dim=1)
368
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
369
+
370
+
371
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
372
+ # Calculate the original minimum and maximum values
373
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
374
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
375
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
376
+
377
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
378
+ return torch.ones_like(x_min)
379
+
380
+ # Calculate the scale factor
381
+ scale = (_max - _min) / (x_max - x_min)
382
+ return scale
383
+
384
+
385
+
386
+ ############## Quant ###############
387
+
388
+ @torch.enable_grad()
389
+ def learn_parameters(
390
+ x: torch.Tensor,
391
+ params: Dict[str, nn.Parameter],
392
+ qtz_func: nn.Module,
393
+ deqtz_func: nn.Module,
394
+ bits: int,
395
+ target_dtype: torch.dtype,
396
+ epochs: int = 1000,
397
+ early_stop: bool = True,
398
+ do_report: bool = False
399
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
400
+ loss_fn = nn.MSELoss()
401
+
402
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
403
+ # the order of magnitude of the loss divided by 2
404
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
405
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
406
+ loss = loss_fn(x, dequant)
407
+
408
+ base_lr = 0.1
409
+ exponent = int(np.floor(np.log10(loss.item())))
410
+ lr = base_lr * (10 ** (exponent // 2))
411
+
412
+ # Requires gradients in the parameters
413
+ for p in params.values():
414
+ p.requires_grad = True
415
+ p.grad = None
416
+
417
+ param_keys = list(params.keys())
418
+ param_values = list(params.values())
419
+
420
+ # Defines optimizer and loss function
421
+ optimizer = torch.optim.Adam(param_values, lr=lr)
422
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
423
+
424
+ # Contains the best loss and the best parameters
425
+ best_loss = float("inf")
426
+ best_params = None
427
+
428
+ # Used to stop the search early
429
+ min_delta = 1e-7
430
+ acc_loss = []
431
+ percent_epochs_before_stop = 0.1
432
+
433
+ for i in range(epochs):
434
+ optimizer.zero_grad()
435
+
436
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
437
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
438
+ loss = loss_fn(x, dequant)
439
+
440
+ if loss.isnan() or loss.isinf():
441
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
442
+
443
+ loss.backward()
444
+ optimizer.step()
445
+ scheduler.step()
446
+
447
+ acc_loss.append(loss.item())
448
+
449
+ # Reports loss every 10 steps
450
+ if i % 10 == 0 and do_report:
451
+ print(f"Epoch {i}: Loss {loss.item()}")
452
+
453
+ # Optimizes the parameter search by storing the best loss and the parameters
454
+ if loss.item() < best_loss:
455
+ best_loss = loss.item()
456
+ best_params = copy.deepcopy({
457
+ k: v for k, v in params.items() if k in param_keys
458
+ })
459
+
460
+ # We also stop the search if the loss has not considerably during the last 10% epochs
461
+ if early_stop:
462
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
463
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
464
+ break
465
+
466
+ # No longer requires gradients in the parameters
467
+ for p in best_params.values():
468
+ p.requires_grad = False
469
+ p.grad = None
470
+
471
+ if do_report:
472
+ return best_params, acc_loss
473
+ else:
474
+ return best_params
475
+
476
+
477
+ def quantize(
478
+ x: torch.Tensor,
479
+ params: Dict[str, nn.Parameter],
480
+ func: nn.Module,
481
+ bits: int,
482
+ target_dtype: torch.dtype = torch.int8
483
+ ) -> torch.Tensor:
484
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
485
+ x = x.transpose(0, 1) # Aligns shapes
486
+ x = func(x=x, **params)
487
+ x = x.transpose(0, 1)
488
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
489
+ return x
490
+
491
+
492
+ def dequantize(
493
+ x: torch.Tensor,
494
+ params: Dict[str, nn.Parameter],
495
+ func: nn.Module,
496
+ bits: int,
497
+ out_dtype: torch.dtype
498
+ ) -> torch.Tensor:
499
+ x = x.to(dtype=out_dtype)
500
+ x = x.transpose(0, 1)
501
+ x = func(x=x, **params)
502
+ x = x.transpose(0, 1)
503
+ return x
504
+
505
+
506
+ def round_func_BPDA(input):
507
+ # This is equivalent to replacing round function (non-differentiable) with
508
+ # an identity function (differentiable) only when backward.
509
+ forward_value = torch.round(input)
510
+ out = input.clone()
511
+ out.data = forward_value.data
512
+ return out
513
+
514
+
515
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
516
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
517
+
518
+
519
+
520
+ ############## Numpy ###############
521
+
522
+ def np_domain_guard(
523
+ x: np.ndarray,
524
+ min: float = None,
525
+ max: float = None,
526
+ posinf: float = None,
527
+ neginf: float = None,
528
+ nan: float = None
529
+ ) -> np.ndarray:
530
+ """Guard a tensor to a valid domain."""
531
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
532
+ if min is not None or max is not None:
533
+ x = np.clip(x, min, max)
534
+ return x
535
+
536
+
537
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
538
+ """Replace a number in a tensor with another number.
539
+
540
+ Args:
541
+ x (np.ndarray): The input tensor.
542
+ num (float): The number to replace.
543
+ to (float): The number to replace with.
544
+
545
+ Returns:
546
+ np.ndarray: The tensor with the number replaced.
547
+ """
548
+ return np.where(x == num, to, x)
549
+
550
+
551
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
552
+ """Guard the power operation to a valid domain."""
553
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
554
+
fn_gen/rnd_search/14/loss.png ADDED
fn_gen/rnd_search/14/quantization.png ADDED
fn_gen/rnd_search/16/__pycache__/fn.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
fn_gen/rnd_search/16/distortion.png ADDED
fn_gen/rnd_search/16/expressions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ atan(_0*x)/_s
2
+ tan(_s*x)/_0
fn_gen/rnd_search/16/fn.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import amin # Necessary for arcsin
5
+ import copy
6
+ import torch.nn as nn
7
+ import numpy as np
8
+
9
+ from scipy.optimize import curve_fit
10
+ from typing import Dict, Any, Tuple, List, Callable
11
+
12
+
13
+ def quantization(x, **params):
14
+ return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.atan((params['_0'] * x)))
15
+
16
+
17
+ def dequantization(x, **params):
18
+ return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.tan(domain_guard((params['_s'] * x), posinf=1, neginf=-1, nan=0)))
19
+
20
+
21
+ def init_space_search(
22
+ x: torch.Tensor,
23
+ **kwargs: Dict[str, Any],
24
+ ) -> torch.Tensor:
25
+
26
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
27
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
28
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
29
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
30
+
31
+ def _search_param(tensors: List[torch.tensor], n_params):
32
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
33
+ torch_tensors = torch.stack(tensors)
34
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
35
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
36
+ mean = torch.mean(torch_tensors, dim=0)
37
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
38
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
39
+
40
+ def _calc(x, qtz_func, deqtz_func, **params):
41
+ x_ = x.transpose(0, 1)
42
+ x_ = qtz_func(x=x_, **params)
43
+ x_ = deqtz_func(x=x_, **params)
44
+ x_ = x_.transpose(0, 1)
45
+ return x_
46
+
47
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
48
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
49
+ assert "params_list" in kwargs, "params list must be provided."
50
+ assert "param" in kwargs, "param must be provided."
51
+
52
+ qtz_func = kwargs.get('qtz_func')
53
+ deqtz_func = kwargs.get('deqtz_func')
54
+ params_list = kwargs.get('params_list')
55
+ param = kwargs.get('param')
56
+
57
+ n_runs = 50 # Number of runs to try to find the best parameters
58
+ n_random_params = 50 # Number of random parameters to generate
59
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
60
+ max_initial = 10000 # Maximum value to initialize the parameters
61
+
62
+ # Initializes the parameters
63
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
64
+ params = _build_initial_param(x, max_initial, n_random_params)
65
+
66
+ # Performs the search
67
+ for _ in range(n_runs):
68
+
69
+ best_params = []
70
+ for param_ in params:
71
+ try:
72
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
73
+ loss_ones = nn.MSELoss()(x, x_)
74
+
75
+ if len(best_params) < n_best_to_pick:
76
+ best_params.append((param_, loss_ones.item()))
77
+ best_params = sorted(best_params, key=lambda x: x[1])
78
+ elif loss_ones < best_params[-1][1]:
79
+ best_params[-1] = (param_, loss_ones.item())
80
+ best_params = sorted(best_params, key=lambda x: x[1])
81
+
82
+ except Exception: # The parameters might not be valid for the function's domain
83
+ continue
84
+
85
+ # Generates new parameters around the mean
86
+ params = _search_param([p for p, _ in best_params], n_random_params)
87
+
88
+ return best_params[0][0]
89
+
90
+
91
+ def init_linear_scale( # Symmetric scale. From the study folder
92
+ x: torch.Tensor,
93
+ **kwargs: Dict[str, Any],
94
+ ) -> torch.Tensor:
95
+ assert "bits" in kwargs, "bits must be provided."
96
+ assert "params" in kwargs, "params must be provided."
97
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
98
+
99
+ bits = kwargs.get('bits')
100
+ params = kwargs.get('params')
101
+ qtz_func = kwargs.get('qtz_func')
102
+
103
+ x_ = x.transpose(0, 1)
104
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
105
+ x_ = x_.transpose(0, 1)
106
+
107
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
108
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
109
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
110
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
111
+
112
+ eps = torch.finfo(torch.float32).eps
113
+
114
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
115
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
116
+
117
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
118
+
119
+ # Introduces some noise in scale
120
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
121
+ # scale = scale + 0.01 * torch.randn_like(scale)
122
+ return scale
123
+
124
+
125
+ def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
126
+ params = {
127
+ '_0': init_space_search(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
128
+ }
129
+ params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
130
+ params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
131
+
132
+ if 'post_init_hook' in kwargs:
133
+ kwargs['post_init_hook'](parameters=params)
134
+
135
+
136
+ if 'post_train_hook' in kwargs:
137
+ kwargs['post_train_hook'](parameters=params)
138
+
139
+ return params
140
+
141
+
142
+ ############### Numpy Qtz ###############
143
+
144
+
145
+ def np_quantization(x, _0, _s):
146
+ return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.arctan((_0 * x)))
147
+
148
+
149
+ def np_dequantization(x, _0, _s):
150
+ return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.tan(np_domain_guard((_s * x), posinf=1, neginf=-1, nan=0)))
151
+
152
+
153
+ def fit_func(x, _0, _s):
154
+ x_ = np_quantization(x, _0, _s)
155
+ x_ = np_dequantization(x_, _0, _s)
156
+ return x_
157
+
158
+
159
+
160
+ ############### HELPERS ###############
161
+
162
+ def domain_guard(
163
+ x: torch.Tensor,
164
+ min: float = None,
165
+ max: float = None,
166
+ posinf: float = None,
167
+ neginf: float = None,
168
+ nan: float = None
169
+ ) -> torch.Tensor:
170
+ """Guard a tensor to a valid domain."""
171
+ x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
172
+ if min is not None or max is not None:
173
+ x = torch.clamp(x, min=min, max=max)
174
+ return x
175
+
176
+
177
+ def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
178
+ """Replace a number in a tensor with another number.
179
+
180
+ Args:
181
+ x (torch.Tensor): The input tensor.
182
+ num (float): The number to replace.
183
+ to (float): The number to replace with.
184
+
185
+ Returns:
186
+ torch.Tensor: The tensor with the number replaced.
187
+ """
188
+ return torch.where(x == num, to, x)
189
+
190
+
191
+ def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
192
+ """Guard the power operation to a valid domain."""
193
+ return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)
194
+
195
+
196
+ def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
197
+ val = torch.amin(x, dim=1)
198
+ return torch.ones_like(val, dtype=torch.float32, device=x.device)
199
+
200
+
201
+ def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
202
+ val = torch.amin(x, dim=1)
203
+ return torch.randn_like(val, dtype=torch.float32, device=x.device)
204
+
205
+
206
+ def init_space_search(
207
+ x: torch.Tensor,
208
+ **kwargs: Dict[str, Any],
209
+ ) -> torch.Tensor:
210
+
211
+ def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
212
+ """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
213
+ for _ in range(n_params * 10): # The first iteration generates 10 times more parameters
214
+ yield init_rand(tensor) * max_initial # Generates n_params in range [-max_initial, max_initial]
215
+
216
+ def _search_param(tensors: List[torch.tensor], n_params):
217
+ """Takes the best parameters and generates new parameters around the mean of the best parameters."""
218
+ torch_tensors = torch.stack(tensors)
219
+ min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
220
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
221
+ mean = torch.mean(torch_tensors, dim=0)
222
+ for _ in range(n_params): # Generates n_params around the mean of the tensors
223
+ yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean
224
+
225
+ def _calc(x, qtz_func, deqtz_func, **params):
226
+ x_ = x.transpose(0, 1)
227
+ x_ = qtz_func(x=x_, **params)
228
+ x_ = deqtz_func(x=x_, **params)
229
+ x_ = x_.transpose(0, 1)
230
+ return x_
231
+
232
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
233
+ assert "deqtz_func" in kwargs, "deqtz_func must be provided."
234
+ assert "params_list" in kwargs, "params list must be provided."
235
+ assert "param" in kwargs, "param must be provided."
236
+
237
+ qtz_func = kwargs.get('qtz_func')
238
+ deqtz_func = kwargs.get('deqtz_func')
239
+ params_list = kwargs.get('params_list')
240
+ param = kwargs.get('param')
241
+
242
+ n_runs = 50 # Number of runs to try to find the best parameters
243
+ n_random_params = 50 # Number of random parameters to generate
244
+ n_best_to_pick = 5 # Number of best parameters to pick after each run
245
+ max_initial = 10000 # Maximum value to initialize the parameters
246
+
247
+ # Initializes the parameters
248
+ base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
249
+ params = _build_initial_param(x, max_initial, n_random_params)
250
+
251
+ # Performs the search
252
+ for _ in range(n_runs):
253
+
254
+ best_params = []
255
+ for param_ in params:
256
+ try:
257
+ x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
258
+ loss_ones = nn.MSELoss()(x, x_)
259
+
260
+ if len(best_params) < n_best_to_pick:
261
+ best_params.append((param_, loss_ones.item()))
262
+ best_params = sorted(best_params, key=lambda x: x[1])
263
+ elif loss_ones < best_params[-1][1]:
264
+ best_params[-1] = (param_, loss_ones.item())
265
+ best_params = sorted(best_params, key=lambda x: x[1])
266
+
267
+ except Exception: # The parameters might not be valid for the function's domain
268
+ continue
269
+
270
+ # Generates new parameters around the mean
271
+ params = _search_param([p for p, _ in best_params], n_random_params)
272
+
273
+ return best_params[0][0]
274
+
275
+
276
+ def init_linear_scale( # Symmetric scale. From the study folder
277
+ x: torch.Tensor,
278
+ **kwargs: Dict[str, Any],
279
+ ) -> torch.Tensor:
280
+ assert "bits" in kwargs, "bits must be provided."
281
+ assert "params" in kwargs, "params must be provided."
282
+ assert "qtz_func" in kwargs, "qtz_func must be provided."
283
+
284
+ bits = kwargs.get('bits')
285
+ params = kwargs.get('params')
286
+ qtz_func = kwargs.get('qtz_func')
287
+
288
+ x_ = x.transpose(0, 1)
289
+ x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
290
+ x_ = x_.transpose(0, 1)
291
+
292
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
293
+ min_vals, max_vals = torch.aminmax(x_, dim=1)
294
+ min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
295
+ max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
296
+
297
+ eps = torch.finfo(torch.float32).eps
298
+
299
+ abs_max_val_per_ch = torch.max(-min_vals, max_vals)
300
+ scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
301
+
302
+ scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)
303
+
304
+ # Introduces some noise in scale
305
+ # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
306
+ # scale = scale + 0.01 * torch.randn_like(scale)
307
+ return scale
308
+
309
+
310
+ def init_non_linear_regression_fit(
311
+ x: torch.Tensor,
312
+ **kwargs: Dict[str, Any],
313
+ ) -> torch.Tensor:
314
+
315
+ assert "params_list" in kwargs, "params list must be provided."
316
+ assert "np_fit_func" in kwargs, "np_fit_func must be provided."
317
+ assert "p0" in kwargs, "p0 must be provided."
318
+ np_fit_func = kwargs.get('np_fit_func')
319
+ params_list = kwargs.get('params_list')
320
+ p0 = kwargs.get('p0')
321
+
322
+ def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
323
+ popt, _ = curve_fit(
324
+ func,
325
+ xdata,
326
+ ydata,
327
+ maxfev=1000,
328
+ p0=p0,
329
+ method='lm'
330
+ )
331
+ return popt
332
+
333
+ # 1. Needs to convert the torch tensor to numpy tensor
334
+ xdata = x.cpu().numpy()
335
+
336
+ # 2. Sorts the data so that it makes it easier to fit to it
337
+ sorted_xdata = np.sort(xdata, axis=-1)
338
+
339
+ p0 = {k: v.cpu().numpy() for k, v in p0.items()}
340
+ params_list = sorted(params_list) # We need to make sure that it matches the numpy fit func arg order
341
+
342
+ # 3. Finds the best parameters for each channel
343
+ try:
344
+ params = []
345
+ for i in range(sorted_xdata.shape[0]):
346
+ xdata_ = sorted_xdata[i]
347
+ p0_ = [p0[p][i] for p in params_list]
348
+ ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
349
+ params.append(ch_params)
350
+
351
+ # 4. Builds the parameters
352
+ result = {}
353
+ for i, p in enumerate(params_list):
354
+ result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)
355
+
356
+ return result
357
+
358
+ except ValueError as e:
359
+ print(f"Could not fit the function with error: {e}")
360
+ print(f"Using fallback result...")
361
+ return {
362
+ k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
363
+ }
364
+
365
+
366
+ def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
367
+ val = torch.amin(x, dim=1)
368
+ return torch.zeros_like(val, dtype=torch.float32, device=x.device)
369
+
370
+
371
+ def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
372
+ # Calculate the original minimum and maximum values
373
+ min_vals, max_vals = torch.aminmax(tensor, dim=-1)
374
+ x_min = torch.min(min_vals, torch.zeros_like(min_vals))
375
+ x_max = torch.max(max_vals, torch.zeros_like(max_vals))
376
+
377
+ if _max is torch.inf: # We do not need to scale the tensor. Just need to move it
378
+ return torch.ones_like(x_min)
379
+
380
+ # Calculate the scale factor
381
+ scale = (_max - _min) / (x_max - x_min)
382
+ return scale
383
+
384
+
385
+
386
+ ############## Quant ###############
387
+
388
+ @torch.enable_grad()
389
+ def learn_parameters(
390
+ x: torch.Tensor,
391
+ params: Dict[str, nn.Parameter],
392
+ qtz_func: nn.Module,
393
+ deqtz_func: nn.Module,
394
+ bits: int,
395
+ target_dtype: torch.dtype,
396
+ epochs: int = 1000,
397
+ early_stop: bool = True,
398
+ do_report: bool = False
399
+ ) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
400
+ loss_fn = nn.MSELoss()
401
+
402
+ # Determines the initial learning rate by computing the initial loss and multiplying it by
403
+ # the order of magnitude of the loss divided by 2
404
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
405
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
406
+ loss = loss_fn(x, dequant)
407
+
408
+ base_lr = 0.1
409
+ exponent = int(np.floor(np.log10(loss.item())))
410
+ lr = base_lr * (10 ** (exponent // 2))
411
+
412
+ # Requires gradients in the parameters
413
+ for p in params.values():
414
+ p.requires_grad = True
415
+ p.grad = None
416
+
417
+ param_keys = list(params.keys())
418
+ param_values = list(params.values())
419
+
420
+ # Defines optimizer and loss function
421
+ optimizer = torch.optim.Adam(param_values, lr=lr)
422
+ scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=epochs // 10)
423
+
424
+ # Contains the best loss and the best parameters
425
+ best_loss = float("inf")
426
+ best_params = None
427
+
428
+ # Used to stop the search early
429
+ min_delta = 1e-7
430
+ acc_loss = []
431
+ percent_epochs_before_stop = 0.1
432
+
433
+ for i in range(epochs):
434
+ optimizer.zero_grad()
435
+
436
+ quant = quantize(x, params, qtz_func, bits, target_dtype)
437
+ dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
438
+ loss = loss_fn(x, dequant)
439
+
440
+ if loss.isnan() or loss.isinf():
441
+ raise Exception("Loss is NaN or Inf. Stopping the search.")
442
+
443
+ loss.backward()
444
+ optimizer.step()
445
+ scheduler.step()
446
+
447
+ acc_loss.append(loss.item())
448
+
449
+ # Reports loss every 10 steps
450
+ if i % 10 == 0 and do_report:
451
+ print(f"Epoch {i}: Loss {loss.item()}")
452
+
453
+ # Optimizes the parameter search by storing the best loss and the parameters
454
+ if loss.item() < best_loss:
455
+ best_loss = loss.item()
456
+ best_params = copy.deepcopy({
457
+ k: v for k, v in params.items() if k in param_keys
458
+ })
459
+
460
+ # We also stop the search if the loss has not considerably during the last 10% epochs
461
+ if early_stop:
462
+ epochs_before_stop = int(epochs * percent_epochs_before_stop)
463
+ if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
464
+ break
465
+
466
+ # No longer requires gradients in the parameters
467
+ for p in best_params.values():
468
+ p.requires_grad = False
469
+ p.grad = None
470
+
471
+ if do_report:
472
+ return best_params, acc_loss
473
+ else:
474
+ return best_params
475
+
476
+
477
+ def quantize(
478
+ x: torch.Tensor,
479
+ params: Dict[str, nn.Parameter],
480
+ func: nn.Module,
481
+ bits: int,
482
+ target_dtype: torch.dtype = torch.int8
483
+ ) -> torch.Tensor:
484
+ quant_min, quant_max = get_min_max_from_bits_signed(bits)
485
+ x = x.transpose(0, 1) # Aligns shapes
486
+ x = func(x=x, **params)
487
+ x = x.transpose(0, 1)
488
+ x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
489
+ return x
490
+
491
+
492
+ def dequantize(
493
+ x: torch.Tensor,
494
+ params: Dict[str, nn.Parameter],
495
+ func: nn.Module,
496
+ bits: int,
497
+ out_dtype: torch.dtype
498
+ ) -> torch.Tensor:
499
+ x = x.to(dtype=out_dtype)
500
+ x = x.transpose(0, 1)
501
+ x = func(x=x, **params)
502
+ x = x.transpose(0, 1)
503
+ return x
504
+
505
+
506
+ def round_func_BPDA(input):
507
+ # This is equivalent to replacing round function (non-differentiable) with
508
+ # an identity function (differentiable) only when backward.
509
+ forward_value = torch.round(input)
510
+ out = input.clone()
511
+ out.data = forward_value.data
512
+ return out
513
+
514
+
515
+ def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
516
+ return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1
517
+
518
+
519
+
520
+ ############## Numpy ###############
521
+
522
+ def np_domain_guard(
523
+ x: np.ndarray,
524
+ min: float = None,
525
+ max: float = None,
526
+ posinf: float = None,
527
+ neginf: float = None,
528
+ nan: float = None
529
+ ) -> np.ndarray:
530
+ """Guard a tensor to a valid domain."""
531
+ x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
532
+ if min is not None or max is not None:
533
+ x = np.clip(x, min, max)
534
+ return x
535
+
536
+
537
+ def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
538
+ """Replace a number in a tensor with another number.
539
+
540
+ Args:
541
+ x (np.ndarray): The input tensor.
542
+ num (float): The number to replace.
543
+ to (float): The number to replace with.
544
+
545
+ Returns:
546
+ np.ndarray: The tensor with the number replaced.
547
+ """
548
+ return np.where(x == num, to, x)
549
+
550
+
551
+ def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
552
+ """Guard the power operation to a valid domain."""
553
+ return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)
554
+
fn_gen/rnd_search/16/loss.png ADDED
fn_gen/rnd_search/16/quantization.png ADDED
fn_gen/rnd_search/17/__pycache__/fn.cpython-311.pyc ADDED
Binary file (28.2 kB). View file
 
fn_gen/rnd_search/17/distortion.png ADDED