# coding=utf-8 # Copyright 2024 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import tempfile import unittest import numpy as np import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline, ) from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device if is_peft_available(): from peft import LoraConfig from peft.utils import get_peft_model_state_dict sys.path.append(".") from utils import check_if_lora_correctly_set # noqa: E402 @require_peft_backend class SD3LoRATests(unittest.TestCase): pipeline_class = StableDiffusion3Pipeline def get_dummy_components(self): torch.manual_seed(0) transformer = SD3Transformer2DModel( sample_size=32, patch_size=1, in_channels=4, num_layers=1, attention_head_dim=8, num_attention_heads=4, caption_projection_dim=32, joint_attention_dim=32, pooled_projection_dim=64, out_channels=4, ) clip_text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, hidden_size=32, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, num_hidden_layers=5, pad_token_id=1, vocab_size=1000, hidden_act="gelu", projection_dim=32, ) torch.manual_seed(0) text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config) torch.manual_seed(0) text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config) text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) vae = AutoencoderKL( sample_size=32, in_channels=3, out_channels=3, block_out_channels=(4,), layers_per_block=1, latent_channels=4, norm_num_groups=1, use_quant_conv=False, use_post_quant_conv=False, shift_factor=0.0609, scaling_factor=1.5035, ) scheduler = FlowMatchEulerDiscreteScheduler() return { "scheduler": scheduler, "text_encoder": text_encoder, "text_encoder_2": text_encoder_2, "text_encoder_3": text_encoder_3, "tokenizer": tokenizer, "tokenizer_2": tokenizer_2, "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, } def get_dummy_inputs(self, device, seed=0): if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device="cpu").manual_seed(seed) inputs = { "prompt": "A painting of a squirrel eating a burger", "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, "output_type": "np", } return inputs def get_lora_config_for_transformer(self): lora_config = LoraConfig( r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False, use_dora=False, ) return lora_config def test_simple_inference_with_transformer_lora_save_load(self): components = self.get_dummy_components() transformer_config = self.get_lora_config_for_transformer() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(torch_device) pipe.transformer.add_adapter(transformer_config) self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") inputs = self.get_dummy_inputs(torch_device) images_lora = pipe(**inputs).images with tempfile.TemporaryDirectory() as tmpdirname: transformer_state_dict = get_peft_model_state_dict(pipe.transformer) self.pipeline_class.save_lora_weights( save_directory=tmpdirname, transformer_lora_layers=transformer_state_dict, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) pipe.unload_lora_weights() pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) inputs = self.get_dummy_inputs(torch_device) images_lora_from_pretrained = pipe(**inputs).images self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") self.assertTrue( np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), "Loading from saved checkpoints should give same results.", ) def test_simple_inference_with_transformer_lora_and_scale(self): components = self.get_dummy_components() transformer_lora_config = self.get_lora_config_for_transformer() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(torch_device) output_no_lora = pipe(**inputs).images pipe.transformer.add_adapter(transformer_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") inputs = self.get_dummy_inputs(torch_device) output_lora = pipe(**inputs).images self.assertTrue( not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) inputs = self.get_dummy_inputs(torch_device) output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images self.assertTrue( not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), "Lora + scale should change the output", ) inputs = self.get_dummy_inputs(torch_device) output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images self.assertTrue( np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), "Lora + 0 scale should lead to same result as no LoRA", ) def test_simple_inference_with_transformer_fused(self): components = self.get_dummy_components() transformer_lora_config = self.get_lora_config_for_transformer() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(torch_device) output_no_lora = pipe(**inputs).images pipe.transformer.add_adapter(transformer_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") pipe.fuse_lora() # Fusing should still keep the LoRA layers self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") inputs = self.get_dummy_inputs(torch_device) ouput_fused = pipe(**inputs).images self.assertFalse( np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) def test_simple_inference_with_transformer_fused_with_no_fusion(self): components = self.get_dummy_components() transformer_lora_config = self.get_lora_config_for_transformer() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(torch_device) output_no_lora = pipe(**inputs).images pipe.transformer.add_adapter(transformer_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") inputs = self.get_dummy_inputs(torch_device) ouput_lora = pipe(**inputs).images pipe.fuse_lora() # Fusing should still keep the LoRA layers self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") inputs = self.get_dummy_inputs(torch_device) ouput_fused = pipe(**inputs).images self.assertFalse( np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) self.assertTrue( np.allclose(ouput_fused, ouput_lora, atol=1e-3, rtol=1e-3), "Fused lora output should be changed when LoRA isn't fused but still effective.", ) def test_simple_inference_with_transformer_fuse_unfuse(self): components = self.get_dummy_components() transformer_lora_config = self.get_lora_config_for_transformer() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(torch_device) output_no_lora = pipe(**inputs).images pipe.transformer.add_adapter(transformer_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") pipe.fuse_lora() # Fusing should still keep the LoRA layers self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") inputs = self.get_dummy_inputs(torch_device) ouput_fused = pipe(**inputs).images self.assertFalse( np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) pipe.unfuse_lora() self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") inputs = self.get_dummy_inputs(torch_device) output_unfused_lora = pipe(**inputs).images self.assertTrue( np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) @require_torch_gpu def test_sd3_lora(self): """ Test loading the loras that are saved with the diffusers and peft formats. Related PR: https://github.com/huggingface/diffusers/pull/8584 """ components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) lora_model_id = "hf-internal-testing/tiny-sd3-loras" lora_filename = "lora_diffusers_format.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.unload_lora_weights() lora_filename = "lora_peft_format.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)