Convert weights to jax

#18
by jfacevedo - opened

Hello. I tried converting the weights to jax, but running into an error.

Code:

from diffusers import FlaxStableDiffusionPipeline
model_name='riffusion/riffusion-model-v1'
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_name, from_pt=True)
pipeline.save_pretrained('riffusion_jax', params=params)

error:

File "riffusion.py", line 39, in <module>
    pipeline.save_pretrained('riffusion_jax', params=params)
  File "/python3.8/site-packages/diffusers/pipeline_flax_utils.py", line 189, in save_pretrained
    save_method(
  File "/python3.8/site-packages/diffusers/modeling_flax_utils.py", line 518, in save_pretrained
    model_to_save.save_config(save_directory)
  File "/python3.8/site-packages/diffusers/configuration_utils.py", line 137, in save_config
    self.to_json_file(output_config_file)
  File "/python3.8/site-packages/diffusers/configuration_utils.py", line 524, in to_json_file
    writer.write(self.to_json_string())
  File "/python3.8/site-packages/diffusers/configuration_utils.py", line 504, in to_json_string
    config_dict["_class_name"] = self.__class__.__name__
  File "/python3.8/site-packages/flax/core/frozen_dict.py", line 72, in __setitem__
    raise ValueError('FrozenDict is immutable.')
ValueError: FrozenDict is immutable.

The same script works for CompVis/stable-diffusion-v1-4 and runwayml/stable-diffusion-v1-5 models without issues. Any idea what could be causing this? Thanks.

Sign up or log in to comment