Edit model card

Model Details

Model Description

This model is fine-tuned from stable-diffusion-v1-5 on 110,000 image-text pairs from the MIMIC dataset using the SVDIFF [1] PEFT method. Under this fine-tuning strategy, fine-tune only the singular values of weight matrices in the U-Net while keeping everything else frozen.

Model Sources

Direct Use

This model can be directly used to generate realistic medical images from text prompts.

How to Get Started with the Model

import os
from safetensors.torch import load_file
from diffusers.pipelines import StableDiffusionPipeline


#### Defining loading function

def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs):
    
    print(pretrained_model_name_or_path)
    config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs)
    original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
    state_dict = original_model.state_dict()
    with accelerate.init_empty_weights():
        model = UNet2DConditionModelForSVDiff.from_config(config)
    # load pre-trained weights
    param_device = "cpu"
    torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
    spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
    state_dict.update(spectral_shifts_weights)
    # move the params from meta device to cpu
    missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
    if len(missing_keys) > 0:
        raise ValueError(
            f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
            f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
            " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
            " those weights or else make sure your checkpoint file is correct."
        )

    for param_name, param in state_dict.items():
        accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
        if accepts_dtype:
            set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
        else:
            set_module_tensor_to_device(model, param_name, param_device, value=param)
    
    if spectral_shifts_ckpt:
        if os.path.isdir(spectral_shifts_ckpt):
            spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors")
        elif not os.path.exists(spectral_shifts_ckpt):
            # download from hub
            hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
            spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
        assert os.path.exists(spectral_shifts_ckpt)

        with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
            for key in f.keys():
                # spectral_shifts_weights[key] = f.get_tensor(key)
                accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
                if accepts_dtype:
                    set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
                else:
                    set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
        print(f"Resumed from {spectral_shifts_ckpt}")
    if "torch_dtype"in kwargs:
        model = model.to(kwargs["torch_dtype"])
    model.register_to_config(_name_or_path=pretrained_model_name_or_path)
    # Set model in evaluation mode to deactivate DropOut modules by default
    model.eval()
    del original_model
    torch.cuda.empty_cache()
    return model

pipe.unet = load_unet_for_svdiff(
            "runwayml/stable-diffusion-v1-5",
            spectral_shifts_ckpt=os.path.join('unet', "spectral_shifts.safetensors"),
            subfolder="unet",
        )
for module in pipe.unet.modules():
    if hasattr(module, "perform_svd"):
        module.perform_svd()

# Load the adapted U-Net
pipe.unet.load_state_dict(state_dict, strict=False)
pipe.to('cuda:0')

# Generate images with text prompts

TEXT_PROMPT = "No acute cardiopulmonary abnormality."
GUIDANCE_SCALE = 4
INFERENCE_STEPS = 75

result_image = pipe(
        prompt=TEXT_PROMPT,
        height=224,
        width=224,
        guidance_scale=GUIDANCE_SCALE,
        num_inference_steps=INFERENCE_STEPS,
    )

result_pil_image = result_image["images"][0]

Training Details

Training Data

This model has been fine-tuned on 110K image-text pairs from the MIMIC dataset.

Training Procedure

The training procedure has been described in detail in Section 4.3 of this paper.

Metrics

This model has been evaluated using the Fréchet inception distance (FID) Score on MIMIC dataset.

Results

Fine-Tuning Strategy FID Score
Full FT 58.74
Attention 52.41
Bias 20.81
Norm 29.84
Bias+Norm+Attention 35.93
LoRA 439.65
SV-Diff 23.59
DiffFit 42.50

Environmental Impact

Using Parameter-Efficient Fine-Tuning potentially causes lesser harm to the environment since we fine-tune a significantly lesser number of parameters in a model. This results in much lesser computing and hardware requirements.

Citation

BibTeX:

@article{dutt2023parameter, title={Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity}, author={Dutt, Raman and Ericsson, Linus and Sanchez, Pedro and Tsaftaris, Sotirios A and Hospedales, Timothy}, journal={arXiv preprint arXiv:2305.08252}, year={2023} }

APA:
Dutt, R., Ericsson, L., Sanchez, P., Tsaftaris, S. A., & Hospedales, T. (2023). Parameter-Efficient Fine-Tuning for Medical Image Analysis: The Missed Opportunity. arXiv preprint arXiv:2305.08252.

Model Card Authors

Raman Dutt
Twitter
LinkedIn
Email

References

  1. Han, Ligong, et al. "Svdiff: Compact parameter space for diffusion fine-tuning." arXiv preprint arXiv:2303.11305 (2023).
Downloads last month
1
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.