| import json | |
| import torch | |
| import open_clip | |
| from tqdm import tqdm | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if device == "cpu": | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| def generate_cache(texts: list[str], model_name: str, batch_size: int = 16) -> dict: | |
| model, _, _ = open_clip.create_model_and_transforms(model_name, device=device) | |
| tokenizer = open_clip.get_tokenizer(model_name) | |
| cache = {} | |
| for i in tqdm(range(0, len(texts), batch_size)): | |
| batch = texts[i : i + batch_size] | |
| tokens = tokenizer(batch).to(device) | |
| with torch.no_grad(), torch.cuda.amp.autocast(): | |
| embeddings = model.encode_text(tokens, normalize=True).cpu().numpy() | |
| for text, embedding in zip(batch, embeddings): | |
| cache[text] = embedding.tolist() | |
| return cache | |
| def flatten_taxonomy(taxonomy: dict) -> list[str]: | |
| classes = [] | |
| for key, value in taxonomy.items(): | |
| classes.append(key) | |
| if isinstance(value, dict): | |
| classes.extend(flatten_taxonomy(value)) | |
| if isinstance(value, list): | |
| classes.extend(value) | |
| return classes | |
| def main(): | |
| models = [ | |
| "hf-hub:Marqo/marqo-ecommerce-embeddings-B", | |
| "hf-hub:Marqo/marqo-ecommerce-embeddings-L", | |
| "ViT-B-16" | |
| ] | |
| with open("amazon.json") as f: | |
| taxonomy = json.load(f) | |
| print("Loaded taxonomy") | |
| print("Flattening taxonomy") | |
| texts = flatten_taxonomy(taxonomy) | |
| print("Generating cache") | |
| for model in models: | |
| cache = generate_cache(texts, model) | |
| with open(f'{model.split("/")[-1]}.json', "w+") as f: | |
| json.dump(cache, f) | |
| if __name__ == "__main__": | |
| main() | |