charlieoneill commited on
Commit
3d541d7
·
verified ·
1 Parent(s): 0456053

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +20 -1
pipeline.py CHANGED
@@ -1,8 +1,27 @@
1
  from transformers import AutoTokenizer, AutoModel
2
  import torch
3
  from typing import List
 
4
 
5
- from model import PersonEmbeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class CustomEmbeddingPipeline:
8
  def __init__(self, model_id="answerdotai/ModernBERT-base"):
 
1
  from transformers import AutoTokenizer, AutoModel
2
  import torch
3
  from typing import List
4
+ import torch.nn as nn
5
 
6
+ class PersonEmbeddings(nn.Module):
7
+ def __init__(self, model_id: str):
8
+ super().__init__()
9
+ self.base_model = AutoModel.from_pretrained(model_id)
10
+ self.projection = nn.Sequential(
11
+ nn.Linear(768, 1024),
12
+ nn.ReLU(),
13
+ nn.Linear(1024, 1536)
14
+ )
15
+
16
+ def forward(self, input_ids, attention_mask):
17
+ outputs = self.base_model(
18
+ input_ids=input_ids,
19
+ attention_mask=attention_mask
20
+ )
21
+ last_hidden = outputs.last_hidden_state # (B, seq_len, 768)
22
+ mean_pooled = last_hidden.mean(dim=1) # (B, 768)
23
+ embeddings = self.projection(mean_pooled) # (B, 1536)
24
+ return embeddings
25
 
26
  class CustomEmbeddingPipeline:
27
  def __init__(self, model_id="answerdotai/ModernBERT-base"):