File size: 2,114 Bytes
bc679dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import numpy as np
import json

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy import spatial

parser = argparse.ArgumentParser()

parser.add_argument("--features-database", type=str, required=True)
parser.add_argument("--decomposition", type=str, default="TSNE", choices=["TSNE", "PCA"])
parser.add_argument("--output", type=str, default="")
parser.add_argument("--fig-h", type=int, default=1080)
parser.add_argument("--fig-w", type=int, default=720)
parser.add_argument("--fig-dpi", type=int, default=100)
parser.add_argument("--distance", type=str, default="cosine")

parser.add_argument("--point", type=str, default="")

args = parser.parse_args()
point = None
if args.point != "":
    point = json.loads(args.point)


dist_fn = getattr(spatial.distance, args.distance)
features_database = json.load(open(args.features_database, "r"))

features = []
classes = []
for name, feature_list in features_database.items():
    for feature in feature_list:
        features.append(feature["features"])
        classes.append(feature["type"])

if point is not None:
    features.append(point)

features = np.array(features)
classes = np.array(classes)

if args.decomposition == "TSNE":
    decomposition = TSNE(n_components=2, metric=dist_fn)
elif args.decomposition == "PCA":
    decomposition = PCA(n_components=2)
transformed = decomposition.fit_transform(features)

if point is not None:
    transformed = transformed[:-1,:]
    transformed_point = transformed[-1,:]

plt.figure(figsize=(args.fig_h/args.fig_dpi, args.fig_w/args.fig_dpi), dpi=args.fig_dpi)
cmap = ListedColormap(["r","b","g"])
scatter = plt.scatter(transformed[:, 0], transformed[:, 1], c=classes, cmap=cmap, s=10)

if point is not None:
    plt.scatter(transformed_point[0], transformed_point[1], marker="x", s=200, c="k")

plt.legend(handles=scatter.legend_elements()[0], labels=["neoplastic", "aphthous", "traumatic"])

if args.output == "":
    plt.show()
else:
    plt.savefig(args.output, dpi=args.fig_dpi)