|
import json |
|
import torch |
|
import open_clip |
|
from tqdm import tqdm |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if device == "cpu": |
|
device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
|
|
def generate_cache(texts: list[str], model_name: str, batch_size: int = 16) -> dict: |
|
model, _, _ = open_clip.create_model_and_transforms(model_name, device=device) |
|
tokenizer = open_clip.get_tokenizer(model_name) |
|
|
|
cache = {} |
|
|
|
for i in tqdm(range(0, len(texts), batch_size)): |
|
batch = texts[i : i + batch_size] |
|
tokens = tokenizer(batch).to(device) |
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
embeddings = model.encode_text(tokens, normalize=True).cpu().numpy() |
|
for text, embedding in zip(batch, embeddings): |
|
cache[text] = embedding.tolist() |
|
|
|
return cache |
|
|
|
|
|
def flatten_taxonomy(taxonomy: dict) -> list[str]: |
|
classes = [] |
|
for key, value in taxonomy.items(): |
|
classes.append(key) |
|
if isinstance(value, dict): |
|
classes.extend(flatten_taxonomy(value)) |
|
if isinstance(value, list): |
|
classes.extend(value) |
|
return classes |
|
|
|
|
|
def main(): |
|
models = [ |
|
"hf-hub:Marqo/marqo-ecommerce-embeddings-B", |
|
"hf-hub:Marqo/marqo-ecommerce-embeddings-L", |
|
"ViT-B-16" |
|
] |
|
|
|
with open("amazon.json") as f: |
|
taxonomy = json.load(f) |
|
print("Loaded taxonomy") |
|
|
|
print("Flattening taxonomy") |
|
texts = flatten_taxonomy(taxonomy) |
|
|
|
print("Generating cache") |
|
for model in models: |
|
cache = generate_cache(texts, model) |
|
with open(f'{model.split("/")[-1]}.json', "w+") as f: |
|
json.dump(cache, f) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|