# coding=utf-8 # Copyright 2024 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random import unittest import numpy as np import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( AutoencoderKL, ControlNetModel, EulerDiscreteScheduler, StableDiffusionXLControlNetImg2ImgPipeline, UNet2DConditionModel, ) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, torch_device from ..pipeline_params import ( IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, ) from ..test_pipelines_common import ( IPAdapterTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, ) enable_full_determinism() class ControlNetPipelineSDXLImg2ImgFastTests( IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase, ): pipeline_class = StableDiffusionXLControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union( {"add_text_embeds", "add_time_ids", "add_neg_time_ids"} ) def get_dummy_components(self, skip_first_text_encoder=False): torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, sample_size=32, in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), # SD2-specific config below attention_head_dim=(2, 4), use_linear_projection=True, addition_embed_type="text_time", addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), projection_class_embeddings_input_dim=80, # 6 * 8 + 32 cross_attention_dim=64 if not skip_first_text_encoder else 32, ) torch.manual_seed(0) controlnet = ControlNetModel( block_out_channels=(32, 64), layers_per_block=2, in_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), conditioning_embedding_out_channels=(16, 32), # SD2-specific config below attention_head_dim=(2, 4), use_linear_projection=True, addition_embed_type="text_time", addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), projection_class_embeddings_input_dim=80, # 6 * 8 + 32 cross_attention_dim=64, ) torch.manual_seed(0) scheduler = EulerDiscreteScheduler( beta_start=0.00085, beta_end=0.012, steps_offset=1, beta_schedule="scaled_linear", timestep_spacing="leading", ) torch.manual_seed(0) vae = AutoencoderKL( block_out_channels=[32, 64], in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], latent_channels=4, ) torch.manual_seed(0) text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, hidden_size=32, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, num_hidden_layers=5, pad_token_id=1, vocab_size=1000, # SD2-specific config below hidden_act="gelu", projection_dim=32, ) text_encoder = CLIPTextModel(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") components = { "unet": unet, "controlnet": controlnet, "scheduler": scheduler, "vae": vae, "text_encoder": text_encoder if not skip_first_text_encoder else None, "tokenizer": tokenizer if not skip_first_text_encoder else None, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, "image_encoder": None, "feature_extractor": None, } return components def get_dummy_inputs(self, device, seed=0): controlnet_embedder_scale_factor = 2 image = floats_tensor( (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), rng=random.Random(seed), ).to(device) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { "prompt": "A painting of a squirrel eating a burger", "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, "output_type": "np", "image": image, "control_image": image, } return inputs def test_ip_adapter_single(self): expected_pipe_slice = None if torch_device == "cpu": expected_pipe_slice = np.array([0.6265, 0.5441, 0.5384, 0.5446, 0.5810, 0.5908, 0.5414, 0.5428, 0.5353]) return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice) def test_stable_diffusion_xl_controlnet_img2img(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = self.pipeline_class(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) expected_slice = np.array( [0.5557202, 0.46418434, 0.46983826, 0.623529, 0.5557242, 0.49262643, 0.6070508, 0.5702978, 0.43777135] ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_xl_controlnet_img2img_guess(self): device = "cpu" components = self.get_dummy_components() sd_pipe = self.pipeline_class(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) inputs["guess_mode"] = True output = sd_pipe(**inputs) image_slice = output.images[0, -3:, -3:, -1] assert output.images.shape == (1, 64, 64, 3) expected_slice = np.array( [0.5557202, 0.46418434, 0.46983826, 0.623529, 0.5557242, 0.49262643, 0.6070508, 0.5702978, 0.43777135] ) # make sure that it's equal assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_attention_slicing_forward_pass(self): return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", ) def test_xformers_attention_forwardGenerator_pass(self): self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) # TODO(Patrick, Sayak) - skip for now as this requires more refiner tests def test_save_load_optional_components(self): pass @require_torch_gpu def test_stable_diffusion_xl_offloads(self): pipes = [] components = self.get_dummy_components() sd_pipe = self.pipeline_class(**components).to(torch_device) pipes.append(sd_pipe) components = self.get_dummy_components() sd_pipe = self.pipeline_class(**components) sd_pipe.enable_model_cpu_offload() pipes.append(sd_pipe) components = self.get_dummy_components() sd_pipe = self.pipeline_class(**components) sd_pipe.enable_sequential_cpu_offload() pipes.append(sd_pipe) image_slices = [] for pipe in pipes: pipe.unet.set_default_attn_processor() inputs = self.get_dummy_inputs(torch_device) image = pipe(**inputs).images image_slices.append(image[0, -3:, -3:, -1].flatten()) assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 def test_stable_diffusion_xl_multi_prompts(self): components = self.get_dummy_components() sd_pipe = self.pipeline_class(**components).to(torch_device) # forward with single prompt inputs = self.get_dummy_inputs(torch_device) output = sd_pipe(**inputs) image_slice_1 = output.images[0, -3:, -3:, -1] # forward with same prompt duplicated inputs = self.get_dummy_inputs(torch_device) inputs["prompt_2"] = inputs["prompt"] output = sd_pipe(**inputs) image_slice_2 = output.images[0, -3:, -3:, -1] # ensure the results are equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 # forward with different prompt inputs = self.get_dummy_inputs(torch_device) inputs["prompt_2"] = "different prompt" output = sd_pipe(**inputs) image_slice_3 = output.images[0, -3:, -3:, -1] # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 # manually set a negative_prompt inputs = self.get_dummy_inputs(torch_device) inputs["negative_prompt"] = "negative prompt" output = sd_pipe(**inputs) image_slice_1 = output.images[0, -3:, -3:, -1] # forward with same negative_prompt duplicated inputs = self.get_dummy_inputs(torch_device) inputs["negative_prompt"] = "negative prompt" inputs["negative_prompt_2"] = inputs["negative_prompt"] output = sd_pipe(**inputs) image_slice_2 = output.images[0, -3:, -3:, -1] # ensure the results are equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 # forward with different negative_prompt inputs = self.get_dummy_inputs(torch_device) inputs["negative_prompt"] = "negative prompt" inputs["negative_prompt_2"] = "different negative prompt" output = sd_pipe(**inputs) image_slice_3 = output.images[0, -3:, -3:, -1] # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 # Copied from test_stable_diffusion_xl.py def test_stable_diffusion_xl_prompt_embeds(self): components = self.get_dummy_components() sd_pipe = self.pipeline_class(**components) sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) # forward without prompt embeds inputs = self.get_dummy_inputs(torch_device) inputs["prompt"] = 2 * [inputs["prompt"]] inputs["num_images_per_prompt"] = 2 output = sd_pipe(**inputs) image_slice_1 = output.images[0, -3:, -3:, -1] # forward with prompt embeds inputs = self.get_dummy_inputs(torch_device) prompt = 2 * [inputs.pop("prompt")] ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = sd_pipe.encode_prompt(prompt) output = sd_pipe( **inputs, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, ) image_slice_2 = output.images[0, -3:, -3:, -1] # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4