File size: 3,582 Bytes
e331aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import requests
import json
import os
import shutil
from collections import defaultdict
from inspect import signature
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional, Set

import torch
from io import BytesIO

from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
from transformers import CONFIG_MAPPING


COMMIT_MESSAGE = " This PR adds the both fp32 and fp16 in PyTorch and safetensors format to {}"


def convert_single(model_id: str, filename: str, model_type: str, sample_size: int, scheduler_type: str, extract_ema: bool, folder: str):
    from_safetensors = filename.endswith(".safetensors")

    local_file = os.path.join(model_id, filename)
    ckpt_file = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename)

    if model_type == "v1":
        config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
    elif model_type == "v2.0":
        config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference.yaml"
    elif model_type == "v2.1":
        config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"

    config_file = BytesIO(requests.get(config_url).content)

    pipeline = download_from_original_stable_diffusion_ckpt(ckpt_file, config_file, image_size=sample_size, scheduler_type=scheduler_type, from_safetensors=from_safetensors, extract_ema=extract_ema)

    pipeline.save_pretrained(folder)
    pipeline.save_pretrained(folder, safe_serialization=True)

    pipeline = pipeline.to(torch_dtype=torch.float16)
    pipeline.save_pretrained(folder, variant="fp16")
    pipeline.save_pretrained(folder, safe_serialization=True, variant="fp16")

    return folder


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
    try:
        discussions = api.get_repo_discussions(repo_id=model_id)
    except Exception:
        return None
    for discussion in discussions:
        if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
            details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
            if details.target_branch == "refs/heads/main":
                return discussion


def convert(token: str, model_id: str, filename: str, model_type: str, sample_size: int = 512, scheduler_type: str = "pndm", extract_ema: bool = True):
    api = HfApi()

    pr_title = "Adding `diffusers` weights of this model"

    with TemporaryDirectory() as d:
        folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
        os.makedirs(folder)
        new_pr = None
        try:
            folder = convert_single(model_id, filename, model_type, sample_size, scheduler_type, extract_ema, folder)
            new_pr  = api.upload_folder(folder_path=folder, path_in_repo="./", repo_id=model_id, repo_type="model", token=token, commit_description=COMMIT_MESSAGE.format(model_id), create_pr=True)
            pr_number = new_pr.split("%2F")[-1].split("/")[0]
            print(f"Pr created at: {'https://huggingface.co/' + os.path.join(model_id, 'discussions', pr_number)}")
        finally:
            shutil.rmtree(folder)

        return new_pr