|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import tempfile |
|
import unittest |
|
|
|
from diffusers import ( |
|
DDIMScheduler, |
|
DDPMScheduler, |
|
DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
PNDMScheduler, |
|
logging, |
|
) |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from diffusers.utils.testing_utils import CaptureLogger |
|
|
|
|
|
class SampleObject(ConfigMixin): |
|
config_name = "config.json" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
a=2, |
|
b=5, |
|
c=(2, 5), |
|
d="for diffusion", |
|
e=[1, 3], |
|
): |
|
pass |
|
|
|
|
|
class SampleObject2(ConfigMixin): |
|
config_name = "config.json" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
a=2, |
|
b=5, |
|
c=(2, 5), |
|
d="for diffusion", |
|
f=[1, 3], |
|
): |
|
pass |
|
|
|
|
|
class SampleObject3(ConfigMixin): |
|
config_name = "config.json" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
a=2, |
|
b=5, |
|
c=(2, 5), |
|
d="for diffusion", |
|
e=[1, 3], |
|
f=[1, 3], |
|
): |
|
pass |
|
|
|
|
|
class SampleObject4(ConfigMixin): |
|
config_name = "config.json" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
a=2, |
|
b=5, |
|
c=(2, 5), |
|
d="for diffusion", |
|
e=[1, 5], |
|
f=[5, 4], |
|
): |
|
pass |
|
|
|
|
|
class ConfigTester(unittest.TestCase): |
|
def test_load_not_from_mixin(self): |
|
with self.assertRaises(ValueError): |
|
ConfigMixin.load_config("dummy_path") |
|
|
|
def test_register_to_config(self): |
|
obj = SampleObject() |
|
config = obj.config |
|
assert config["a"] == 2 |
|
assert config["b"] == 5 |
|
assert config["c"] == (2, 5) |
|
assert config["d"] == "for diffusion" |
|
assert config["e"] == [1, 3] |
|
|
|
|
|
obj = SampleObject(_name_or_path="lalala") |
|
config = obj.config |
|
assert config["a"] == 2 |
|
assert config["b"] == 5 |
|
assert config["c"] == (2, 5) |
|
assert config["d"] == "for diffusion" |
|
assert config["e"] == [1, 3] |
|
|
|
|
|
obj = SampleObject(c=6) |
|
config = obj.config |
|
assert config["a"] == 2 |
|
assert config["b"] == 5 |
|
assert config["c"] == 6 |
|
assert config["d"] == "for diffusion" |
|
assert config["e"] == [1, 3] |
|
|
|
|
|
obj = SampleObject(1, c=6) |
|
config = obj.config |
|
assert config["a"] == 1 |
|
assert config["b"] == 5 |
|
assert config["c"] == 6 |
|
assert config["d"] == "for diffusion" |
|
assert config["e"] == [1, 3] |
|
|
|
def test_save_load(self): |
|
obj = SampleObject() |
|
config = obj.config |
|
|
|
assert config["a"] == 2 |
|
assert config["b"] == 5 |
|
assert config["c"] == (2, 5) |
|
assert config["d"] == "for diffusion" |
|
assert config["e"] == [1, 3] |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
obj.save_config(tmpdirname) |
|
new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname)) |
|
new_config = new_obj.config |
|
|
|
|
|
config = dict(config) |
|
new_config = dict(new_config) |
|
|
|
assert config.pop("c") == (2, 5) |
|
assert new_config.pop("c") == [2, 5] |
|
config.pop("_use_default_values") |
|
assert config == new_config |
|
|
|
def test_load_ddim_from_pndm(self): |
|
logger = logging.get_logger("diffusers.configuration_utils") |
|
|
|
logger.setLevel(30) |
|
|
|
with CaptureLogger(logger) as cap_logger: |
|
ddim = DDIMScheduler.from_pretrained( |
|
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
|
) |
|
|
|
assert ddim.__class__ == DDIMScheduler |
|
|
|
assert cap_logger.out == "" |
|
|
|
def test_load_euler_from_pndm(self): |
|
logger = logging.get_logger("diffusers.configuration_utils") |
|
|
|
logger.setLevel(30) |
|
|
|
with CaptureLogger(logger) as cap_logger: |
|
euler = EulerDiscreteScheduler.from_pretrained( |
|
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
|
) |
|
|
|
assert euler.__class__ == EulerDiscreteScheduler |
|
|
|
assert cap_logger.out == "" |
|
|
|
def test_load_euler_ancestral_from_pndm(self): |
|
logger = logging.get_logger("diffusers.configuration_utils") |
|
|
|
logger.setLevel(30) |
|
|
|
with CaptureLogger(logger) as cap_logger: |
|
euler = EulerAncestralDiscreteScheduler.from_pretrained( |
|
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
|
) |
|
|
|
assert euler.__class__ == EulerAncestralDiscreteScheduler |
|
|
|
assert cap_logger.out == "" |
|
|
|
def test_load_pndm(self): |
|
logger = logging.get_logger("diffusers.configuration_utils") |
|
|
|
logger.setLevel(30) |
|
|
|
with CaptureLogger(logger) as cap_logger: |
|
pndm = PNDMScheduler.from_pretrained( |
|
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
|
) |
|
|
|
assert pndm.__class__ == PNDMScheduler |
|
|
|
assert cap_logger.out == "" |
|
|
|
def test_overwrite_config_on_load(self): |
|
logger = logging.get_logger("diffusers.configuration_utils") |
|
|
|
logger.setLevel(30) |
|
|
|
with CaptureLogger(logger) as cap_logger: |
|
ddpm = DDPMScheduler.from_pretrained( |
|
"hf-internal-testing/tiny-stable-diffusion-torch", |
|
subfolder="scheduler", |
|
prediction_type="sample", |
|
beta_end=8, |
|
) |
|
|
|
with CaptureLogger(logger) as cap_logger_2: |
|
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88) |
|
|
|
assert ddpm.__class__ == DDPMScheduler |
|
assert ddpm.config.prediction_type == "sample" |
|
assert ddpm.config.beta_end == 8 |
|
assert ddpm_2.config.beta_start == 88 |
|
|
|
|
|
assert cap_logger.out == "" |
|
assert cap_logger_2.out == "" |
|
|
|
def test_load_dpmsolver(self): |
|
logger = logging.get_logger("diffusers.configuration_utils") |
|
|
|
logger.setLevel(30) |
|
|
|
with CaptureLogger(logger) as cap_logger: |
|
dpm = DPMSolverMultistepScheduler.from_pretrained( |
|
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
|
) |
|
|
|
assert dpm.__class__ == DPMSolverMultistepScheduler |
|
|
|
assert cap_logger.out == "" |
|
|
|
def test_use_default_values(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = SampleObject() |
|
|
|
config_dict = {k: v for k, v in config.config.items() if not k.startswith("_")} |
|
|
|
|
|
assert set(config_dict.keys()) == set(config.config._use_default_values) |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
config.save_config(tmpdirname) |
|
|
|
|
|
config = SampleObject2.from_config(SampleObject2.load_config(tmpdirname)) |
|
|
|
assert "f" in config.config._use_default_values |
|
assert config.config.f == [1, 3] |
|
|
|
|
|
|
|
new_config = SampleObject4.from_config(config.config) |
|
assert new_config.config.f == [5, 4] |
|
|
|
config.config._use_default_values.pop() |
|
new_config_2 = SampleObject4.from_config(config.config) |
|
assert new_config_2.config.f == [1, 3] |
|
|
|
|
|
assert new_config_2.config.e == [1, 3] |
|
|