File size: 2,006 Bytes
a8b3f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from unittest.mock import MagicMock, patch

import pytest

from core.rag.datasource.vdb.oceanbase.oceanbase_vector import (
    OceanBaseVector,
    OceanBaseVectorConfig,
)
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
from tests.integration_tests.vdb.test_vector_store import (
    AbstractVectorTest,
    get_example_text,
    setup_mock_redis,
)


@pytest.fixture
def oceanbase_vector():
    return OceanBaseVector(
        "dify_test_collection",
        config=OceanBaseVectorConfig(
            host="127.0.0.1",
            port="2881",
            user="root@test",
            database="test",
            password="test",
        ),
    )


class OceanBaseVectorTest(AbstractVectorTest):
    def __init__(self, vector: OceanBaseVector):
        super().__init__()
        self.vector = vector

    def search_by_vector(self):
        hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
        assert len(hits_by_vector) == 0

    def search_by_full_text(self):
        hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
        assert len(hits_by_full_text) == 0

    def text_exists(self):
        exist = self.vector.text_exists(self.example_doc_id)
        assert exist == True

    def get_ids_by_metadata_field(self):
        ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
        assert len(ids) == 0


@pytest.fixture
def setup_mock_oceanbase_client():
    with patch("core.rag.datasource.vdb.oceanbase.oceanbase_vector.ObVecClient", new_callable=MagicMock) as mock_client:
        yield mock_client


@pytest.fixture
def setup_mock_oceanbase_vector(oceanbase_vector):
    with patch.object(oceanbase_vector, "_client"):
        yield oceanbase_vector


def test_oceanbase_vector(
    setup_mock_redis,
    setup_mock_oceanbase_client,
    setup_mock_oceanbase_vector,
    oceanbase_vector,
):
    OceanBaseVectorTest(oceanbase_vector).run_all_tests()