File size: 13,941 Bytes
d64a508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
---
title: ""
---

# Introduction


```{python}
#| echo: false
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import torch

def quantization_error(tensor, dequantized_tensor):
    return (dequantized_tensor - tensor).abs().square().mean()




def plot_quantization_errors(original_tensor, quantized_tensor, dequantized_tensor, dtype=torch.int8, n_bits=8):
    """

    A method that plots 4 matrices, the original tensor, the quantized tensor,

    the de-quantized tensor, and the error tensor in a 2x2 grid.

    """
    # Create a figure of 4 plots arranged in 2 rows and 2 columns
    fig, axes = plt.subplots(2, 2, figsize=(8, 4))  # Adjust the size as needed

    # Flatten the axes array for easier indexing
    axes = axes.flatten()

    # Plot the original tensor
    plot_matrix(original_tensor, axes[0], 'Original Tensor', cmap=ListedColormap(['white']))

    # Get the quantization range and plot the quantized tensor
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    plot_matrix(quantized_tensor, axes[1], f'{n_bits}-bit Linear Quantized Tensor', vmin=q_min, vmax=q_max, cmap='coolwarm')

    # Plot the de-quantized tensor
    plot_matrix(dequantized_tensor, axes[2], 'Dequantized Tensor', cmap='coolwarm')

    # Calculate and plot quantization errors
    q_error_tensor = abs(original_tensor - dequantized_tensor)
    plot_matrix(q_error_tensor, axes[3], 'Quantization Error Tensor', cmap=ListedColormap(['white']))

    # Adjust layout to prevent overlap
    fig.tight_layout()
    plt.show()

def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=None):
    """

    Plot a heatmap of tensors using seaborn

    """
    sns.heatmap(tensor.cpu().numpy(), ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, annot=True, fmt=".2f", cbar=False)
    ax.set_title(title)
    ax.set_yticklabels([])
    ax.set_xticklabels([])


def linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype = torch.int8):
    scaled_and_shifted_tensor = tensor / scale + zero_point
    rounded_tensor = torch.round(scaled_and_shifted_tensor)
    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max
    q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
    
    return q_tensor


```


```{python}
#| echo: false
# Set the random seed for reproducibility
torch.manual_seed(41)

# Define the desired range
a = -1024  # Lower bound of the range
b = 1024   # Upper bound of the range

# Create a 6x6 matrix with random numbers in the range [a, b]
test_tensor = a + (b - a) * torch.rand(6, 6)
test_tensor
```

# Mastering Tensor Quantization in PyTorch

Quantization is a powerful technique used to reduce the memory footprint of neural networks, making them faster and more efficient, particularly on devices with limited computational power like mobile phones and embedded systems. This guide dives deep into how quantization works using PyTorch and provides a step-by-step approach to quantize tensors effectively.

### Implementing Asymmetric Quantization in PyTorch

Quantization in the context of deep learning involves approximating a high-precision tensor (like a floating point tensor) with a lower-precision format (like integers). This is crucial for deploying models on hardware that supports or performs better with lower precision arithmetic.

Let's begin by understanding the fundamental components needed for quantization—scale and zero point. The `scale` is a factor that adjusts the tensor's range to match the dynamic range of the target data type (e.g., `int8`), and the `zero point` is used to align the tensor around zero.

### Determining Scale and Zero Point

First, you need the minimum and maximum values that your chosen data type can hold. Here’s how you can find these for the `int8` type in PyTorch:

```{python}
import torch
q_min = torch.iinfo(torch.int8).min
q_max = torch.iinfo(torch.int8).max
print(f"Min: {q_min}, Max: {q_max}")
```

For our tensor `test_tensor`, find the minimum and maximum values:

```{python}
r_min = test_tensor.min().item()
r_max = test_tensor.max().item()
print(f"Min: {r_min}, Max: {r_max}")
```

With these values, you can compute the `scale` and `zero_point`:

