Spaces:
Build error
Build error
import os | |
from typing import Union | |
from unittest.mock import MagicMock | |
import pytest | |
from _pytest.monkeypatch import MonkeyPatch | |
from volcengine.viking_db import ( | |
Collection, | |
Data, | |
DistanceType, | |
Field, | |
FieldType, | |
Index, | |
IndexType, | |
QuantType, | |
VectorIndexParams, | |
VikingDBService, | |
) | |
from core.rag.datasource.vdb.field import Field as vdb_Field | |
class MockVikingDBClass: | |
def __init__( | |
self, | |
host="api-vikingdb.volces.com", | |
region="cn-north-1", | |
ak="", | |
sk="", | |
scheme="http", | |
connection_timeout=30, | |
socket_timeout=30, | |
proxy=None, | |
): | |
self._viking_db_service = MagicMock() | |
self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}') | |
def get_collection(self, collection_name) -> Collection: | |
return Collection( | |
collection_name=collection_name, | |
description="Collection For Dify", | |
viking_db_service=self._viking_db_service, | |
primary_key=vdb_Field.PRIMARY_KEY.value, | |
fields=[ | |
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), | |
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), | |
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), | |
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), | |
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768), | |
], | |
indexes=[ | |
Index( | |
collection_name=collection_name, | |
index_name=f"{collection_name}_idx", | |
vector_index=VectorIndexParams( | |
distance=DistanceType.L2, | |
index_type=IndexType.HNSW, | |
quant=QuantType.Float, | |
), | |
scalar_index=None, | |
stat=None, | |
viking_db_service=self._viking_db_service, | |
) | |
], | |
) | |
def drop_collection(self, collection_name): | |
assert collection_name != "" | |
def create_collection(self, collection_name, fields, description="") -> Collection: | |
return Collection( | |
collection_name=collection_name, | |
description=description, | |
primary_key=vdb_Field.PRIMARY_KEY.value, | |
viking_db_service=self._viking_db_service, | |
fields=fields, | |
) | |
def get_index(self, collection_name, index_name) -> Index: | |
return Index( | |
collection_name=collection_name, | |
index_name=index_name, | |
viking_db_service=self._viking_db_service, | |
stat=None, | |
scalar_index=None, | |
vector_index=VectorIndexParams( | |
distance=DistanceType.L2, | |
index_type=IndexType.HNSW, | |
quant=QuantType.Float, | |
), | |
) | |
def create_index( | |
self, | |
collection_name, | |
index_name, | |
vector_index=None, | |
cpu_quota=2, | |
description="", | |
partition_by="", | |
scalar_index=None, | |
shard_count=None, | |
shard_policy=None, | |
): | |
return Index( | |
collection_name=collection_name, | |
index_name=index_name, | |
vector_index=vector_index, | |
cpu_quota=cpu_quota, | |
description=description, | |
partition_by=partition_by, | |
scalar_index=scalar_index, | |
shard_count=shard_count, | |
shard_policy=shard_policy, | |
viking_db_service=self._viking_db_service, | |
stat=None, | |
) | |
def drop_index(self, collection_name, index_name): | |
assert collection_name != "" | |
assert index_name != "" | |
def upsert_data(self, data: Union[Data, list[Data]]): | |
assert data is not None | |
def fetch_data(self, id: Union[str, list[str], int, list[int]]): | |
return Data( | |
fields={ | |
vdb_Field.GROUP_KEY.value: "test_group", | |
vdb_Field.METADATA_KEY.value: "{}", | |
vdb_Field.CONTENT_KEY.value: "content", | |
vdb_Field.PRIMARY_KEY.value: id, | |
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], | |
}, | |
id=id, | |
) | |
def delete_data(self, id: Union[str, list[str], int, list[int]]): | |
assert id is not None | |
def search_by_vector( | |
self, | |
vector, | |
sparse_vectors=None, | |
filter=None, | |
limit=10, | |
output_fields=None, | |
partition="default", | |
dense_weight=None, | |
) -> list[Data]: | |
return [ | |
Data( | |
fields={ | |
vdb_Field.GROUP_KEY.value: "test_group", | |
vdb_Field.METADATA_KEY.value: '\ | |
{"source": "/var/folders/ml/xxx/xxx.txt", \ | |
"document_id": "test_document_id", \ | |
"dataset_id": "test_dataset_id", \ | |
"doc_id": "test_id", \ | |
"doc_hash": "test_hash"}', | |
vdb_Field.CONTENT_KEY.value: "content", | |
vdb_Field.PRIMARY_KEY.value: "test_id", | |
vdb_Field.VECTOR.value: vector, | |
}, | |
id="test_id", | |
score=0.10, | |
) | |
] | |
def search( | |
self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None | |
) -> list[Data]: | |
return [ | |
Data( | |
fields={ | |
vdb_Field.GROUP_KEY.value: "test_group", | |
vdb_Field.METADATA_KEY.value: '\ | |
{"source": "/var/folders/ml/xxx/xxx.txt", \ | |
"document_id": "test_document_id", \ | |
"dataset_id": "test_dataset_id", \ | |
"doc_id": "test_id", \ | |
"doc_hash": "test_hash"}', | |
vdb_Field.CONTENT_KEY.value: "content", | |
vdb_Field.PRIMARY_KEY.value: "test_id", | |
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], | |
}, | |
id="test_id", | |
score=0.10, | |
) | |
] | |
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" | |
def setup_vikingdb_mock(monkeypatch: MonkeyPatch): | |
if MOCK: | |
monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__) | |
monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection) | |
monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection) | |
monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection) | |
monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index) | |
monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index) | |
monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index) | |
monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data) | |
monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data) | |
monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data) | |
monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector) | |
monkeypatch.setattr(Index, "search", MockVikingDBClass.search) | |
yield | |
if MOCK: | |
monkeypatch.undo() | |