File size: 2,607 Bytes
dfd4a53 662ed4b 611a3ed 8a91492 611a3ed d0f55c6 8a91492 2f4d877 d0f55c6 523fad9 d0f55c6 719c272 dfd4a53 719c272 dfd4a53 719c272 dfd4a53 719c272 8a91492 719c272 7e32ac7 d0f55c6 fae0e19 d0f55c6 719c272 d0f55c6 fae0e19 2f4d877 719c272 662ed4b 2f4d877 d0f55c6 523fad9 719c272 523fad9 719c272 523fad9 8a91492 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import asyncio
import io
import json
import os
import httpx
from huggingface_hub import HfApi, HfFileSystem, ModelCard, hf_hub_url
from huggingface_hub.utils import build_hf_headers
import src.constants as constants
class Client:
def __init__(self):
self.client = httpx.AsyncClient(follow_redirects=True)
async def _get(self, url, headers=None, params=None):
r = await self.client.get(url, headers=headers, params=params)
r.raise_for_status()
return r
async def get(self, url, headers=None, params=None):
try:
r = await self._get(url, headers=headers, params=params)
except httpx.ReadTimeout:
return await self.retry(self._get, url, headers=headers, params=params)
except httpx.HTTPError:
return
return r
async def retry(self, func, url, max_retries=4, max_wait_time=8, wait_time=1, **kwargs):
for _ in range(max_retries):
try:
await asyncio.sleep(wait_time)
return await func(url, **kwargs)
except httpx.ReadTimeout:
wait_time = wait_time * 2
if wait_time > max_wait_time:
print("HTTP Timeout: max retries exceeded with url:", url)
return
api = HfApi()
client = Client()
fs = HfFileSystem()
def glob(path):
paths = fs.glob(path)
return paths
async def load_json_file(path):
url = to_url(path)
r = await client.get(url)
if r is None:
return
return r.json()
async def load_jsonlines_file(path):
url = to_url(path)
r = await client.get(url, headers=build_hf_headers())
if r is None:
return
f = io.StringIO(r.text)
return [json.loads(line) for line in f]
def to_url(path):
*repo_type, org_name, ds_name, filename = path.split("/", 3)
repo_type = repo_type[0][:-1] if repo_type else None
return hf_hub_url(repo_id=f"{org_name}/{ds_name}", filename=filename, repo_type=repo_type)
async def load_model_card(model_id):
url = to_url(f"{model_id}/README.md")
r = await client.get(url)
if r is None:
return
return ModelCard(r.text, ignore_metadata_errors=True)
async def list_models(filtering=None):
params = {}
if filtering:
params["filter"] = filtering
r = await client.get(f"{constants.HF_API_URL}/models", params=params)
if r is None:
return
return r.json()
def restart_space():
space_id = os.getenv("SPACE_ID")
if space_id:
api.restart_space(repo_id=space_id, token=os.getenv("HF_TOKEN"))
|