import json import torch import open_clip from tqdm import tqdm device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device == "cpu": device = "mps" if torch.backends.mps.is_available() else "cpu" def generate_cache(texts: list[str], model_name: str, batch_size: int = 16) -> dict: model, _, _ = open_clip.create_model_and_transforms(model_name, device=device) tokenizer = open_clip.get_tokenizer(model_name) cache = {} for i in tqdm(range(0, len(texts), batch_size)): batch = texts[i : i + batch_size] tokens = tokenizer(batch).to(device) with torch.no_grad(), torch.cuda.amp.autocast(): embeddings = model.encode_text(tokens, normalize=True).cpu().numpy() for text, embedding in zip(batch, embeddings): cache[text] = embedding.tolist() return cache def flatten_taxonomy(taxonomy: dict) -> list[str]: classes = [] for key, value in taxonomy.items(): classes.append(key) if isinstance(value, dict): classes.extend(flatten_taxonomy(value)) if isinstance(value, list): classes.extend(value) return classes def main(): models = [ "hf-hub:Marqo/marqo-ecommerce-embeddings-B", "hf-hub:Marqo/marqo-ecommerce-embeddings-L", "ViT-B-16" ] with open("amazon.json") as f: taxonomy = json.load(f) print("Loaded taxonomy") print("Flattening taxonomy") texts = flatten_taxonomy(taxonomy) print("Generating cache") for model in models: cache = generate_cache(texts, model) with open(f'{model.split("/")[-1]}.json', "w+") as f: json.dump(cache, f) if __name__ == "__main__": main()