import os from typing import Union from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch from volcengine.viking_db import ( Collection, Data, DistanceType, Field, FieldType, Index, IndexType, QuantType, VectorIndexParams, VikingDBService, ) from core.rag.datasource.vdb.field import Field as vdb_Field class MockVikingDBClass: def __init__( self, host="api-vikingdb.volces.com", region="cn-north-1", ak="", sk="", scheme="http", connection_timeout=30, socket_timeout=30, proxy=None, ): self._viking_db_service = MagicMock() self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}') def get_collection(self, collection_name) -> Collection: return Collection( collection_name=collection_name, description="Collection For Dify", viking_db_service=self._viking_db_service, primary_key=vdb_Field.PRIMARY_KEY.value, fields=[ Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768), ], indexes=[ Index( collection_name=collection_name, index_name=f"{collection_name}_idx", vector_index=VectorIndexParams( distance=DistanceType.L2, index_type=IndexType.HNSW, quant=QuantType.Float, ), scalar_index=None, stat=None, viking_db_service=self._viking_db_service, ) ], ) def drop_collection(self, collection_name): assert collection_name != "" def create_collection(self, collection_name, fields, description="") -> Collection: return Collection( collection_name=collection_name, description=description, primary_key=vdb_Field.PRIMARY_KEY.value, viking_db_service=self._viking_db_service, fields=fields, ) def get_index(self, collection_name, index_name) -> Index: return Index( collection_name=collection_name, index_name=index_name, viking_db_service=self._viking_db_service, stat=None, scalar_index=None, vector_index=VectorIndexParams( distance=DistanceType.L2, index_type=IndexType.HNSW, quant=QuantType.Float, ), ) def create_index( self, collection_name, index_name, vector_index=None, cpu_quota=2, description="", partition_by="", scalar_index=None, shard_count=None, shard_policy=None, ): return Index( collection_name=collection_name, index_name=index_name, vector_index=vector_index, cpu_quota=cpu_quota, description=description, partition_by=partition_by, scalar_index=scalar_index, shard_count=shard_count, shard_policy=shard_policy, viking_db_service=self._viking_db_service, stat=None, ) def drop_index(self, collection_name, index_name): assert collection_name != "" assert index_name != "" def upsert_data(self, data: Union[Data, list[Data]]): assert data is not None def fetch_data(self, id: Union[str, list[str], int, list[int]]): return Data( fields={ vdb_Field.GROUP_KEY.value: "test_group", vdb_Field.METADATA_KEY.value: "{}", vdb_Field.CONTENT_KEY.value: "content", vdb_Field.PRIMARY_KEY.value: id, vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], }, id=id, ) def delete_data(self, id: Union[str, list[str], int, list[int]]): assert id is not None def search_by_vector( self, vector, sparse_vectors=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None, ) -> list[Data]: return [ Data( fields={ vdb_Field.GROUP_KEY.value: "test_group", vdb_Field.METADATA_KEY.value: '\ {"source": "/var/folders/ml/xxx/xxx.txt", \ "document_id": "test_document_id", \ "dataset_id": "test_dataset_id", \ "doc_id": "test_id", \ "doc_hash": "test_hash"}', vdb_Field.CONTENT_KEY.value: "content", vdb_Field.PRIMARY_KEY.value: "test_id", vdb_Field.VECTOR.value: vector, }, id="test_id", score=0.10, ) ] def search( self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None ) -> list[Data]: return [ Data( fields={ vdb_Field.GROUP_KEY.value: "test_group", vdb_Field.METADATA_KEY.value: '\ {"source": "/var/folders/ml/xxx/xxx.txt", \ "document_id": "test_document_id", \ "dataset_id": "test_dataset_id", \ "doc_id": "test_id", \ "doc_hash": "test_hash"}', vdb_Field.CONTENT_KEY.value: "content", vdb_Field.PRIMARY_KEY.value: "test_id", vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], }, id="test_id", score=0.10, ) ] MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_vikingdb_mock(monkeypatch: MonkeyPatch): if MOCK: monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__) monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection) monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection) monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection) monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index) monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index) monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index) monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data) monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data) monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data) monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector) monkeypatch.setattr(Index, "search", MockVikingDBClass.search) yield if MOCK: monkeypatch.undo()