|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
import inspect |
|
import unittest |
|
|
|
import torch |
|
from parameterized import parameterized |
|
|
|
from diffusers import PriorTransformer |
|
from diffusers.utils.testing_utils import ( |
|
backend_empty_cache, |
|
enable_full_determinism, |
|
floats_tensor, |
|
slow, |
|
torch_all_close, |
|
torch_device, |
|
) |
|
|
|
from ..test_modeling_common import ModelTesterMixin |
|
|
|
|
|
enable_full_determinism() |
|
|
|
|
|
class PriorTransformerTests(ModelTesterMixin, unittest.TestCase): |
|
model_class = PriorTransformer |
|
main_input_name = "hidden_states" |
|
|
|
@property |
|
def dummy_input(self): |
|
batch_size = 4 |
|
embedding_dim = 8 |
|
num_embeddings = 7 |
|
|
|
hidden_states = floats_tensor((batch_size, embedding_dim)).to(torch_device) |
|
|
|
proj_embedding = floats_tensor((batch_size, embedding_dim)).to(torch_device) |
|
encoder_hidden_states = floats_tensor((batch_size, num_embeddings, embedding_dim)).to(torch_device) |
|
|
|
return { |
|
"hidden_states": hidden_states, |
|
"timestep": 2, |
|
"proj_embedding": proj_embedding, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
} |
|
|
|
def get_dummy_seed_input(self, seed=0): |
|
torch.manual_seed(seed) |
|
batch_size = 4 |
|
embedding_dim = 8 |
|
num_embeddings = 7 |
|
|
|
hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device) |
|
|
|
proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device) |
|
encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device) |
|
|
|
return { |
|
"hidden_states": hidden_states, |
|
"timestep": 2, |
|
"proj_embedding": proj_embedding, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
} |
|
|
|
@property |
|
def input_shape(self): |
|
return (4, 8) |
|
|
|
@property |
|
def output_shape(self): |
|
return (4, 8) |
|
|
|
def prepare_init_args_and_inputs_for_common(self): |
|
init_dict = { |
|
"num_attention_heads": 2, |
|
"attention_head_dim": 4, |
|
"num_layers": 2, |
|
"embedding_dim": 8, |
|
"num_embeddings": 7, |
|
"additional_embeddings": 4, |
|
} |
|
inputs_dict = self.dummy_input |
|
return init_dict, inputs_dict |
|
|
|
def test_from_pretrained_hub(self): |
|
model, loading_info = PriorTransformer.from_pretrained( |
|
"hf-internal-testing/prior-dummy", output_loading_info=True |
|
) |
|
self.assertIsNotNone(model) |
|
self.assertEqual(len(loading_info["missing_keys"]), 0) |
|
|
|
model.to(torch_device) |
|
hidden_states = model(**self.dummy_input)[0] |
|
|
|
assert hidden_states is not None, "Make sure output is not None" |
|
|
|
def test_forward_signature(self): |
|
init_dict, _ = self.prepare_init_args_and_inputs_for_common() |
|
|
|
model = self.model_class(**init_dict) |
|
signature = inspect.signature(model.forward) |
|
|
|
arg_names = [*signature.parameters.keys()] |
|
|
|
expected_arg_names = ["hidden_states", "timestep"] |
|
self.assertListEqual(arg_names[:2], expected_arg_names) |
|
|
|
def test_output_pretrained(self): |
|
model = PriorTransformer.from_pretrained("hf-internal-testing/prior-dummy") |
|
model = model.to(torch_device) |
|
|
|
if hasattr(model, "set_default_attn_processor"): |
|
model.set_default_attn_processor() |
|
|
|
input = self.get_dummy_seed_input() |
|
|
|
with torch.no_grad(): |
|
output = model(**input)[0] |
|
|
|
output_slice = output[0, :5].flatten().cpu() |
|
print(output_slice) |
|
|
|
|
|
|
|
expected_output_slice = torch.tensor([-1.3436, -0.2870, 0.7538, 0.4368, -0.0239]) |
|
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) |
|
|
|
|
|
@slow |
|
class PriorTransformerIntegrationTests(unittest.TestCase): |
|
def get_dummy_seed_input(self, batch_size=1, embedding_dim=768, num_embeddings=77, seed=0): |
|
torch.manual_seed(seed) |
|
batch_size = batch_size |
|
embedding_dim = embedding_dim |
|
num_embeddings = num_embeddings |
|
|
|
hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device) |
|
|
|
proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device) |
|
encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device) |
|
|
|
return { |
|
"hidden_states": hidden_states, |
|
"timestep": 2, |
|
"proj_embedding": proj_embedding, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
} |
|
|
|
def tearDown(self): |
|
|
|
super().tearDown() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
@parameterized.expand( |
|
[ |
|
|
|
[13, [-0.5861, 0.1283, -0.0931, 0.0882, 0.4476, 0.1329, -0.0498, 0.0640]], |
|
[37, [-0.4913, 0.0110, -0.0483, 0.0541, 0.4954, -0.0170, 0.0354, 0.1651]], |
|
|
|
] |
|
) |
|
def test_kandinsky_prior(self, seed, expected_slice): |
|
model = PriorTransformer.from_pretrained("kandinsky-community/kandinsky-2-1-prior", subfolder="prior") |
|
model.to(torch_device) |
|
input = self.get_dummy_seed_input(seed=seed) |
|
|
|
with torch.no_grad(): |
|
sample = model(**input)[0] |
|
|
|
assert list(sample.shape) == [1, 768] |
|
|
|
output_slice = sample[0, :8].flatten().cpu() |
|
print(output_slice) |
|
expected_output_slice = torch.tensor(expected_slice) |
|
|
|
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) |
|
|