File size: 2,460 Bytes
ea6afa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata

def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
    # Convert input lists to NumPy arrays
    detectability = np.array(detectability_val)
    distortion = np.array(distortion_val)
    euclidean = np.array(euclidean_val)

    # Normalize the values to range [0, 1]
    def normalize(data):
        min_val, max_val = np.min(data), np.max(data)
        return (data - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(data)

    norm_detectability = normalize(detectability)
    norm_distortion = normalize(distortion)
    norm_euclidean = normalize(euclidean)

    # Composite score: maximize detectability, minimize distortion and Euclidean distance
    composite_score = norm_detectability - (norm_distortion + norm_euclidean)

    # Sweet spot values
    sweet_spot_index = np.argmax(composite_score)
    sweet_spot = (detectability[sweet_spot_index], distortion[sweet_spot_index], euclidean[sweet_spot_index])

    # Create a meshgrid for interpolation
    x_grid, y_grid = np.meshgrid(
        np.linspace(np.min(detectability), np.max(detectability), 30),
        np.linspace(np.min(distortion), np.max(distortion), 30)
    )

    # Interpolate z values (Euclidean distances) to fit the grid
    z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')

    if z_grid is None:
        raise ValueError("griddata could not generate a valid interpolation. Check your input data.")

    # Create the 3D contour plot with the Plasma color scale
    fig = go.Figure(data=go.Surface(
        z=z_grid,
        x=x_grid,
        y=y_grid,
        contours={"z": {"show": True, "start": np.min(euclidean), "end": np.max(euclidean), "size": 0.1, "usecolormap": True}},
        colorscale='Plasma'
    ))

    # Add a marker for the sweet spot
    fig.add_trace(go.Scatter3d(
        x=[sweet_spot[0]],
        y=[sweet_spot[1]],
        z=[sweet_spot[2]],
        mode='markers+text',
        marker=dict(size=10, color='red', symbol='circle'),
        text=["Sweet Spot"],
        textposition="top center"
    ))

    # Set axis labels
    fig.update_layout(
        scene=dict(
            xaxis_title='Detectability Score',
            yaxis_title='Distortion Score',
            zaxis_title='Euclidean Distance'
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )

    return fig