|
import tempfile |
|
import unittest |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from diffusers import ScoreSdeVeScheduler |
|
|
|
|
|
class ScoreSdeVeSchedulerTest(unittest.TestCase): |
|
|
|
scheduler_classes = (ScoreSdeVeScheduler,) |
|
forward_default_kwargs = () |
|
|
|
@property |
|
def dummy_sample(self): |
|
batch_size = 4 |
|
num_channels = 3 |
|
height = 8 |
|
width = 8 |
|
|
|
sample = torch.rand((batch_size, num_channels, height, width)) |
|
|
|
return sample |
|
|
|
@property |
|
def dummy_sample_deter(self): |
|
batch_size = 4 |
|
num_channels = 3 |
|
height = 8 |
|
width = 8 |
|
|
|
num_elems = batch_size * num_channels * height * width |
|
sample = torch.arange(num_elems) |
|
sample = sample.reshape(num_channels, height, width, batch_size) |
|
sample = sample / num_elems |
|
sample = sample.permute(3, 0, 1, 2) |
|
|
|
return sample |
|
|
|
def dummy_model(self): |
|
def model(sample, t, *args): |
|
return sample * t / (t + 1) |
|
|
|
return model |
|
|
|
def get_scheduler_config(self, **kwargs): |
|
config = { |
|
"num_train_timesteps": 2000, |
|
"snr": 0.15, |
|
"sigma_min": 0.01, |
|
"sigma_max": 1348, |
|
"sampling_eps": 1e-5, |
|
} |
|
|
|
config.update(**kwargs) |
|
return config |
|
|
|
def check_over_configs(self, time_step=0, **config): |
|
kwargs = dict(self.forward_default_kwargs) |
|
|
|
for scheduler_class in self.scheduler_classes: |
|
sample = self.dummy_sample |
|
residual = 0.1 * sample |
|
|
|
scheduler_config = self.get_scheduler_config(**config) |
|
scheduler = scheduler_class(**scheduler_config) |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
scheduler.save_config(tmpdirname) |
|
new_scheduler = scheduler_class.from_pretrained(tmpdirname) |
|
|
|
output = scheduler.step_pred( |
|
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs |
|
).prev_sample |
|
new_output = new_scheduler.step_pred( |
|
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs |
|
).prev_sample |
|
|
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" |
|
|
|
output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample |
|
new_output = new_scheduler.step_correct( |
|
residual, sample, generator=torch.manual_seed(0), **kwargs |
|
).prev_sample |
|
|
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" |
|
|
|
def check_over_forward(self, time_step=0, **forward_kwargs): |
|
kwargs = dict(self.forward_default_kwargs) |
|
kwargs.update(forward_kwargs) |
|
|
|
for scheduler_class in self.scheduler_classes: |
|
sample = self.dummy_sample |
|
residual = 0.1 * sample |
|
|
|
scheduler_config = self.get_scheduler_config() |
|
scheduler = scheduler_class(**scheduler_config) |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
scheduler.save_config(tmpdirname) |
|
new_scheduler = scheduler_class.from_pretrained(tmpdirname) |
|
|
|
output = scheduler.step_pred( |
|
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs |
|
).prev_sample |
|
new_output = new_scheduler.step_pred( |
|
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs |
|
).prev_sample |
|
|
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" |
|
|
|
output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample |
|
new_output = new_scheduler.step_correct( |
|
residual, sample, generator=torch.manual_seed(0), **kwargs |
|
).prev_sample |
|
|
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" |
|
|
|
def test_timesteps(self): |
|
for timesteps in [10, 100, 1000]: |
|
self.check_over_configs(num_train_timesteps=timesteps) |
|
|
|
def test_sigmas(self): |
|
for sigma_min, sigma_max in zip([0.0001, 0.001, 0.01], [1, 100, 1000]): |
|
self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max) |
|
|
|
def test_time_indices(self): |
|
for t in [0.1, 0.5, 0.75]: |
|
self.check_over_forward(time_step=t) |
|
|
|
def test_full_loop_no_noise(self): |
|
kwargs = dict(self.forward_default_kwargs) |
|
|
|
scheduler_class = self.scheduler_classes[0] |
|
scheduler_config = self.get_scheduler_config() |
|
scheduler = scheduler_class(**scheduler_config) |
|
|
|
num_inference_steps = 3 |
|
|
|
model = self.dummy_model() |
|
sample = self.dummy_sample_deter |
|
|
|
scheduler.set_sigmas(num_inference_steps) |
|
scheduler.set_timesteps(num_inference_steps) |
|
generator = torch.manual_seed(0) |
|
|
|
for i, t in enumerate(scheduler.timesteps): |
|
sigma_t = scheduler.sigmas[i] |
|
|
|
for _ in range(scheduler.config.correct_steps): |
|
with torch.no_grad(): |
|
model_output = model(sample, sigma_t) |
|
sample = scheduler.step_correct(model_output, sample, generator=generator, **kwargs).prev_sample |
|
|
|
with torch.no_grad(): |
|
model_output = model(sample, sigma_t) |
|
|
|
output = scheduler.step_pred(model_output, t, sample, generator=generator, **kwargs) |
|
sample, _ = output.prev_sample, output.prev_sample_mean |
|
|
|
result_sum = torch.sum(torch.abs(sample)) |
|
result_mean = torch.mean(torch.abs(sample)) |
|
|
|
assert np.isclose(result_sum.item(), 14372758528.0) |
|
assert np.isclose(result_mean.item(), 18714530.0) |
|
|
|
def test_step_shape(self): |
|
kwargs = dict(self.forward_default_kwargs) |
|
|
|
num_inference_steps = kwargs.pop("num_inference_steps", None) |
|
|
|
for scheduler_class in self.scheduler_classes: |
|
scheduler_config = self.get_scheduler_config() |
|
scheduler = scheduler_class(**scheduler_config) |
|
|
|
sample = self.dummy_sample |
|
residual = 0.1 * sample |
|
|
|
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): |
|
scheduler.set_timesteps(num_inference_steps) |
|
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): |
|
kwargs["num_inference_steps"] = num_inference_steps |
|
|
|
output_0 = scheduler.step_pred(residual, 0, sample, generator=torch.manual_seed(0), **kwargs).prev_sample |
|
output_1 = scheduler.step_pred(residual, 1, sample, generator=torch.manual_seed(0), **kwargs).prev_sample |
|
|
|
self.assertEqual(output_0.shape, sample.shape) |
|
self.assertEqual(output_0.shape, output_1.shape) |
|
|