File size: 2,288 Bytes
39a6792
 
 
f70898c
39a6792
 
 
f70898c
39a6792
65d64be
39a6792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65d64be
 
 
 
 
 
 
 
 
 
 
f70898c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39a6792
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import functools
import inspect
import json
import os
from typing import Callable, TypeVar

import anyio
import httpx
from anyio import Semaphore
from huggingface_hub._snapshot_download import snapshot_download
from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")

MAX_CONCURRENT_THREADS = 1
MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)


@functools.lru_cache()
def load_json(path: str) -> dict:
    with open(path, "r", encoding="utf-8") as file:
        return json.load(file)


@functools.lru_cache()
def read_file(path: str) -> str:
    with open(path, "r", encoding="utf-8") as file:
        return file.read()


def download_repo_files(repo_id, allow_patterns, token=None):
    return snapshot_download(
        repo_id=repo_id,
        repo_type="model",
        revision="main",
        token=token,
        allow_patterns=allow_patterns,
        ignore_patterns=None,
    )


def download_civit_file(lora_id, version_id, file_path=".", token=None):
    base_url = "https://civitai.com/api/download/models"
    file = f"{file_path}/{lora_id}.{version_id}.safetensors"

    if os.path.exists(file):
        return

    try:
        params = {"token": token}
        response = httpx.get(
            f"{base_url}/{version_id}",
            timeout=None,
            params=params,
            follow_redirects=True,
        )

        response.raise_for_status()
        os.makedirs(file_path, exist_ok=True)

        with open(file, "wb") as f:
            f.write(response.content)
    except httpx.HTTPStatusError as e:
        print(e.request.url)
        print(f"HTTPError: {e.response.status_code} {e.response.text}")
    except httpx.RequestError as e:
        print(f"RequestError: {e}")


# like the original but supports args and kwargs instead of a dict
# https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
    async with MAX_THREADS_GUARD:
        sig = inspect.signature(fn)
        bound_args = sig.bind(*args, **kwargs)
        bound_args.apply_defaults()
        partial_fn = functools.partial(fn, **bound_args.arguments)
        return await anyio.to_thread.run_sync(partial_fn)