```{python}
scale = (r_max - r_min) / (q_max - q_min)
zero_point = q_min - (r_min / scale)
print(f"Scale: {scale}, Zero-Point: {zero_point}")
```

### Automating Quantization

To streamline the process, you can define a function `get_q_scale_and_zero_point` that automatically computes the `scale` and `zero_point`:

```{python}
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    r_min = tensor.min().item()
    r_max = tensor.max().item()
    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max
    scale = (r_max - r_min) / (q_max - q_min)
    zero_point = q_min - (r_min / scale)
    return scale, zero_point
```

### Applying Quantization and Dequantization

Now, let's quantize and dequantize a tensor using the derived scale and zero point. The quantization maps real values to integer values using the scale and zero point:



```{python}

def linear_quantization(tensor, dtype=torch.int8):

    scale, zero_point = get_q_scale_and_zero_point(tensor, dtype=dtype)

    quantized_tensor = linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype=dtype)

    return quantized_tensor, scale, zero_point



def linear_dequantization(quantized_tensor, scale, zero_point):

    dequantized_tensor = scale * (quantized_tensor.float() - zero_point)

    return dequantized_tensor

```



### Visualization of Quantization Effects



Finally, it’s insightful to visualize the effects of quantization:



```{python}

quantized_tensor, scale, zero_point = linear_quantization(test_tensor)

dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)



plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)

```





```{python}

# Calculate and print quantization error

error = quantization_error(test_tensor, dequantized_tensor)

print(f"Quantization Error: {error}")

```







## Implementing Symmetric Quantization in PyTorch



Quantization is a technique used to reduce model size and speed up inference by approximating floating point numbers with integers. Symmetric quantization is a specific type of quantization where the number range is symmetric around zero. This simplifies the quantization process as the zero point is fixed at zero, eliminating the need to compute or store it. Here, we explore how to implement symmetric quantization in PyTorch.



### Calculating the Scale for Symmetric Quantization



The scale factor in symmetric quantization is crucial as it defines the conversion ratio between the floating point values and their integer representations. The scale is computed based on the maximum absolute value in the tensor and the maximum value storable in the specified integer data type. Here's how you can calculate the scale:

```{python}
def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()  # Get the maximum absolute value in the tensor
    q_max = torch.iinfo(dtype).max     # Get the maximum storable value for the dtype

    # Return the scale
    return r_max / q_max
```

### Testing the Scale Calculation

We'll test this function using a random 4x4 tensor:



```{python}

print(get_q_scale_symmetric(test_tensor))

```



### Performing Symmetric Quantization



Once the scale is determined, the tensor can be quantized. This involves converting the floating-point numbers to integers based on the scale. Here’s how to do it:



### Quantization Equation

The quantization equation transforms the original floating-point values into quantized integer values. This is achieved by scaling the original values down by the scale factor, then rounding them to the nearest integer, and finally adjusting by the zero-point:



$$ 

\text{Quantized Value} = \text{round}\left(\frac{\text{Original Value}}{\text{Scale}}\right) + \text{Zero-point}

$$



### Dequantization Equation

The dequantization equation reverses the quantization process to approximate the original floating-point values from the quantized integers. This involves subtracting the zero-point from the quantized value, and then scaling it up by the scale factor:



$$ 

\text{Dequantized Value} = (\text{Quantized Value} - \text{Zero-point}) \times \text{Scale}

$$



These equations are fundamental to understanding how data is compressed and decompressed in the process of quantization and dequantization, allowing for efficient storage and computation in neural network models.



```{python}

def linear_q_symmetric(tensor, dtype=torch.int8):

    scale = get_q_scale_symmetric(tensor)  # Calculate the scale

    

    # Perform quantization with zero_point = 0 for symmetric mode

    quantized_tensor = linear_q_with_scale_and_zero_point(tensor, scale=scale, zero_point=0, dtype=dtype)

    

    return quantized_tensor, scale



quantized_tensor, scale = linear_q_symmetric(test_tensor)

```



