Spaces:
Build error
Build error
File size: 1,491 Bytes
235b9c1 93e5f33 235b9c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
"""
Based on transformers python API.
This script turn list of string into embeddings.
"""
from transformers import AutoTokenizer, TFAutoModel
import tensorflow as tf
class Embed(object):
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
self.model = TFAutoModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
@staticmethod
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = tf.cast(tf.tile(tf.expand_dims(attention_mask, -1), [1, 1, token_embeddings.shape[-1]]),
tf.float32)
return tf.math.reduce_sum(token_embeddings * input_mask_expanded, 1) / tf.math.maximum(
tf.math.reduce_sum(input_mask_expanded, 1), 1e-9)
# Encode text
def encode(self, texts):
# Tokenize sentences
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='tf')
# Compute token embeddings
model_output = self.model(**encoded_input, return_dict=True)
# Perform pooling
embeddings = Embed.mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
embeddings = tf.math.l2_normalize(embeddings, axis=1)
return embeddings
|