Disty0 commited on
Commit
fd65655
1 Parent(s): 065bb00

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -32,7 +32,7 @@ def load_quanto_transformer(repo_path):
32
  with torch.device("meta"):
33
  transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.bfloat16)
34
  state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors"))
35
- requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
36
  return transformer
37
 
38
 
@@ -44,7 +44,7 @@ def load_quanto_text_encoder_2(repo_path):
44
  with torch.device("meta"):
45
  text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
46
  state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors"))
47
- requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
48
  return text_encoder_2
49
 
50
 
 
32
  with torch.device("meta"):
33
  transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.bfloat16)
34
  state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors"))
35
+ requantize(transformer, state_dict, quantization_map, device=torch.device("cuda"))
36
  return transformer
37
 
38
 
 
44
  with torch.device("meta"):
45
  text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
46
  state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors"))
47
+ requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda"))
48
  return text_encoder_2
49
 
50