Spaces:
Sleeping
Sleeping
File size: 1,493 Bytes
bc679dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import json
import argparse
from scipy import spatial
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=str, required=True)
parser.add_argument("--cols", type=str, required=True)
parser.add_argument("--distance", type=str, default="cosine")
parser.add_argument("--output", type=str, default="")
args = parser.parse_args()
rows_features = json.load(open(args.rows, "r"))
cols_features = json.load(open(args.cols, "r"))
dist_fn = getattr(spatial.distance, args.distance)
rows_features_rois = []
cols_features_rois = []
for row_feature in rows_features.values():
for roi_feature in row_feature:
rows_features_rois.append(roi_feature)
for col_feature in cols_features.values():
for roi_feature in col_feature:
cols_features_rois.append(roi_feature)
rows_features_rois = sorted(rows_features_rois, key=lambda e: e["type"])
cols_features_rois = sorted(cols_features_rois, key=lambda e: e["type"])
matrix = np.zeros((len(rows_features_rois), len(cols_features_rois)))
for i, row in tqdm(enumerate(rows_features_rois), total=len(rows_features_rois)):
for j, col in enumerate(cols_features_rois):
matrix[i, j] = dist_fn(row["features"], col["features"])
fig, ax = plt.subplots()
ax.set_xlabel(args.rows)
ax.set_ylabel(args.cols)
pos = ax.imshow(matrix)
fig.colorbar(pos, ax=ax)
if args.output == "":
plt.show()
else:
plt.savefig(args.output)
|