### Dequantization and Error Visualization



Dequantization is the reverse process of quantization, converting integers back to floating-point numbers using the same scale and zero point. Here's how to dequantize and plot quantization errors:

```{python}

dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point=0)

plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)
```


```{python}
error = quantization_error(test_tensor, dequantized_tensor)
print(f"Quantization Error: {error}")
```

### Understanding Per-Tensor Quantization

In per-tensor quantization, a single scale and zero point based on the entire tensor's range are used. This is particularly useful for tensors where values do not vary significantly in magnitude across different dimensions. It simplifies the quantization process by maintaining uniformity.



### Testing with a Sample Tensor



We'll quantize a predefined tensor to understand how per-tensor symmetric quantization is implemented:

```{python}
quantized_tensor, scale = linear_q_symmetric(test_tensor)
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
```

### Visualizing Quantization Errors

To assess the impact of quantization on tensor values, we'll visualize the errors between original and dequantized tensors:



```{python}

plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)

```



### Quantization Error Analysis



Quantization error is a critical metric to evaluate the loss of information due to quantization. It is calculated as the difference between original and dequantized values:



```{python}

# Calculate and print quantization error

error = quantization_error(test_tensor, dequantized_tensor)

print(f"Quantization Error: {error}")

```



##  Understanding Per-channel Quantization





In per-channel quantization, each channel of a tensor (e.g., the weight tensor in convolutional layers) is treated as an independent unit for quantization. Here's a basic outline of the process:

1. **Determine Scale and Zero-point**: For each channel, calculate a scale and zero-point based on the range of data values present in that channel. This might involve finding the minimum and maximum values of each channel and then using these values to compute the scale and zero-point that map the floating-point numbers to integers.

2. **Quantization**: Apply the quantization formula to each channel using its respective scale and zero-point. This step converts the floating-point values to integers.

   $$ 
   \text{Quantized Value} = \text{round}\left(\frac{\text{Original Value}}{\text{Scale}}\right) + \text{Zero-point}
   $$

3. **Storage and Computation**: The quantized values are stored and used for computations in the quantized model. The unique scales and zero-points for each channel are also stored for use during dequantization or inference.

4. **Dequantization**: To convert the quantized integers back to floating-point numbers (e.g., during inference), the inverse operation is performed using the per-channel scales and zero-points.

   $$ 
   \text{Dequantized Value} = (\text{Quantized Value} - \text{Zero-point}) \times \text{Scale} 
   $$


```{python}
def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8):
    
    output_dim = r_tensor.shape[dim]
    # store the scales
    scale = torch.zeros(output_dim)

    for index in range(output_dim):
        sub_tensor = r_tensor.select(dim, index)
        scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)

    # reshape the scale
    scale_shape = [1] * r_tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = linear_q_with_scale_and_zero_point(
        r_tensor, scale=scale, zero_point=0, dtype=dtype)
   
    return quantized_tensor, scale


```


### Scaled on Columns (Dim 0)


```{python}
quantized_tensor_0, scale_0 = linear_q_symmetric_per_channel(test_tensor, dim=0)

dequantized_tensor_0 = linear_dequantization(quantized_tensor_0, scale_0, 0)

plot_quantization_errors(
    test_tensor, quantized_tensor_0, dequantized_tensor_0)
```




```{python}
print(f"""Quantization Error : {quantization_error(test_tensor, dequantized_tensor_0)}""")
```


### Scaled on Columns (Dim 1)

```{python}

quantized_tensor_1, scale_1 = linear_q_symmetric_per_channel(test_tensor, dim=1)

dequantized_tensor_1 = linear_dequantization(quantized_tensor_1, scale_1, 0)

plot_quantization_errors(
    test_tensor, quantized_tensor_1, dequantized_tensor_1)
```


```{python}
print(f"""Quantization Error : {quantization_error(test_tensor, dequantized_tensor_1)}""")
```