import seaborn as sns import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap import torch def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=None): """ Plot a heatmap of tensors using seaborn """ sns.heatmap(tensor.cpu().numpy(), ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, annot=True, fmt=".2f", cbar=False) ax.set_title(title) ax.set_yticklabels([]) ax.set_xticklabels([]) def plot_quantization_errors(original_tensor, quantized_tensor, dequantized_tensor, dtype = torch.int8, n_bits = 8): """ A method that plots 4 matrices, the original tensor, the quantized tensor the de-quantized tensor and the error tensor. """ # Get a figure of 4 plots fig, axes = plt.subplots(1, 4, figsize=(15, 4)) # Plot the first matrix plot_matrix(original_tensor, axes[0], 'Original Tensor', cmap=ListedColormap(['white'])) # Get the quantization range and plot the quantized tensor q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max plot_matrix(quantized_tensor, axes[1], f'{n_bits}-bit Linear Quantized Tensor', vmin=q_min, vmax=q_max, cmap='coolwarm') # Plot the de-quantized tensors plot_matrix(dequantized_tensor, axes[2], 'Dequantized Tensor', cmap='coolwarm') # Get the quantization errors q_error_tensor = abs(original_tensor - dequantized_tensor) plot_matrix(q_error_tensor, axes[3], 'Quantization Error Tensor', cmap=ListedColormap(['white'])) fig.tight_layout() plt.show()