metadata
license: creativeml-openrail-m
language:
- en
base_model: black-forest-labs/FLUX.1-schnell
pipeline_tag: text-to-image
library_name: diffusers
license: creativeml-openrail-m
This model may be used by individuals for personal and commercial purposes, including generating and selling images. Commercial use by companies or organizations is strictly prohibited.
Maxwell Model
Acknowledgements
Firstly, a big thanks to @sayakpaul who fixed most issues we were facing with Diffusers. i used his way of Quantization bnb-NF4
Installation
- Install the required packages:
pip install torch accelerate safetensors diffusers huggingface_hub bitsandbytes transformers
Download convert_nf4_flux.py @same level of Generative Code
Usage
Run the following Python code:
# Generative Code
from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from convert_nf4_flux import replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch
# Set dtype and check for float8 support
dtype = torch.bfloat16
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# Download the model checkpoint
ckpt_path = hf_hub_download("ABDALLALSWAITI/Maxwell", filename="diffusion_pytorch_model.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
# Initialize the model with empty weights
with init_empty_weights():
config = FluxTransformer2DModel.load_config("ABDALLALSWAITI/Maxwell")
model = FluxTransformer2DModel.from_config(config).to(dtype)
expected_state_dict_keys = list(model.state_dict().keys())
# Replace layers with NF4 quantized versions
replace_with_bnb_linear(model, "nf4")
# Load the state dict into the quantized model
for param_name, param in original_state_dict.items():
if param_name not in expected_state_dict_keys:
continue
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
param = param.to(dtype)
if not check_quantized_param(model, param_name):
set_module_tensor_to_device(model, param_name, device=0, value=param)
else:
create_quantized_param(
model, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True
)
# Clean up
del original_state_dict
gc.collect()
# Print model size
print(compute_module_sizes(model)[""] / 1024 / 1204)
# Initialize the pipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
# Generate an image from a prompt
prompt = "A mystic Tiger play guitar with sign that says hello world!"
image = pipe(prompt, guidance_scale=0.0, num_inference_steps=4, generator=torch.manual_seed(0)).images[0]
image.save("simple.png")
This code will download the Maxwell model, initialize it with NF4 quantization, and generate an image based on the given prompt.