|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
|
|
|
from typing import Any, List |
|
|
|
|
|
class MyEmbeddings(HuggingFaceEmbeddings): |
|
def __init__(self, **kwargs: Any): |
|
super().__init__(**kwargs) |
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
"""Compute doc embeddings using a HuggingFace transformer model. |
|
|
|
Args: |
|
texts: The list of texts to embed. |
|
|
|
Returns: |
|
List of embeddings, one for each text. |
|
""" |
|
texts = list(map(lambda x: x.replace("\n", " "), texts)) |
|
embeddings = self.client.encode(texts, normalize_embeddings=True) |
|
return embeddings.tolist() |
|
|
|
def embed_query(self, text: str) -> List[float]: |
|
"""Compute query embeddings using a HuggingFace transformer model. |
|
|
|
Args: |
|
text: The text to embed. |
|
|
|
Returns: |
|
Embeddings for the text. |
|
""" |
|
text = text.replace("\n", " ") |
|
embedding = self.client.encode(text, normalize_embeddings=True) |
|
return embedding.tolist() |
|
|