aiisc-watermarking-modelv3 / threeD_plot.py
jgyasu's picture
Upload folder using huggingface_hub
436c4c1 verified
# import numpy as np
# import plotly.graph_objects as go
# from scipy.interpolate import griddata
# def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
# detectability = np.array(detectability_val)
# distortion = np.array(distortion_val)
# euclidean = np.array(euclidean_val)
# # Find the closest point to the origin
# distances_to_origin = np.linalg.norm(np.array([distortion, detectability, euclidean]).T, axis=1)
# closest_point_index = np.argmin(distances_to_origin)
# # Determine the closest points to each axis
# closest_to_x_axis = np.argmin(distortion)
# closest_to_y_axis = np.argmin(detectability)
# closest_to_z_axis = np.argmin(euclidean)
# # Use the detected closest point as the "sweet spot"
# sweet_spot_detectability = detectability[closest_point_index]
# sweet_spot_distortion = distortion[closest_point_index]
# sweet_spot_euclidean = euclidean[closest_point_index]
# # Create a meshgrid from the data
# x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
# np.linspace(min(distortion), max(distortion), 30))
# # Interpolate z values (Euclidean distances) to fit the grid
# z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')
# if z_grid is None:
# raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
# # Create the 3D contour plot with the Plasma color scale
# fig = go.Figure(data=go.Surface(
# z=z_grid,
# x=x_grid,
# y=y_grid,
# contours={
# "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
# },
# colorscale='Plasma'
# ))
# # Add a marker for the sweet spot
# fig.add_trace(go.Scatter3d(
# x=[sweet_spot_detectability],
# y=[sweet_spot_distortion],
# z=[sweet_spot_euclidean],
# mode='markers+text',
# marker=dict(size=10, color='red', symbol='circle'),
# text=["Sweet Spot"],
# textposition="top center"
# ))
# # Set axis labels
# fig.update_layout(
# scene=dict(
# xaxis_title='Detectability Score',
# yaxis_title='Distortion Score',
# zaxis_title='Euclidean Distance'
# ),
# margin=dict(l=0, r=0, b=0, t=0)
# )
# return fig
import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
detectability = np.array(detectability_val)
distortion = np.array(distortion_val)
euclidean = np.array(euclidean_val)
# Normalize the values to range [0, 1]
norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))
# Composite score: maximize detectability, minimize distortion and Euclidean distance
# We subtract distortion and euclidean as we want them minimized.
composite_score = norm_detectability - (norm_distortion + norm_euclidean)
# Find the index of the maximum score (sweet spot)
sweet_spot_index = np.argmax(composite_score)
# Sweet spot values
sweet_spot_detectability = detectability[sweet_spot_index]
sweet_spot_distortion = distortion[sweet_spot_index]
sweet_spot_euclidean = euclidean[sweet_spot_index]
# Create a meshgrid from the data
x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
np.linspace(min(distortion), max(distortion), 30))
# Interpolate z values (Euclidean distances) to fit the grid
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')
if z_grid is None:
raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
# Create the 3D contour plot with the Plasma color scale
fig = go.Figure(data=go.Surface(
z=z_grid,
x=x_grid,
y=y_grid,
contours={
"z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
},
colorscale='Plasma'
))
# Add a marker for the sweet spot
fig.add_trace(go.Scatter3d(
x=[sweet_spot_detectability],
y=[sweet_spot_distortion],
z=[sweet_spot_euclidean],
mode='markers+text',
marker=dict(size=10, color='red', symbol='circle'),
text=["Sweet Spot"],
textposition="top center"
))
# Set axis labels
fig.update_layout(
scene=dict(
xaxis_title='Detectability Score',
yaxis_title='Distortion Score',
zaxis_title='Euclidean Distance'
),
margin=dict(l=0, r=0, b=0, t=0)
)
return fig