|
""" Milvus memory storage provider.""" |
|
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections |
|
|
|
from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding |
|
|
|
|
|
class MilvusMemory(MemoryProviderSingleton): |
|
"""Milvus memory storage provider.""" |
|
|
|
def __init__(self, cfg) -> None: |
|
"""Construct a milvus memory storage connection. |
|
|
|
Args: |
|
cfg (Config): Auto-GPT global config. |
|
""" |
|
|
|
connections.connect(address=cfg.milvus_addr) |
|
fields = [ |
|
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), |
|
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=1536), |
|
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535), |
|
] |
|
|
|
|
|
self.milvus_collection = cfg.milvus_collection |
|
self.schema = CollectionSchema(fields, "auto-gpt memory storage") |
|
self.collection = Collection(self.milvus_collection, self.schema) |
|
|
|
if not self.collection.has_index(): |
|
self.collection.release() |
|
self.collection.create_index( |
|
"embeddings", |
|
{ |
|
"metric_type": "IP", |
|
"index_type": "HNSW", |
|
"params": {"M": 8, "efConstruction": 64}, |
|
}, |
|
index_name="embeddings", |
|
) |
|
self.collection.load() |
|
|
|
def add(self, data) -> str: |
|
"""Add an embedding of data into memory. |
|
|
|
Args: |
|
data (str): The raw text to construct embedding index. |
|
|
|
Returns: |
|
str: log. |
|
""" |
|
embedding = get_ada_embedding(data) |
|
result = self.collection.insert([[embedding], [data]]) |
|
_text = ( |
|
"Inserting data into memory at primary key: " |
|
f"{result.primary_keys[0]}:\n data: {data}" |
|
) |
|
return _text |
|
|
|
def get(self, data): |
|
"""Return the most relevant data in memory. |
|
Args: |
|
data: The data to compare to. |
|
""" |
|
return self.get_relevant(data, 1) |
|
|
|
def clear(self) -> str: |
|
"""Drop the index in memory. |
|
|
|
Returns: |
|
str: log. |
|
""" |
|
self.collection.drop() |
|
self.collection = Collection(self.milvus_collection, self.schema) |
|
self.collection.create_index( |
|
"embeddings", |
|
{ |
|
"metric_type": "IP", |
|
"index_type": "HNSW", |
|
"params": {"M": 8, "efConstruction": 64}, |
|
}, |
|
index_name="embeddings", |
|
) |
|
self.collection.load() |
|
return "Obliviated" |
|
|
|
def get_relevant(self, data: str, num_relevant: int = 5): |
|
"""Return the top-k relevant data in memory. |
|
Args: |
|
data: The data to compare to. |
|
num_relevant (int, optional): The max number of relevant data. |
|
Defaults to 5. |
|
|
|
Returns: |
|
list: The top-k relevant data. |
|
""" |
|
|
|
embedding = get_ada_embedding(data) |
|
search_params = { |
|
"metrics_type": "IP", |
|
"params": {"nprobe": 8}, |
|
} |
|
result = self.collection.search( |
|
[embedding], |
|
"embeddings", |
|
search_params, |
|
num_relevant, |
|
output_fields=["raw_text"], |
|
) |
|
return [item.entity.value_of_field("raw_text") for item in result[0]] |
|
|
|
def get_stats(self) -> str: |
|
""" |
|
Returns: The stats of the milvus cache. |
|
""" |
|
return f"Entities num: {self.collection.num_entities}" |
|
|