File size: 6,832 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import torch
from diffusers import TCDScheduler
from .test_schedulers import SchedulerCommonTest
class TCDSchedulerTest(SchedulerCommonTest):
scheduler_classes = (TCDScheduler,)
forward_default_kwargs = (("num_inference_steps", 10),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.00085,
"beta_end": 0.0120,
"beta_schedule": "scaled_linear",
"prediction_type": "epsilon",
}
config.update(**kwargs)
return config
@property
def default_num_inference_steps(self):
return 10
@property
def default_valid_timestep(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timestep = scheduler.timesteps[-1]
return timestep
def test_timesteps(self):
for timesteps in [100, 500, 1000]:
# 0 is not guaranteed to be in the timestep schedule, but timesteps - 1 is
self.check_over_configs(time_step=timesteps - 1, num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(time_step=self.default_valid_timestep, beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "scaled_linear", "squaredcos_cap_v2"]:
self.check_over_configs(time_step=self.default_valid_timestep, beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(time_step=self.default_valid_timestep, prediction_type=prediction_type)
def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(time_step=self.default_valid_timestep, clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(time_step=self.default_valid_timestep, thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(
time_step=self.default_valid_timestep,
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_time_indices(self):
# Get default timestep schedule.
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
for t in timesteps:
self.check_over_forward(time_step=t)
def test_inference_steps(self):
# Hardcoded for now
for t, num_inference_steps in zip([99, 39, 39, 19], [10, 25, 26, 50]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def full_loop(self, num_inference_steps=10, seed=0, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
eta = 0.0 # refer to gamma in the paper
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(seed)
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta, generator).prev_sample
return sample
def test_full_loop_onestep_deter(self):
sample = self.full_loop(num_inference_steps=1)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 29.8715) < 1e-3 # 0.0778918
assert abs(result_mean.item() - 0.0389) < 1e-3
def test_full_loop_multistep_deter(self):
sample = self.full_loop(num_inference_steps=10)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 181.2040) < 1e-3
assert abs(result_mean.item() - 0.2359) < 1e-3
def test_custom_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
scheduler.set_timesteps(timesteps=timesteps)
scheduler_timesteps = scheduler.timesteps
for i, timestep in enumerate(scheduler_timesteps):
if i == len(timesteps) - 1:
expected_prev_t = -1
else:
expected_prev_t = timesteps[i + 1]
prev_t = scheduler.previous_timestep(timestep)
prev_t = prev_t.item()
self.assertEqual(prev_t, expected_prev_t)
def test_custom_timesteps_increasing_order(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 51, 0]
with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
scheduler.set_timesteps(timesteps=timesteps)
def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
num_inference_steps = len(timesteps)
with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."):
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)
def test_custom_timesteps_too_large(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [scheduler.config.num_train_timesteps]
with self.assertRaises(
ValueError,
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)
|