charlieoneill's picture
Update pipeline.py
b3ae512 verified
raw
history blame
1.71 kB
from transformers import AutoTokenizer, AutoModel
import torch
from typing import List
import torch.nn as nn
class PersonEmbeddings(nn.Module):
def __init__(self, model_id: str):
super().__init__()
self.base_model = AutoModel.from_pretrained(model_id)
self.projection = nn.Sequential(
nn.Linear(768, 1024),
nn.ReLU(),
nn.Linear(1024, 1536)
)
def forward(self, input_ids, attention_mask):
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask
)
last_hidden = outputs.last_hidden_state # (B, seq_len, 768)
mean_pooled = last_hidden.mean(dim=1) # (B, 768)
embeddings = self.projection(mean_pooled) # (B, 1536)
return embeddings
class CustomEmbeddingPipeline:
def __init__(self, model_id="answerdotai/ModernBERT-base"):
# Load your base tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("charlieoneill/my_modernbert_person_embeddings")
# Load your PersonEmbeddings
self.model = PersonEmbeddings(model_id)
ckpt_path = "pytorch_model.bin"
state_dict = torch.load(ckpt_path)
self.model.load_state_dict(state_dict)
self.model.eval()
def __call__(self, text: str) -> List[float]:
# Tokenize
inputs = self.tokenizer([text], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
emb = self.model(inputs["input_ids"], inputs["attention_mask"])
# Return the embedding of shape (1, 1536) as a Python list
return emb[0].tolist()
def pipeline(*args, **kwargs):
return CustomEmbeddingPipeline()