File size: 1,737 Bytes
b6c64a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import json
import torch
import open_clip
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    device = "mps" if torch.backends.mps.is_available() else "cpu"


def generate_cache(texts: list[str], model_name: str, batch_size: int = 16) -> dict:
    model, _, _ = open_clip.create_model_and_transforms(model_name, device=device)
    tokenizer = open_clip.get_tokenizer(model_name)

    cache = {}

    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i : i + batch_size]
        tokens = tokenizer(batch).to(device)
        with torch.no_grad(), torch.cuda.amp.autocast():
            embeddings = model.encode_text(tokens, normalize=True).cpu().numpy()
        for text, embedding in zip(batch, embeddings):
            cache[text] = embedding.tolist()

    return cache


def flatten_taxonomy(taxonomy: dict) -> list[str]:
    classes = []
    for key, value in taxonomy.items():
        classes.append(key)
        if isinstance(value, dict):
            classes.extend(flatten_taxonomy(value))
        if isinstance(value, list):
            classes.extend(value)
    return classes


def main():
    models = [
        "hf-hub:Marqo/marqo-ecommerce-embeddings-B",
        "hf-hub:Marqo/marqo-ecommerce-embeddings-L",
        "ViT-B-16"
    ]

    with open("amazon.json") as f:
        taxonomy = json.load(f)
    print("Loaded taxonomy")

    print("Flattening taxonomy")
    texts = flatten_taxonomy(taxonomy)

    print("Generating cache")
    for model in models:
        cache = generate_cache(texts, model)
        with open(f'{model.split("/")[-1]}.json', "w+") as f:
            json.dump(cache, f)


if __name__ == "__main__":
    main()