e-commerce-taxonomy-mapping / cache_taxonomy_vectors.py
OwenElliott's picture
Upload 18 files
b6c64a0 verified
raw
history blame
1.74 kB
import json
import torch
import open_clip
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cpu":
device = "mps" if torch.backends.mps.is_available() else "cpu"
def generate_cache(texts: list[str], model_name: str, batch_size: int = 16) -> dict:
model, _, _ = open_clip.create_model_and_transforms(model_name, device=device)
tokenizer = open_clip.get_tokenizer(model_name)
cache = {}
for i in tqdm(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
tokens = tokenizer(batch).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
embeddings = model.encode_text(tokens, normalize=True).cpu().numpy()
for text, embedding in zip(batch, embeddings):
cache[text] = embedding.tolist()
return cache
def flatten_taxonomy(taxonomy: dict) -> list[str]:
classes = []
for key, value in taxonomy.items():
classes.append(key)
if isinstance(value, dict):
classes.extend(flatten_taxonomy(value))
if isinstance(value, list):
classes.extend(value)
return classes
def main():
models = [
"hf-hub:Marqo/marqo-ecommerce-embeddings-B",
"hf-hub:Marqo/marqo-ecommerce-embeddings-L",
"ViT-B-16"
]
with open("amazon.json") as f:
taxonomy = json.load(f)
print("Loaded taxonomy")
print("Flattening taxonomy")
texts = flatten_taxonomy(taxonomy)
print("Generating cache")
for model in models:
cache = generate_cache(texts, model)
with open(f'{model.split("/")[-1]}.json', "w+") as f:
json.dump(cache, f)
if __name__ == "__main__":
main()