File size: 2,607 Bytes
dfd4a53
662ed4b
611a3ed
8a91492
611a3ed
d0f55c6
8a91492
2f4d877
d0f55c6
523fad9
 
d0f55c6
719c272
 
 
 
dfd4a53
 
 
 
 
719c272
 
dfd4a53
 
 
719c272
 
 
 
dfd4a53
 
 
 
 
 
 
 
 
 
 
719c272
8a91492
719c272
7e32ac7
 
 
 
 
 
d0f55c6
 
fae0e19
d0f55c6
 
719c272
 
d0f55c6
 
 
fae0e19
2f4d877
 
719c272
 
662ed4b
 
2f4d877
 
d0f55c6
523fad9
 
 
 
 
 
 
 
719c272
 
523fad9
 
 
 
 
 
 
 
719c272
 
523fad9
8a91492
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import asyncio
import io
import json
import os

import httpx
from huggingface_hub import HfApi, HfFileSystem, ModelCard, hf_hub_url
from huggingface_hub.utils import build_hf_headers

import src.constants as constants


class Client:
    def __init__(self):
        self.client = httpx.AsyncClient(follow_redirects=True)

    async def _get(self, url, headers=None, params=None):
        r = await self.client.get(url, headers=headers, params=params)
        r.raise_for_status()
        return r

    async def get(self, url, headers=None, params=None):
        try:
            r = await self._get(url, headers=headers, params=params)
        except httpx.ReadTimeout:
            return await self.retry(self._get, url, headers=headers, params=params)
        except httpx.HTTPError:
            return
        return r

    async def retry(self, func, url, max_retries=4, max_wait_time=8, wait_time=1, **kwargs):
        for _ in range(max_retries):
            try:
                await asyncio.sleep(wait_time)
                return await func(url, **kwargs)
            except httpx.ReadTimeout:
                wait_time = wait_time * 2
                if wait_time > max_wait_time:
                    print("HTTP Timeout: max retries exceeded with url:", url)
                    return


api = HfApi()
client = Client()
fs = HfFileSystem()


def glob(path):
    paths = fs.glob(path)
    return paths


async def load_json_file(path):
    url = to_url(path)
    r = await client.get(url)
    if r is None:
        return
    return r.json()


async def load_jsonlines_file(path):
    url = to_url(path)
    r = await client.get(url, headers=build_hf_headers())
    if r is None:
        return
    f = io.StringIO(r.text)
    return [json.loads(line) for line in f]


def to_url(path):
    *repo_type, org_name, ds_name, filename = path.split("/", 3)
    repo_type = repo_type[0][:-1] if repo_type else None
    return hf_hub_url(repo_id=f"{org_name}/{ds_name}", filename=filename, repo_type=repo_type)


async def load_model_card(model_id):
    url = to_url(f"{model_id}/README.md")
    r = await client.get(url)
    if r is None:
        return
    return ModelCard(r.text, ignore_metadata_errors=True)


async def list_models(filtering=None):
    params = {}
    if filtering:
        params["filter"] = filtering
    r = await client.get(f"{constants.HF_API_URL}/models", params=params)
    if r is None:
        return
    return r.json()


def restart_space():
    space_id = os.getenv("SPACE_ID")
    if space_id:
        api.restart_space(repo_id=space_id, token=os.getenv("HF_TOKEN"))