Update README.md
Browse files
README.md
CHANGED
@@ -32,7 +32,7 @@ def load_quanto_transformer(repo_path):
|
|
32 |
with torch.device("meta"):
|
33 |
transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.bfloat16)
|
34 |
state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors"))
|
35 |
-
requantize(transformer, state_dict, quantization_map, device=torch.device("
|
36 |
return transformer
|
37 |
|
38 |
|
@@ -44,7 +44,7 @@ def load_quanto_text_encoder_2(repo_path):
|
|
44 |
with torch.device("meta"):
|
45 |
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
|
46 |
state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors"))
|
47 |
-
requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("
|
48 |
return text_encoder_2
|
49 |
|
50 |
|
|
|
32 |
with torch.device("meta"):
|
33 |
transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.bfloat16)
|
34 |
state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors"))
|
35 |
+
requantize(transformer, state_dict, quantization_map, device=torch.device("cuda"))
|
36 |
return transformer
|
37 |
|
38 |
|
|
|
44 |
with torch.device("meta"):
|
45 |
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
|
46 |
state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors"))
|
47 |
+
requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda"))
|
48 |
return text_encoder_2
|
49 |
|
50 |
|