svjack's picture
Upload 1392 files
43b7e92 verified
raw
history blame
3.79 kB
import pickle as pkl
import unittest
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from diffusers.utils.outputs import BaseOutput
from diffusers.utils.testing_utils import require_torch
@dataclass
class CustomOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
class ConfigTester(unittest.TestCase):
def test_outputs_single_attribute(self):
outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4))
# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)
# test with a non-tensor attribute
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
# check every way of getting the attribute
assert isinstance(outputs.images, list)
assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)
def test_outputs_dict_init(self):
# test output reinitialization with a `dict` for compatibility with `accelerate`
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)})
# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)
# test with a non-tensor attribute
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]})
# check every way of getting the attribute
assert isinstance(outputs.images, list)
assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)
def test_outputs_serialization(self):
outputs_orig = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
serialized = pkl.dumps(outputs_orig)
outputs_copy = pkl.loads(serialized)
# Check original and copy are equal
assert dir(outputs_orig) == dir(outputs_copy)
assert dict(outputs_orig) == dict(outputs_copy)
assert vars(outputs_orig) == vars(outputs_copy)
@require_torch
def test_torch_pytree(self):
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
import torch
import torch.utils._pytree
data = np.random.rand(1, 3, 4, 4)
x = CustomOutput(images=data)
self.assertFalse(torch.utils._pytree._is_leaf(x))
expected_flat_outs = [data]
expected_tree_spec = torch.utils._pytree.TreeSpec(CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()])
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
self.assertEqual(expected_tree_spec, actual_tree_spec)
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)