oral-lesions-detection / plots /plot_matrix_distance.py
Federico Galatolo
work in progress
bc679dd
raw
history blame
1.49 kB
import json
import argparse
from scipy import spatial
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=str, required=True)
parser.add_argument("--cols", type=str, required=True)
parser.add_argument("--distance", type=str, default="cosine")
parser.add_argument("--output", type=str, default="")
args = parser.parse_args()
rows_features = json.load(open(args.rows, "r"))
cols_features = json.load(open(args.cols, "r"))
dist_fn = getattr(spatial.distance, args.distance)
rows_features_rois = []
cols_features_rois = []
for row_feature in rows_features.values():
for roi_feature in row_feature:
rows_features_rois.append(roi_feature)
for col_feature in cols_features.values():
for roi_feature in col_feature:
cols_features_rois.append(roi_feature)
rows_features_rois = sorted(rows_features_rois, key=lambda e: e["type"])
cols_features_rois = sorted(cols_features_rois, key=lambda e: e["type"])
matrix = np.zeros((len(rows_features_rois), len(cols_features_rois)))
for i, row in tqdm(enumerate(rows_features_rois), total=len(rows_features_rois)):
for j, col in enumerate(cols_features_rois):
matrix[i, j] = dist_fn(row["features"], col["features"])
fig, ax = plt.subplots()
ax.set_xlabel(args.rows)
ax.set_ylabel(args.cols)
pos = ax.imshow(matrix)
fig.colorbar(pos, ax=ax)
if args.output == "":
plt.show()
else:
plt.savefig(args.output)