import json import asyncio import logging import time import requests from tqdm.asyncio import tqdm_asyncio from huggingface_hub import get_inference_endpoint from models import env_config, embed_config logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) endpoint = get_inference_endpoint(env_config.tei_name, token=env_config.hf_token) async def embed_chunk(sentence, semaphore, tmp_file): async with semaphore: payload = { "inputs": sentence, "truncate": True } try: resp = await endpoint.async_client.post(json=payload) except Exception as e: raise RuntimeError(str(e)) result = json.loads(resp) tmp_file.write( json.dumps({"vector": result[0], env_config.input_text_col: sentence}) + "\n" ) async def embed_wrapper(input_ds, temp_file): semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound) jobs = [ asyncio.create_task(embed_chunk(row[env_config.input_text_col], semaphore, temp_file)) for row in input_ds if row[env_config.input_text_col].strip() ] logger.info(f"num chunks to embed: {len(jobs)}") tic = time.time() await tqdm_asyncio.gather(*jobs) logger.info(f"embed time: {time.time() - tic}") def wake_up_endpoint(): endpoint.fetch() if endpoint.status != 'running': logger.info("Starting up TEI endpoint") endpoint.resume().wait().fetch() # n_loop = 0 # while requests.get( # url=endpoint.url, # headers={"Authorization": f"Bearer {env_config.hf_token}"} # ).status_code != 200: # time.sleep(2) # n_loop += 1 # if n_loop > 20: # raise TimeoutError("TEI endpoint is unavailable") logger.info("TEI endpoint is up") return