File size: 2,274 Bytes
0df2f2a
 
3d541d7
efcd725
0df2f2a
3d541d7
 
 
 
 
 
 
 
 
 
 
efcd725
3d541d7
 
 
 
0df2f2a
 
efcd725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0df2f2a
efcd725
0df2f2a
 
 
 
efcd725
 
 
 
 
 
 
 
0df2f2a
efcd725
 
 
 
 
 
0df2f2a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
from typing import List

class PersonEmbeddings(nn.Module):
    def __init__(self, model_id: str):
        super().__init__()
        self.base_model = AutoModel.from_pretrained(model_id)
        self.projection = nn.Sequential(
            nn.Linear(768, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1536)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state  # (B, seq_len, 768)
        mean_pooled = last_hidden.mean(dim=1)    # (B, 768)
        embeddings = self.projection(mean_pooled)  # (B, 1536)
        return embeddings

class CustomEmbeddingPipeline:
    """
    Loads tokenizer + PersonEmbeddings from the *same* HF repo so that
    the vocabulary is consistent with the model weights.
    """
    def __init__(self, repo_id="charlieoneill/my_modernbert_person_embeddings"):
        # 1. Load tokenizer from your own HF repo, 
        #    which contains tokenizer.json, special_tokens_map.json, etc.
        self.tokenizer = AutoTokenizer.from_pretrained(repo_id)

        # 2. Create your PersonEmbeddings using the same repo_id 
        #    so AutoModel inside PersonEmbeddings will match
        self.model = PersonEmbeddings(repo_id)

        # 3. Load your fine-tuned state dict from local file (pytorch_model.bin).
        #    (It's typically named this in your HF repo. Make sure your repo has it!)
        ckpt_path = "pytorch_model.bin"
        state_dict = torch.load(ckpt_path, map_location="cpu")
        self.model.load_state_dict(state_dict)
        self.model.eval()

    def __call__(self, text: str) -> List[float]:
        # Tokenize input
        inputs = self.tokenizer(
            [text], 
            padding=True, 
            truncation=True, 
            return_tensors="pt"
        )

        with torch.no_grad():
            emb = self.model(
                inputs["input_ids"], 
                inputs["attention_mask"]
            )  # shape: (1, 1536)

        # Return as a Python list
        return emb[0].tolist()

def pipeline(*args, **kwargs):
    return CustomEmbeddingPipeline()