File size: 8,526 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import torch
from diffusers import HeunDiscreteScheduler
from diffusers.utils.testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
class HeunDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (HeunDiscreteScheduler,)
num_inference_steps = 10
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1100,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
}
config.update(**kwargs)
return config
def test_timesteps(self):
for timesteps in [10, 50, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "scaled_linear", "exp"]:
self.check_over_configs(beta_schedule=schedule)
def test_clip_sample(self):
for clip_sample_range in [1.0, 2.0, 3.0]:
self.check_over_configs(clip_sample_range=clip_sample_range, clip_sample=True)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction", "sample"]:
self.check_over_configs(prediction_type=prediction_type)
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
return sample
def full_loop_custom_timesteps(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = self.num_inference_steps
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
timesteps = torch.cat([timesteps[:1], timesteps[1::2]])
# reset the timesteps using `timesteps`
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps=None, timesteps=timesteps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
return sample
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["cpu", "mps"]:
assert abs(result_sum.item() - 4.6934e-07) < 1e-2
assert abs(result_mean.item() - 6.1112e-10) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 4.693428650170972e-07) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
model = self.dummy_model()
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if str(torch_device).startswith("cpu"):
# The following sum varies between 148 and 156 on mps. Why?
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
elif str(torch_device).startswith("mps"):
# Larger tolerance on mps
assert abs(result_mean.item() - 0.0002) < 1e-2
else:
# CUDA
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_device_karras_sigmas(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
model = self.dummy_model()
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 0.00015) < 1e-2
assert abs(result_mean.item() - 1.9869554535034695e-07) < 1e-2
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
t_start = self.num_inference_steps - 2
noise = self.dummy_noise_deter
noise = noise.to(torch_device)
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 75074.8906) < 1e-2, f" expected result sum 75074.8906, but get {result_sum}"
assert abs(result_mean.item() - 97.7538) < 1e-3, f" expected result mean 97.7538, but get {result_mean}"
def test_custom_timesteps(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
for timestep_spacing in ["linspace", "leading"]:
sample = self.full_loop(
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
sample_custom_timesteps = self.full_loop_custom_timesteps(
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
|