import argparse import requests import json import os import shutil from collections import defaultdict from inspect import signature from tempfile import TemporaryDirectory from typing import Dict, List, Optional, Set import torch from io import BytesIO from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download from huggingface_hub.file_download import repo_folder_name from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt from transformers import CONFIG_MAPPING COMMIT_MESSAGE = " This PR adds the both fp32 and fp16 in PyTorch and safetensors format to {}" def convert_single(model_id: str, filename: str, model_type: str, sample_size: int, scheduler_type: str, extract_ema: bool, folder: str): from_safetensors = filename.endswith(".safetensors") local_file = os.path.join(model_id, filename) ckpt_file = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename) if model_type == "v1": config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" elif model_type == "v2.0": config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference.yaml" elif model_type == "v2.1": config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" config_file = BytesIO(requests.get(config_url).content) pipeline = download_from_original_stable_diffusion_ckpt(ckpt_file, config_file, image_size=sample_size, scheduler_type=scheduler_type, from_safetensors=from_safetensors, extract_ema=extract_ema) pipeline.save_pretrained(folder) pipeline.save_pretrained(folder, safe_serialization=True) pipeline = pipeline.to(torch_dtype=torch.float16) pipeline.save_pretrained(folder, variant="fp16") pipeline.save_pretrained(folder, safe_serialization=True, variant="fp16") return folder def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]: try: discussions = api.get_repo_discussions(repo_id=model_id) except Exception: return None for discussion in discussions: if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title: details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num) if details.target_branch == "refs/heads/main": return discussion def convert(token: str, model_id: str, filename: str, model_type: str, sample_size: int = 512, scheduler_type: str = "pndm", extract_ema: bool = True): api = HfApi() pr_title = "Adding `diffusers` weights of this model" with TemporaryDirectory() as d: folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) os.makedirs(folder) new_pr = None try: folder = convert_single(model_id, filename, model_type, sample_size, scheduler_type, extract_ema, folder) new_pr = api.upload_folder(folder_path=folder, path_in_repo="./", repo_id=model_id, repo_type="model", token=token, commit_description=COMMIT_MESSAGE.format(model_id), create_pr=True) pr_number = new_pr.split("%2F")[-1].split("/")[0] print(f"Pr created at: {'https://huggingface.co/' + os.path.join(model_id, 'discussions', pr_number)}") finally: shutil.rmtree(folder) return new_pr