diff --git a/python/tests/detail/test_collection_dql.py b/python/tests/detail/test_collection_dql.py index d52529111..51ab590e8 100644 --- a/python/tests/detail/test_collection_dql.py +++ b/python/tests/detail/test_collection_dql.py @@ -731,7 +731,7 @@ def test_query_multivector_rrf(self, full_collection: Collection, doc_num): ) expected_score = expected_rrf_scores[doc_id] actual_score = doc.score - assert abs(actual_score - expected_score) < 1e-10, ( + assert abs(actual_score - expected_score) < 1e-6, ( f"RRF score mismatch for document {doc_id}: expected {expected_score}, got {actual_score}" ) assert doc.score <= prev_score, ( @@ -799,7 +799,7 @@ def test_query_multivector_weighted( ) expected_score = expected_weighted_scores[doc_id] actual_score = doc.score - assert abs(actual_score - expected_score) < 1e-10, ( + assert abs(actual_score - expected_score) < 1e-6, ( f"score mismatch for document {doc_id}: expected {expected_score}, got {actual_score}" ) assert doc.score <= prev_score, ( diff --git a/python/tests/test_collection.py b/python/tests/test_collection.py index 9b84eb723..2c31c6757 100644 --- a/python/tests/test_collection.py +++ b/python/tests/test_collection.py @@ -27,11 +27,16 @@ InvertIndexParam, LogLevel, LogType, + MetricType, OptimizeOption, StatusCode, Query, VectorSchema, ) +from zvec.extension.multi_vector_reranker import ( + RrfReRanker, + WeightedReRanker, +) # ==================== Common ==================== @@ -60,9 +65,18 @@ def collection_schema(): dimension=128, index_param=HnswIndexParam(), ), + VectorSchema( + "dense2", + DataType.VECTOR_FP32, + dimension=128, + index_param=HnswIndexParam(), + ), VectorSchema( "sparse", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() ), + VectorSchema( + "sparse2", DataType.SPARSE_VECTOR_FP32, index_param=HnswIndexParam() + ), ], ) @@ -78,7 +92,12 @@ def single_doc(): return Doc( id=f"{id}", fields={"id": id, "name": "test", "weight": 80.0, "height": id + 140}, - vectors={"dense": [id + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [id + 0.1] * 128, + "dense2": [id + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) @@ -88,7 +107,12 @@ def multiple_docs(): Doc( id=f"{id}", fields={"id": id, "name": "test", "weight": 80.0, "height": 210}, - vectors={"dense": [id + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [id + 0.1] * 128, + "dense2": [id + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) for id in range(1, 101) ] @@ -182,9 +206,11 @@ def test_collection_stats(self, test_collection: Collection): assert test_collection.stats is not None stats = test_collection.stats assert stats.doc_count == 0 - assert len(stats.index_completeness) == 2 + assert len(stats.index_completeness) == 4 assert stats.index_completeness["dense"] == 1 + assert stats.index_completeness["dense2"] == 1 assert stats.index_completeness["sparse"] == 1 + assert stats.index_completeness["sparse2"] == 1 # ---------------------------- @@ -449,7 +475,12 @@ def test_collection_insert_with_nullable_false_field(self, test_collection): "id": 1, "name": "test", }, - vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [1 + 0.1] * 128, + "dense2": [1 + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) result = test_collection.insert(doc) assert bool(result) @@ -465,7 +496,12 @@ def test_collection_insert_without_nullable_false_field(self, test_collection): # without id, name doc = Doc( id="0", - vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [1 + 0.1] * 128, + "dense2": [1 + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) with pytest.raises(ValueError) as e: # ValueError: Invalid doc: field[id] is required but not provided @@ -478,7 +514,12 @@ def test_collection_insert_without_nullable_false_field(self, test_collection): fields={ "id": 1, }, - vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [1 + 0.1] * 128, + "dense2": [1 + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) with pytest.raises(ValueError) as e: test_collection.insert(doc) @@ -494,7 +535,12 @@ def test_collection_insert_with_nullable_true_field(self, test_collection): "id": 1, "name": "test", }, - vectors={"dense": [1 + 0.1] * 128, "sparse": {1: 1.0, 2: 2.0, 3: 3.0}}, + vectors={ + "dense": [1 + 0.1] * 128, + "dense2": [1 + 0.2] * 128, + "sparse": {1: 1.0, 2: 2.0, 3: 3.0}, + "sparse2": {4: 1.5, 5: 2.5, 6: 3.5}, + }, ) result = test_collection.insert(doc) assert bool(result) @@ -969,70 +1015,177 @@ def test_collection_query_by_id( def test_collection_query_multi_vector_with_same_field( self, collection_with_multiple_docs: Collection, multiple_docs ): - with pytest.raises(ValueError): + # Multi-vector query on same field without reranker should raise ValueError + with pytest.raises(ValueError, match="Reranker is required"): collection_with_multiple_docs.query( [ Query(field_name="dense", vector=multiple_docs[0].vector("dense")), - Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="dense", vector=multiple_docs[1].vector("dense")), ] ) - @pytest.mark.skip(reason="TODO: This test case is pending implementation") + # Same field name with reranker should also raise ValueError + reranker = RrfReRanker(topn=10, rank_constant=60) + with pytest.raises(ValueError, match="appears more than once"): + collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="dense", vector=multiple_docs[1].vector("dense")), + ], + topk=10, + reranker=reranker, + ) + def test_collection_query_by_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + result = collection_with_multiple_docs.query( + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + topk=10, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + result = collection_with_multiple_docs.query( + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + topk=10, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_dense_vector_with_filter( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + result = collection_with_multiple_docs.query( + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + topk=10, + filter="id > 50", + ) + assert len(result) > 0 + assert len(result) <= 10 + for doc in result: + assert int(doc.id) > 50 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_by_sparse_vector_with_filter( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + result = collection_with_multiple_docs.query( + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + topk=10, + filter="id > 50", + ) + assert len(result) > 0 + assert len(result) <= 10 + for doc in result: + assert int(doc.id) > 50 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_rrf_reranker_by_multi_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with RRF reranker on multiple dense vectors.""" + reranker = RrfReRanker(topn=10, rank_constant=60) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="dense2", vector=multiple_docs[0].vector("dense2")), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 + # Results should have RRF-fused scores + for doc in result: + assert hasattr(doc, "score") - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_rrf_reranker_by_multi_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with RRF reranker on multiple sparse vectors.""" + reranker = RrfReRanker(topn=10, rank_constant=60) + result = collection_with_multiple_docs.query( + [ + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + Query( + field_name="sparse2", + vector=multiple_docs[0].vector("sparse2"), + ), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_rrf_reranker_by_hybrid_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with RRF reranker combining dense + sparse.""" + reranker = RrfReRanker(topn=10, rank_constant=60) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_weighted_reranker_by_multi_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with Weighted reranker on multiple dense vectors.""" + weights = {"dense": 0.6, "dense2": 0.4} + reranker = WeightedReRanker(topn=10, metric=MetricType.IP, weights=weights) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="dense2", vector=multiple_docs[0].vector("dense2")), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_weighted_reranker_by_multi_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with Weighted reranker on multiple sparse vectors.""" + weights = {"sparse": 0.6, "sparse2": 0.4} + reranker = WeightedReRanker(topn=10, metric=MetricType.IP, weights=weights) + result = collection_with_multiple_docs.query( + [ + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + Query( + field_name="sparse2", + vector=multiple_docs[0].vector("sparse2"), + ), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 - @pytest.mark.skip(reason="TODO: This test case is pending implementation") def test_collection_query_with_weighted_reranker_by_hybrid_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): - pass + """Test multi-vector query with Weighted reranker combining dense + sparse.""" + weights = {"dense": 0.7, "sparse": 0.3} + reranker = WeightedReRanker(topn=10, metric=MetricType.IP, weights=weights) + result = collection_with_multiple_docs.query( + [ + Query(field_name="dense", vector=multiple_docs[0].vector("dense")), + Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), + ], + topk=10, + reranker=reranker, + ) + assert len(result) > 0 + assert len(result) <= 10 diff --git a/python/zvec/executor/query_executor.py b/python/zvec/executor/query_executor.py index 3e54e37d2..eeb13be1b 100644 --- a/python/zvec/executor/query_executor.py +++ b/python/zvec/executor/query_executor.py @@ -19,8 +19,8 @@ from typing import Optional, Union, final import numpy as np -from _zvec import _Collection -from _zvec.param import _VectorQuery +from _zvec import _Collection, _MultiVectorQuery +from _zvec.param import _SubVectorQuery, _VectorQuery from ..extension import ReRanker, RrfReRanker, WeightedReRanker from ..model.convert import convert_to_py_doc @@ -290,11 +290,44 @@ def _do_validate(self, ctx: QueryContext) -> None: raise ValueError(f"Query field name '{field}' appears more than once") seen_fields.add(field) + def execute(self, ctx: QueryContext, collection: _Collection) -> list[Doc]: + # 1. validate query + self._do_validate(ctx) + # 2. build query vectors + query_vectors = self._do_build(ctx, collection) + if not query_vectors: + raise ValueError("No query to execute") + + # Fast path: use C++ MultiQuery for multi-vector with C++ reranker + if len(query_vectors) > 1 and ctx.reranker is not None: + cpp_reranker = ctx.reranker._get_object() + if cpp_reranker is not None: + mvq = _MultiVectorQuery() + mvq.queries = [self._to_sub_vector_query(vq) for vq in query_vectors] + mvq.topk = ctx.topk + if ctx.filter: + mvq.filter = ctx.filter + mvq.include_vector = ctx.include_vector + if ctx.output_fields: + mvq.output_fields = ctx.output_fields + mvq.reranker = cpp_reranker + docs = collection.MultiQuery(mvq) + return [convert_to_py_doc(doc, self._schema) for doc in docs] + + # 3. execute query (fallback to Python path) + docs = self._do_execute(query_vectors, collection) + # 4. merge and rerank result + return self._do_merge_rerank_results(ctx, docs) + def _do_execute( self, vectors: list[_VectorQuery], collection: _Collection ) -> dict[str, list[Doc]]: return super()._do_execute(vectors, collection) + @staticmethod + def _to_sub_vector_query(vq: _VectorQuery) -> _SubVectorQuery: + return _SubVectorQuery.from_vector_query(vq) + class QueryExecutorFactory: @staticmethod diff --git a/python/zvec/extension/multi_vector_reranker.py b/python/zvec/extension/multi_vector_reranker.py index ba3a2363b..a31182804 100644 --- a/python/zvec/extension/multi_vector_reranker.py +++ b/python/zvec/extension/multi_vector_reranker.py @@ -18,6 +18,8 @@ from collections import defaultdict from typing import Optional +from _zvec import _RrfReRanker, _WeightedReRanker + from ..model.doc import Doc from ..typing import MetricType from .rerank_function import RerankFunction @@ -51,11 +53,17 @@ def __init__( ): super().__init__(topn=topn, rerank_field=rerank_field) self._rank_constant = rank_constant + # Use C++ implementation for performance + self._cpp_reranker = _RrfReRanker(topn, rank_constant) @property def rank_constant(self) -> int: return self._rank_constant + def _get_object(self): + """Return the underlying C++ RrfReRanker instance.""" + return self._cpp_reranker + def _rrf_score(self, rank: int) -> float: return 1.0 / (self._rank_constant + rank + 1) @@ -121,6 +129,8 @@ def __init__( super().__init__(topn=topn, rerank_field=rerank_field) self._weights = weights or {} self._metric = metric + # Use C++ implementation for performance + self._cpp_reranker = _WeightedReRanker(topn, metric, self._weights) @property def weights(self) -> dict[str, float]: @@ -132,6 +142,10 @@ def metric(self) -> MetricType: """MetricType: Distance metric used for score normalization.""" return self._metric + def _get_object(self): + """Return the underlying C++ WeightedReRanker instance.""" + return self._cpp_reranker + def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: """Combine scores from multiple vector fields using weighted sum. diff --git a/python/zvec/extension/rerank_function.py b/python/zvec/extension/rerank_function.py index c558a2bc4..0d8d00263 100644 --- a/python/zvec/extension/rerank_function.py +++ b/python/zvec/extension/rerank_function.py @@ -67,3 +67,15 @@ def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: with updated ``score`` fields. """ ... + + def _get_object(self): + """Return the underlying C++ Reranker instance, if available. + + This is used internally by the query executor to pass the reranker + to the C++ MultiQuery method. Subclasses that wrap a C++ reranker + should override this method. + + Returns: + The C++ Reranker shared pointer, or None if not available. + """ + return None # noqa: RET501 diff --git a/src/binding/c/c_api.cc b/src/binding/c/c_api.cc index 2c3489ab9..1fc9a1d1a 100644 --- a/src/binding/c/c_api.cc +++ b/src/binding/c/c_api.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -5340,6 +5341,367 @@ zvec_error_code_t zvec_group_by_vector_query_set_flat_params( return ZVEC_OK; } +// ============================================================================= +// Reranker Implementation +// ============================================================================= + +zvec_reranker_t *zvec_reranker_create_rrf(int topn, int rank_constant) { + ZVEC_TRY_RETURN_NULL("Failed to create RRF Reranker", + auto *reranker = + new zvec::Reranker::Ptr( + std::make_shared( + topn, rank_constant)); + return reinterpret_cast(reranker);) + return nullptr; +} + +zvec_reranker_t *zvec_reranker_create_weighted( + int topn, int metric, const char **fields, const double *weights, + size_t weight_count) { + if ((!fields || !weights) && weight_count > 0) { + set_last_error("Fields and weights pointers cannot be null when " + "weight_count > 0"); + return nullptr; + } + + ZVEC_TRY_RETURN_NULL( + "Failed to create Weighted Reranker", + std::map weight_map; + for (size_t i = 0; i < weight_count; ++i) { + if (!fields[i]) { + set_last_error("Null field name at index " + std::to_string(i)); + return nullptr; + } + weight_map[fields[i]] = weights[i]; + } + + auto *reranker = new zvec::Reranker::Ptr( + std::make_shared( + topn, static_cast(metric), weight_map)); + return reinterpret_cast(reranker);) + return nullptr; +} + +void zvec_reranker_destroy(zvec_reranker_t *reranker) { + if (reranker) { + delete reinterpret_cast(reranker); + } +} + +int zvec_reranker_get_topn(const zvec_reranker_t *reranker) { + if (!reranker) return 0; + auto *ptr = reinterpret_cast(reranker); + return (*ptr)->topn(); +} + +int zvec_reranker_get_rank_constant(const zvec_reranker_t *reranker) { + if (!reranker) return -1; + auto *ptr = reinterpret_cast(reranker); + auto *rrf = dynamic_cast(ptr->get()); + return rrf ? rrf->rank_constant() : -1; +} + +// ============================================================================= +// MultiVectorQuery Implementation +// ============================================================================= + +zvec_multi_vector_query_t *zvec_multi_vector_query_create(void) { + ZVEC_TRY_RETURN_NULL("Failed to create MultiVectorQuery", + auto *query = new zvec::MultiVectorQuery(); + return reinterpret_cast( + query);) + return nullptr; +} + +void zvec_multi_vector_query_destroy(zvec_multi_vector_query_t *query) { + if (query) { + delete reinterpret_cast(query); + } +} + +zvec_error_code_t zvec_multi_vector_query_add_query( + zvec_multi_vector_query_t *query, + const zvec_multi_vector_sub_query_t *sub_query) { + if (!query || !sub_query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query or sub_query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *mvq = reinterpret_cast(query); + auto *sub = reinterpret_cast(sub_query); + mvq->queries.push_back(*sub); + + return ZVEC_OK; +} + +size_t zvec_multi_vector_query_get_query_count( + const zvec_multi_vector_query_t *query) { + if (!query) return 0; + auto *mvq = reinterpret_cast(query); + return mvq->queries.size(); +} + +zvec_error_code_t zvec_multi_vector_query_set_topk( + zvec_multi_vector_query_t *query, int topk) { + if (!query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Multi-vector query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *mvq = reinterpret_cast(query); + mvq->topk = topk; + return ZVEC_OK; +} + +int zvec_multi_vector_query_get_topk( + const zvec_multi_vector_query_t *query) { + if (!query) return 0; + auto *mvq = reinterpret_cast(query); + return mvq->topk; +} + +zvec_error_code_t zvec_multi_vector_query_set_filter( + zvec_multi_vector_query_t *query, const char *filter) { + if (!query || !filter) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query or filter pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *mvq = reinterpret_cast(query); + mvq->filter = std::string(filter); + return ZVEC_OK; +} + +const char *zvec_multi_vector_query_get_filter( + const zvec_multi_vector_query_t *query) { + if (!query) return nullptr; + auto *mvq = reinterpret_cast(query); + return mvq->filter.c_str(); +} + +zvec_error_code_t zvec_multi_vector_query_set_include_vector( + zvec_multi_vector_query_t *query, bool include) { + if (!query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Multi-vector query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *mvq = reinterpret_cast(query); + mvq->include_vector = include; + return ZVEC_OK; +} + +bool zvec_multi_vector_query_get_include_vector( + const zvec_multi_vector_query_t *query) { + if (!query) return false; + auto *mvq = reinterpret_cast(query); + return mvq->include_vector; +} + +zvec_error_code_t zvec_multi_vector_query_set_output_fields( + zvec_multi_vector_query_t *query, const char **fields, size_t count) { + if (!query || (!fields && count > 0)) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query pointer is null or fields is null with count > 0"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *mvq = reinterpret_cast(query); + std::vector field_vec; + field_vec.reserve(count); + for (size_t i = 0; i < count; ++i) { + if (!fields[i]) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Null field name at index " + std::to_string(i)); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + field_vec.emplace_back(fields[i]); + } + mvq->output_fields = std::move(field_vec); + + return ZVEC_OK; +} + +zvec_error_code_t zvec_multi_vector_query_get_output_fields( + zvec_multi_vector_query_t *query, const char ***fields, size_t *count) { + if (!query || !fields || !count) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query, fields or count pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *mvq = reinterpret_cast(query); + if (!mvq->output_fields.has_value() || mvq->output_fields->empty()) { + *fields = nullptr; + *count = 0; + return ZVEC_OK; + } + + const auto &field_vec = mvq->output_fields.value(); + *count = field_vec.size(); + *fields = static_cast(malloc(*count * sizeof(const char *))); + if (!*fields) { + SET_LAST_ERROR(ZVEC_ERROR_INTERNAL_ERROR, "Failed to allocate memory"); + return ZVEC_ERROR_INTERNAL_ERROR; + } + for (size_t i = 0; i < *count; ++i) { + (*fields)[i] = field_vec[i].c_str(); + } + + return ZVEC_OK; +} + +zvec_error_code_t zvec_multi_vector_query_set_reranker( + zvec_multi_vector_query_t *query, zvec_reranker_t *reranker) { + if (!query || !reranker) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query or reranker pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *mvq = reinterpret_cast(query); + auto *reranker_ptr = + reinterpret_cast(reranker); + mvq->reranker = *reranker_ptr; + + return ZVEC_OK; +} + +// ============================================================================= +// SubVectorQuery Implementation +// ============================================================================= + +zvec_multi_vector_sub_query_t *zvec_multi_vector_sub_query_create(void) { + ZVEC_TRY_RETURN_NULL("Failed to create SubVectorQuery", + auto *query = new zvec::SubVectorQuery(); + query->num_candidates_ = 10; + return reinterpret_cast( + query);) + return nullptr; +} + +void zvec_multi_vector_sub_query_destroy(zvec_multi_vector_sub_query_t *query) { + if (query) { + delete reinterpret_cast(query); + } +} + +zvec_error_code_t zvec_multi_vector_sub_query_set_num_candidates( + zvec_multi_vector_sub_query_t *query, int num_candidates) { + if (!query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + ptr->num_candidates_ = num_candidates; + return ZVEC_OK; +} + +int zvec_multi_vector_sub_query_get_num_candidates( + const zvec_multi_vector_sub_query_t *query) { + if (!query) return 0; + auto *ptr = reinterpret_cast(query); + return ptr->num_candidates_; +} + +zvec_error_code_t zvec_multi_vector_sub_query_set_field_name( + zvec_multi_vector_sub_query_t *query, const char *field_name) { + if (!query || !field_name) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or field_name pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + ptr->field_name_ = std::string(field_name); + return ZVEC_OK; +} + +const char *zvec_multi_vector_sub_query_get_field_name( + const zvec_multi_vector_sub_query_t *query) { + if (!query) return nullptr; + auto *ptr = reinterpret_cast(query); + return ptr->field_name_.c_str(); +} + +zvec_error_code_t zvec_multi_vector_sub_query_set_query_vector( + zvec_multi_vector_sub_query_t *query, const void *data, size_t size) { + if (!query || !data || size == 0) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query pointer or data is null/empty"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + ptr->query_vector_.assign(static_cast(data), size); + return ZVEC_OK; +} + +zvec_error_code_t zvec_multi_vector_sub_query_set_sparse_indices( + zvec_multi_vector_sub_query_t *query, const uint32_t *indices, size_t count) { + if (!query || (!indices && count > 0)) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or indices pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + ptr->query_sparse_indices_.assign( + reinterpret_cast(indices), count * sizeof(uint32_t)); + return ZVEC_OK; +} + +zvec_error_code_t zvec_multi_vector_sub_query_set_sparse_values( + zvec_multi_vector_sub_query_t *query, const float *values, size_t count) { + if (!query || (!values && count > 0)) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or values pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + ptr->query_sparse_values_.assign( + reinterpret_cast(values), count * sizeof(float)); + return ZVEC_OK; +} + +zvec_error_code_t zvec_multi_vector_sub_query_set_hnsw_params( + zvec_multi_vector_sub_query_t *query, zvec_hnsw_query_params_t *hnsw_params) { + if (!query || !hnsw_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or HNSW params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + auto *params_ptr = reinterpret_cast(hnsw_params); + ptr->query_params_.reset(params_ptr); + return ZVEC_OK; +} + +zvec_error_code_t zvec_multi_vector_sub_query_set_ivf_params( + zvec_multi_vector_sub_query_t *query, zvec_ivf_query_params_t *ivf_params) { + if (!query || !ivf_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or IVF params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + auto *params_ptr = reinterpret_cast(ivf_params); + ptr->query_params_.reset(params_ptr); + return ZVEC_OK; +} + +zvec_error_code_t zvec_multi_vector_sub_query_set_flat_params( + zvec_multi_vector_sub_query_t *query, zvec_flat_query_params_t *flat_params) { + if (!query || !flat_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Sub-vector query or Flat params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(query); + auto *params_ptr = reinterpret_cast(flat_params); + ptr->query_params_.reset(params_ptr); + return ZVEC_OK; +} + // ============================================================================= // Index Interface Implementation // ============================================================================= @@ -5998,6 +6360,41 @@ zvec_error_code_t zvec_collection_query(const zvec_collection_t *collection, return error_code;) } +zvec_error_code_t zvec_collection_multi_query( + const zvec_collection_t *collection, + const zvec_multi_vector_query_t *query, + zvec_doc_t ***results, size_t *result_count) { + if (!collection || !query || !results || !result_count) { + set_last_error( + "Invalid arguments: collection, query, results and result_count " + "cannot be null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + ZVEC_TRY_RETURN_ERROR( + "Exception occurred", + auto coll_ptr = + reinterpret_cast *>( + collection); + + auto *internal_query = + reinterpret_cast(query); + + auto result = (*coll_ptr)->MultiQuery(*internal_query); + zvec_error_code_t error_code = handle_expected_result(result); + + if (error_code == ZVEC_OK) { + const auto &query_results = result.value(); + error_code = + convert_document_results(query_results, results, result_count); + } else { + *results = nullptr; + *result_count = 0; + } + + return error_code;) +} + zvec_error_code_t zvec_collection_fetch(zvec_collection_t *collection, const char *const *pks, size_t pk_count, zvec_doc_t ***results, size_t *doc_count) { diff --git a/src/binding/python/CMakeLists.txt b/src/binding/python/CMakeLists.txt index d17f56289..a02287c8d 100644 --- a/src/binding/python/CMakeLists.txt +++ b/src/binding/python/CMakeLists.txt @@ -10,6 +10,7 @@ set(SRC_LISTS binding.cc model/python_collection.cc model/python_doc.cc + model/python_reranker.cc model/param/python_param.cc model/schema/python_schema.cc model/common/python_config.cc diff --git a/src/binding/python/binding.cc b/src/binding/python/binding.cc index ed8d6918d..c1bdad367 100644 --- a/src/binding/python/binding.cc +++ b/src/binding/python/binding.cc @@ -16,6 +16,7 @@ #include "python_config.h" #include "python_doc.h" #include "python_param.h" +#include "python_reranker.h" #include "python_schema.h" #include "python_type.h" @@ -26,6 +27,7 @@ PYBIND11_MODULE(_zvec, m) { ZVecPyTyping::Initialize(m); ZVecPyParams::Initialize(m); ZVecPySchemas::Initialize(m); + ZVecPyReranker::Initialize(m); ZVecPyConfig::Initialize(m); ZVecPyDoc::Initialize(m); ZVecPyCollection::Initialize(m); diff --git a/src/binding/python/include/python_reranker.h b/src/binding/python/include/python_reranker.h new file mode 100644 index 000000000..4ab174a62 --- /dev/null +++ b/src/binding/python/include/python_reranker.h @@ -0,0 +1,31 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +namespace py = pybind11; + +namespace zvec { + +class ZVecPyReranker { + public: + ZVecPyReranker() = delete; + + public: + static void Initialize(py::module_ &m); +}; + +} // namespace zvec diff --git a/src/binding/python/model/param/python_param.cc b/src/binding/python/model/param/python_param.cc index 268246cd6..6a195684e 100644 --- a/src/binding/python/model/param/python_param.cc +++ b/src/binding/python/model/param/python_param.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include "python_doc.h" namespace zvec { @@ -1372,6 +1373,27 @@ Constructs an AlterColumnOption instance. } void ZVecPyParams::bind_vector_query(py::module_ &m) { + // Bind SubVectorQuery (used by MultiVectorQuery) + py::class_(m, "_SubVectorQuery") + .def(py::init<>()) + .def_readwrite("num_candidates", &SubVectorQuery::num_candidates_) + .def_readwrite("field_name", &SubVectorQuery::field_name_) + .def_readwrite("query_params", &SubVectorQuery::query_params_) + .def_static( + "from_vector_query", + [](const VectorQuery &vq) { + SubVectorQuery sub; + sub.num_candidates_ = vq.topk_; + sub.field_name_ = vq.field_name_; + sub.query_vector_ = vq.query_vector_; + sub.query_sparse_indices_ = vq.query_sparse_indices_; + sub.query_sparse_values_ = vq.query_sparse_values_; + sub.query_params_ = vq.query_params_; + return sub; + }, + py::arg("vector_query"), + "Create a SubVectorQuery from a VectorQuery"); + py::class_(m, "_VectorQuery") .def(py::init<>()) // properties diff --git a/src/binding/python/model/python_collection.cc b/src/binding/python/model/python_collection.cc index ae2ac572f..671a26d02 100644 --- a/src/binding/python/model/python_collection.cc +++ b/src/binding/python/model/python_collection.cc @@ -292,7 +292,16 @@ void ZVecPyCollection::bind_dql_methods( "given vector column. One of 'mmap', 'buffer_pool', 'contiguous'. " "Raises KeyError if no HNSW index exists on the column, or " "ValueError if the column's index is not an HNSW index. Intended " - "for introspection and testing only; not part of the stable API."); + "for introspection and testing only; not part of the stable API.") + // MultiQuery: multi-vector query with optional reranker + .def( + "MultiQuery", + [](Collection &self, const MultiVectorQuery &query) { + const auto result = self.MultiQuery(query); + return unwrap_expected(result); + }, + py::arg("query"), + "Execute a multi-vector query with optional re-ranking."); } } // namespace zvec \ No newline at end of file diff --git a/src/binding/python/model/python_reranker.cc b/src/binding/python/model/python_reranker.cc new file mode 100644 index 000000000..27906f8b9 --- /dev/null +++ b/src/binding/python/model/python_reranker.cc @@ -0,0 +1,56 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "python_reranker.h" +#include +#include +#include + +namespace zvec { + +void ZVecPyReranker::Initialize(py::module_ &m) { + // Bind Reranker base class (abstract, cannot be instantiated directly) + py::class_(m, "_Reranker") + .def_property_readonly("topn", &Reranker::topn); + + // Bind RrfReRanker + py::class_>( + m, "_RrfReRanker") + .def(py::init(), py::arg("topn") = 10, + py::arg("rank_constant") = 60) + .def_property_readonly("topn", &RrfReRanker::topn) + .def_property_readonly("rank_constant", &RrfReRanker::rank_constant); + + // Bind WeightedReRanker + py::class_>( + m, "_WeightedReRanker") + .def(py::init>(), + py::arg("topn") = 10, py::arg("metric") = MetricType::L2, + py::arg("weights") = std::map{}) + .def_property_readonly("topn", &WeightedReRanker::topn) + .def_property_readonly("metric", &WeightedReRanker::metric) + .def_property_readonly("weights", &WeightedReRanker::weights); + + // Bind MultiVectorQuery struct + py::class_(m, "_MultiVectorQuery") + .def(py::init<>()) + .def_readwrite("queries", &MultiVectorQuery::queries) + .def_readwrite("topk", &MultiVectorQuery::topk) + .def_readwrite("filter", &MultiVectorQuery::filter) + .def_readwrite("include_vector", &MultiVectorQuery::include_vector) + .def_readwrite("output_fields", &MultiVectorQuery::output_fields) + .def_readwrite("reranker", &MultiVectorQuery::reranker); +} + +} // namespace zvec diff --git a/src/db/collection.cc b/src/db/collection.cc index 4e9fa2275..b0feeebcc 100644 --- a/src/db/collection.cc +++ b/src/db/collection.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,7 @@ #include #include #include +#include #include #include #include "db/common/constants.h" @@ -117,6 +119,8 @@ class CollectionImpl : public Collection { Result Query(const VectorQuery &query) const override; + Result MultiQuery(const MultiVectorQuery &query) const override; + Result GroupByQuery( const GroupByVectorQuery &query) const override; @@ -1594,6 +1598,101 @@ Result CollectionImpl::Query(const VectorQuery &query) const { return sql_engine_->execute(schema_, sanitized, segments); } +Result CollectionImpl::MultiQuery( + const MultiVectorQuery &query) const { + std::shared_lock lock(schema_handle_mtx_); + + CHECK_DESTROY_RETURN_STATUS_EXPECTED(destroyed_, false); + + if (query.queries.size() < 2) { + return tl::make_unexpected( + Status::InvalidArgument("MultiQuery requires at least 2 sub-queries")); + } + + if (!query.reranker) { + return tl::make_unexpected( + Status::InvalidArgument("Reranker is required for multi-vector query")); + } + + // Use query.topk as reranker's topn + query.reranker->set_topn(query.topk); + + // If WeightedReRanker, verify metric consistency with field schemas + auto *weighted = dynamic_cast(query.reranker.get()); + if (weighted) { + for (const auto &sub : query.queries) { + auto *field_schema = schema_->get_vector_field(sub.field_name_); + if (!field_schema) { + return tl::make_unexpected(Status::InvalidArgument( + "Vector field not found: ", sub.field_name_)); + } + auto *vec_params = dynamic_cast( + field_schema->index_params().get()); + if (vec_params && vec_params->metric_type() != weighted->metric()) { + return tl::make_unexpected(Status::InvalidArgument( + "WeightedReRanker metric mismatch for field: ", sub.field_name_, + ". Reranker metric: ", + std::to_string(static_cast(weighted->metric())), + ", field metric: ", + std::to_string(static_cast(vec_params->metric_type())))); + } + } + } + + // Convert SubVectorQuery to VectorQuery and validate + std::set seen_fields; + std::vector converted_queries; + converted_queries.reserve(query.queries.size()); + + for (const auto &sub : query.queries) { + if (seen_fields.count(sub.field_name_)) { + return tl::make_unexpected(Status::InvalidArgument( + "Duplicate field name in multi-vector query: ", sub.field_name_)); + } + seen_fields.insert(sub.field_name_); + auto *field_schema = schema_->get_vector_field(sub.field_name_); + if (!field_schema) { + return tl::make_unexpected( + Status::InvalidArgument("Vector field not found: ", sub.field_name_)); + } + + VectorQuery vq; + vq.topk_ = sub.num_candidates_; + vq.field_name_ = sub.field_name_; + vq.query_vector_ = sub.query_vector_; + vq.query_sparse_indices_ = sub.query_sparse_indices_; + vq.query_sparse_values_ = sub.query_sparse_values_; + vq.query_params_ = sub.query_params_; + vq.filter_ = query.filter; + vq.include_vector_ = query.include_vector; + vq.include_doc_id_ = query.include_doc_id_; + vq.output_fields_ = query.output_fields; + + auto s = vq.validate_and_sanitize(field_schema); + CHECK_RETURN_STATUS_EXPECTED(s); + converted_queries.push_back(std::move(vq)); + } + + auto segments = get_all_segments(); + if (segments.empty()) { + return DocPtrList(); + } + + // Execute each VectorQuery and collect results per field + std::map query_results; + + for (const auto &vq : converted_queries) { + auto result = sql_engine_->execute(schema_, vq, segments); + if (!result.has_value()) { + return tl::make_unexpected(result.error()); + } + query_results[vq.field_name_] = std::move(result.value()); + } + + // Merge and rerank results + return query.reranker->rerank(query_results); +} + Result CollectionImpl::GroupByQuery( const GroupByVectorQuery &query) const { std::shared_lock lock(schema_handle_mtx_); diff --git a/src/db/index/common/doc.cc b/src/db/index/common/doc.cc index 0405eac1d..fe8be100e 100644 --- a/src/db/index/common/doc.cc +++ b/src/db/index/common/doc.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include "db/common/constants.h" #include "db/index/common/type_helper.h" diff --git a/src/db/reranker/reranker.cc b/src/db/reranker/reranker.cc new file mode 100644 index 000000000..1d5c540a7 --- /dev/null +++ b/src/db/reranker/reranker.cc @@ -0,0 +1,150 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include + +namespace zvec { + +// ==================== RrfReRanker ==================== + +DocPtrList RrfReRanker::rerank( + const std::map &query_results) const { + // doc_id -> cumulative RRF score + std::unordered_map rrf_scores; + // doc_id -> first-seen Doc pointer + std::unordered_map id_to_doc; + + for (const auto &[field_name, docs] : query_results) { + for (size_t rank = 0; rank < docs.size(); ++rank) { + const auto &doc = docs[rank]; + const std::string &doc_id = doc->pk(); + double score = 1.0 / (static_cast(rank_constant_) + + static_cast(rank) + 1.0); + rrf_scores[doc_id] += score; + if (id_to_doc.find(doc_id) == id_to_doc.end()) { + id_to_doc[doc_id] = doc; + } + } + } + + // Sort by RRF score descending and take topn using a min-heap + using ScorePair = std::pair; + auto cmp = [](const ScorePair &a, const ScorePair &b) { + return a.second > b.second; // min-heap: top element is smallest + }; + std::priority_queue, decltype(cmp)> pq(cmp); + + for (const auto &[doc_id, score] : rrf_scores) { + if (static_cast(pq.size()) < topn_) { + pq.emplace(doc_id, score); + } else if (score > pq.top().second) { + pq.pop(); + pq.emplace(doc_id, score); + } + } + + DocPtrList results; + results.reserve(pq.size()); + while (!pq.empty()) { + const auto &[doc_id, score] = pq.top(); + auto doc = std::make_shared(*id_to_doc[doc_id]); + doc->set_score(static_cast(score)); + results.push_back(std::move(doc)); + pq.pop(); + } + // Reverse to get descending order + std::reverse(results.begin(), results.end()); + return results; +} + +// ==================== WeightedReRanker ==================== + +WeightedReRanker::WeightedReRanker(int topn, MetricType metric, + const std::map &weights) + : Reranker(topn), metric_(metric), weights_(weights) {} + +double WeightedReRanker::normalize_score(double score, MetricType metric) { + switch (metric) { + case MetricType::L2: + return 1.0 - 2.0 * std::atan(score) / M_PI; + case MetricType::IP: + return 0.5 + std::atan(score) / M_PI; + case MetricType::COSINE: + return 1.0 - score / 2.0; + default: + throw std::invalid_argument("Unsupported metric type for normalization"); + } +} + +DocPtrList WeightedReRanker::rerank( + const std::map &query_results) const { + // doc_id -> cumulative weighted score + std::unordered_map weighted_scores; + // doc_id -> first-seen Doc pointer + std::unordered_map id_to_doc; + + for (const auto &[vector_name, docs] : query_results) { + double weight = 1.0; + auto it = weights_.find(vector_name); + if (it != weights_.end()) { + weight = it->second; + } + for (const auto &doc : docs) { + const std::string &doc_id = doc->pk(); + double normalized = + normalize_score(static_cast(doc->score()), metric_); + weighted_scores[doc_id] += normalized * weight; + if (id_to_doc.find(doc_id) == id_to_doc.end()) { + id_to_doc[doc_id] = doc; + } + } + } + + // Sort by weighted score descending and take topn using a min-heap + using ScorePair = std::pair; + auto cmp = [](const ScorePair &a, const ScorePair &b) { + return a.second > b.second; // min-heap + }; + std::priority_queue, decltype(cmp)> pq(cmp); + + for (const auto &[doc_id, score] : weighted_scores) { + if (static_cast(pq.size()) < topn_) { + pq.emplace(doc_id, score); + } else if (score > pq.top().second) { + pq.pop(); + pq.emplace(doc_id, score); + } + } + + DocPtrList results; + results.reserve(pq.size()); + while (!pq.empty()) { + const auto &[doc_id, score] = pq.top(); + auto doc = std::make_shared(*id_to_doc[doc_id]); + doc->set_score(static_cast(score)); + results.push_back(std::move(doc)); + pq.pop(); + } + // Reverse to get descending order + std::reverse(results.begin(), results.end()); + return results; +} + +} // namespace zvec diff --git a/src/db/sqlengine/parser/sql_info_helper.h b/src/db/sqlengine/parser/sql_info_helper.h index 465ccdce2..760dbc4e3 100644 --- a/src/db/sqlengine/parser/sql_info_helper.h +++ b/src/db/sqlengine/parser/sql_info_helper.h @@ -14,7 +14,7 @@ #pragma once -#include +#include #include "db/sqlengine/common/group_by.h" #include "db/sqlengine/parser/node.h" #include "db/sqlengine/parser/sql_info.h" diff --git a/src/db/sqlengine/sqlengine.h b/src/db/sqlengine/sqlengine.h index d86fd69bf..47143b60f 100644 --- a/src/db/sqlengine/sqlengine.h +++ b/src/db/sqlengine/sqlengine.h @@ -14,7 +14,7 @@ #pragma once -#include +#include #include #include "db/common/profiler.h" #include "db/index/segment/segment.h" diff --git a/src/db/sqlengine/sqlengine_impl.h b/src/db/sqlengine/sqlengine_impl.h index 88c279283..5e258346b 100644 --- a/src/db/sqlengine/sqlengine_impl.h +++ b/src/db/sqlengine/sqlengine_impl.h @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include "analyzer/query_info.h" #include "common/group_by.h" diff --git a/src/include/zvec/c_api.h b/src/include/zvec/c_api.h index c64190d50..e7bb64aa6 100644 --- a/src/include/zvec/c_api.h +++ b/src/include/zvec/c_api.h @@ -1032,6 +1032,30 @@ typedef struct zvec_vector_query_t zvec_vector_query_t; */ typedef struct zvec_group_by_vector_query_t zvec_group_by_vector_query_t; +/** + * @brief Reranker structure (opaque pointer) + * Aligned with zvec::Reranker + * Use zvec_reranker_create_rrf() or zvec_reranker_create_weighted() to create + * and zvec_reranker_destroy() to destroy + */ +typedef struct zvec_reranker_t zvec_reranker_t; + +/** + * @brief Multi-vector query structure (opaque pointer) + * Aligned with zvec::MultiVectorQuery + * Use zvec_multi_vector_query_create() to create and + * zvec_multi_vector_query_destroy() to destroy + */ +typedef struct zvec_multi_vector_query_t zvec_multi_vector_query_t; + +/** + * @brief Sub-vector query structure for multi-vector queries (opaque pointer) + * Aligned with zvec::SubVectorQuery + * Use zvec_multi_vector_sub_query_create() to create and + * zvec_multi_vector_sub_query_destroy() to destroy + */ +typedef struct zvec_multi_vector_sub_query_t zvec_multi_vector_sub_query_t; + // ============================================================================= // Query Parameters Management Functions @@ -1704,6 +1728,299 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_group_by_vector_query_set_flat_params( zvec_group_by_vector_query_t *query, zvec_flat_query_params_t *flat_params); +// ----------------------------------------------------------------------------- +// zvec_reranker_t (Reranker) +// ----------------------------------------------------------------------------- + +/** + * @brief Create an RRF (Reciprocal Rank Fusion) reranker + * @param topn Maximum number of results to return after re-ranking + * @param rank_constant RRF rank constant (default: 60) + * @return zvec_reranker_t* Pointer to the newly created reranker + */ +ZVEC_EXPORT zvec_reranker_t *ZVEC_CALL +zvec_reranker_create_rrf(int topn, int rank_constant); + +/** + * @brief Create a Weighted reranker + * @param topn Maximum number of results to return after re-ranking + * @param metric Metric type: 0=L2, 1=IP, 2=COSINE + * @param weights Array of field name and weight pairs (field1, weight1, ...) + * @param weight_count Number of weight pairs (must be even) + * @return zvec_reranker_t* Pointer to the newly created reranker + */ +ZVEC_EXPORT zvec_reranker_t *ZVEC_CALL +zvec_reranker_create_weighted(int topn, int metric, const char **fields, + const double *weights, size_t weight_count); + +/** + * @brief Destroy reranker + * @param reranker Reranker pointer + */ +ZVEC_EXPORT void ZVEC_CALL zvec_reranker_destroy(zvec_reranker_t *reranker); + +/** + * @brief Get reranker topn + * @param reranker Reranker pointer + * @return int TopN value + */ +ZVEC_EXPORT int ZVEC_CALL +zvec_reranker_get_topn(const zvec_reranker_t *reranker); + +/** + * @brief Get RRF rank constant (only valid for RRF reranker) + * @param reranker Reranker pointer + * @return int Rank constant, or -1 if not an RRF reranker + */ +ZVEC_EXPORT int ZVEC_CALL +zvec_reranker_get_rank_constant(const zvec_reranker_t *reranker); + +// ----------------------------------------------------------------------------- +// zvec_multi_vector_query_t (Multi-Vector Query) +// ----------------------------------------------------------------------------- + +/** + * @brief Create multi-vector query + * @return zvec_multi_vector_query_t* Pointer to the newly created multi-vector + * query + */ +ZVEC_EXPORT zvec_multi_vector_query_t *ZVEC_CALL +zvec_multi_vector_query_create(void); + +/** + * @brief Destroy multi-vector query + * @param query Multi-vector query pointer + */ +ZVEC_EXPORT void ZVEC_CALL +zvec_multi_vector_query_destroy(zvec_multi_vector_query_t *query); + +/** + * @brief Add a sub-vector query to the multi-vector query + * @param query Multi-vector query pointer + * @param sub_query Sub-vector query to add (copied, caller retains ownership) + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_vector_query_add_query( + zvec_multi_vector_query_t *query, + const zvec_multi_vector_sub_query_t *sub_query); + +/** + * @brief Get number of vector queries + * @param query Multi-vector query pointer + * @return size_t Number of vector queries + */ +ZVEC_EXPORT size_t ZVEC_CALL +zvec_multi_vector_query_get_query_count(const zvec_multi_vector_query_t *query); + +/** + * @brief Set topk + * @param query Multi-vector query pointer + * @param topk Number of results + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_query_set_topk(zvec_multi_vector_query_t *query, int topk); + +/** + * @brief Get topk + * @param query Multi-vector query pointer + * @return int Number of results + */ +ZVEC_EXPORT int ZVEC_CALL +zvec_multi_vector_query_get_topk(const zvec_multi_vector_query_t *query); + +/** + * @brief Set filter expression + * @param query Multi-vector query pointer + * @param filter Filter expression string + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_vector_query_set_filter( + zvec_multi_vector_query_t *query, const char *filter); + +/** + * @brief Get filter expression + * @param query Multi-vector query pointer + * @return const char* Filter expression (owned by query, do not free) + */ +ZVEC_EXPORT const char *ZVEC_CALL +zvec_multi_vector_query_get_filter(const zvec_multi_vector_query_t *query); + +/** + * @brief Set whether to include vector data in results + * @param query Multi-vector query pointer + * @param include Whether to include vectors + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_query_set_include_vector(zvec_multi_vector_query_t *query, + bool include); + +/** + * @brief Get whether to include vector data in results + * @param query Multi-vector query pointer + * @return bool Whether to include vectors + */ +ZVEC_EXPORT bool ZVEC_CALL zvec_multi_vector_query_get_include_vector( + const zvec_multi_vector_query_t *query); + +/** + * @brief Set output fields + * @param query Multi-vector query pointer + * @param fields Array of field names + * @param count Number of fields + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_query_set_output_fields(zvec_multi_vector_query_t *query, + const char **fields, size_t count); + +/** + * @brief Get output fields + * @param query Multi-vector query pointer + * @param[out] fields Output array of field names (allocated by library) + * @param[out] count Number of fields + * @return zvec_error_code_t Error code + * + * @note The returned array is allocated by the library and should be freed + * using zvec_free() when no longer needed. The individual string pointers + * are owned by the query and must NOT be freed. + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_query_get_output_fields(zvec_multi_vector_query_t *query, + const char ***fields, size_t *count); + +/** + * @brief Set reranker (copies shared pointer, caller must still destroy + * reranker) + * @param query Multi-vector query pointer + * @param reranker Reranker pointer (remains valid, caller must call + * zvec_reranker_destroy after use) + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_vector_query_set_reranker( + zvec_multi_vector_query_t *query, zvec_reranker_t *reranker); + +// ----------------------------------------------------------------------------- +// zvec_multi_vector_sub_query_t (Sub-Vector Query for Multi-Vector Queries) +// ----------------------------------------------------------------------------- + +/** + * @brief Create sub-vector query + * @return zvec_multi_vector_sub_query_t* Pointer to the newly created + * sub-vector query + */ +ZVEC_EXPORT zvec_multi_vector_sub_query_t *ZVEC_CALL +zvec_multi_vector_sub_query_create(void); + +/** + * @brief Destroy sub-vector query + * @param query Sub-vector query pointer + */ +ZVEC_EXPORT void ZVEC_CALL +zvec_multi_vector_sub_query_destroy(zvec_multi_vector_sub_query_t *query); + +/** + * @brief Set number of candidates to retrieve per field + * @param query Sub-vector query pointer + * @param num_candidates Number of candidates + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_sub_query_set_num_candidates( + zvec_multi_vector_sub_query_t *query, int num_candidates); + +/** + * @brief Get number of candidates + * @param query Sub-vector query pointer + * @return int Number of candidates + */ +ZVEC_EXPORT int ZVEC_CALL zvec_multi_vector_sub_query_get_num_candidates( + const zvec_multi_vector_sub_query_t *query); + +/** + * @brief Set field name + * @param query Sub-vector query pointer + * @param field_name Field name + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_sub_query_set_field_name(zvec_multi_vector_sub_query_t *query, + const char *field_name); + +/** + * @brief Get field name + * @param query Sub-vector query pointer + * @return const char* Field name (owned by query, do not free) + */ +ZVEC_EXPORT const char *ZVEC_CALL zvec_multi_vector_sub_query_get_field_name( + const zvec_multi_vector_sub_query_t *query); + +/** + * @brief Set query vector data + * @param query Sub-vector query pointer + * @param data Vector data pointer + * @param size Data size in bytes + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_sub_query_set_query_vector( + zvec_multi_vector_sub_query_t *query, const void *data, size_t size); + +/** + * @brief Set sparse vector indices + * @param query Sub-vector query pointer + * @param indices Array of uint32_t indices + * @param count Number of indices + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_sub_query_set_sparse_indices( + zvec_multi_vector_sub_query_t *query, const uint32_t *indices, + size_t count); + +/** + * @brief Set sparse vector values + * @param query Sub-vector query pointer + * @param values Array of float values + * @param count Number of values + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_sub_query_set_sparse_values( + zvec_multi_vector_sub_query_t *query, const float *values, size_t count); + +/** + * @brief Set HNSW query parameters (takes ownership) + * @param query Sub-vector query pointer + * @param hnsw_params HNSW query parameters pointer + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_sub_query_set_hnsw_params( + zvec_multi_vector_sub_query_t *query, + zvec_hnsw_query_params_t *hnsw_params); + +/** + * @brief Set IVF query parameters (takes ownership) + * @param query Sub-vector query pointer + * @param ivf_params IVF query parameters pointer + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_sub_query_set_ivf_params(zvec_multi_vector_sub_query_t *query, + zvec_ivf_query_params_t *ivf_params); + +/** + * @brief Set Flat query parameters (takes ownership) + * @param query Sub-vector query pointer + * @param flat_params Flat query parameters pointer + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_vector_sub_query_set_flat_params( + zvec_multi_vector_sub_query_t *query, + zvec_flat_query_params_t *flat_params); // ============================================================================= // Collection Options and Statistics (Opaque Pointer Pattern) // ============================================================================= @@ -2645,6 +2962,19 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_collection_query( const zvec_collection_t *collection, const zvec_vector_query_t *query, zvec_doc_t ***results, size_t *result_count); +/** + * @brief Multi-vector similarity search with re-ranking + * @param collection Collection handle + * @param query Multi-vector query parameters pointer + * @param[out] results Returned document array (needs to be freed by calling + * zvec_docs_free) + * @param[out] result_count Number of returned results + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_collection_multi_query( + const zvec_collection_t *collection, const zvec_multi_vector_query_t *query, + zvec_doc_t ***results, size_t *result_count); + /** * @brief Fetch documents by primary keys * @param collection Collection handle diff --git a/src/include/zvec/db/collection.h b/src/include/zvec/db/collection.h index 010ba36fa..35fe35f81 100644 --- a/src/include/zvec/db/collection.h +++ b/src/include/zvec/db/collection.h @@ -16,8 +16,8 @@ #include #include #include -#include #include +#include #include #include @@ -98,6 +98,19 @@ class Collection { virtual Result Query(const VectorQuery &query) const = 0; + /** + * @brief Execute a multi-vector query with optional re-ranking. + * + * Runs multiple vector queries sequentially, then combines and re-ranks + * results using the provided reranker. If no reranker is provided and + * there are multiple queries, returns an error. + * + * @param query The multi-vector query specification. + * @return Combined and re-ranked document list OR an error. + */ + virtual Result MultiQuery( + const MultiVectorQuery &query) const = 0; + virtual Result GroupByQuery( const GroupByVectorQuery &query) const = 0; diff --git a/src/include/zvec/db/doc.h b/src/include/zvec/db/doc.h index f702a43c3..3dbe9a7c9 100644 --- a/src/include/zvec/db/doc.h +++ b/src/include/zvec/db/doc.h @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -364,44 +363,4 @@ using DocPtrMap = std::unordered_map; using WriteResults = std::vector; -struct VectorQuery { - int topk_; - std::string field_name_; - std::string query_vector_; // fp16, void * - std::string query_sparse_indices_; - std::string query_sparse_values_; - std::string filter_; - bool include_vector_{false}; - bool include_doc_id_{false}; - // select * by default, select no field if output_fields_ is empty, select - // specific fields if output_fields_ is not empty - std::optional> output_fields_; - QueryParams::Ptr query_params_; - - Status validate_and_sanitize(const FieldSchema *schema); -}; - -struct GroupByVectorQuery { - std::string field_name_; - std::string query_vector_; - std::string query_sparse_indices_; - std::string query_sparse_values_; - std::string filter_; - bool include_vector_; - // select * by default, select no field if output_fields_ is empty, select - // specific fields if output_fields_ is not empty - std::optional> output_fields_; - std::string group_by_field_name_; - uint32_t group_count_ = 2; - uint32_t group_topk_ = 3; - QueryParams::Ptr query_params_; -}; - -struct GroupResult { - std::string group_by_value_; - std::vector docs_; -}; - -using GroupResults = std::vector; - } // namespace zvec diff --git a/src/include/zvec/db/query.h b/src/include/zvec/db/query.h new file mode 100644 index 000000000..f79f71db0 --- /dev/null +++ b/src/include/zvec/db/query.h @@ -0,0 +1,88 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace zvec { + +struct VectorQuery { + int topk_; + std::string field_name_; + std::string query_vector_; // fp16, void * + std::string query_sparse_indices_; + std::string query_sparse_values_; + std::string filter_; + bool include_vector_{false}; + bool include_doc_id_{false}; + // select * by default, select no field if output_fields_ is empty, select + // specific fields if output_fields_ is not empty + std::optional> output_fields_; + QueryParams::Ptr query_params_; + + Status validate_and_sanitize(const FieldSchema *schema); +}; + +struct GroupByVectorQuery { + std::string field_name_; + std::string query_vector_; + std::string query_sparse_indices_; + std::string query_sparse_values_; + std::string filter_; + bool include_vector_; + // select * by default, select no field if output_fields_ is empty, select + // specific fields if output_fields_ is not empty + std::optional> output_fields_; + std::string group_by_field_name_; + uint32_t group_count_ = 2; + uint32_t group_topk_ = 3; + QueryParams::Ptr query_params_; +}; + +//! Multi-vector query structure for querying multiple vector fields +//! with optional re-ranking of combined results. + +struct SubVectorQuery { + int num_candidates_; + std::string field_name_; + std::string query_vector_; + std::string query_sparse_indices_; + std::string query_sparse_values_; + QueryParams::Ptr query_params_; +}; + +struct MultiVectorQuery { + std::vector queries; + int topk{10}; + std::string filter; + bool include_vector{false}; + bool include_doc_id_{false}; + std::optional> output_fields; + std::shared_ptr reranker{nullptr}; +}; + +struct GroupResult { + std::string group_by_value_; + std::vector docs_; +}; + +using GroupResults = std::vector; + +} // namespace zvec diff --git a/src/include/zvec/db/reranker.h b/src/include/zvec/db/reranker.h new file mode 100644 index 000000000..e53066f98 --- /dev/null +++ b/src/include/zvec/db/reranker.h @@ -0,0 +1,122 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace zvec { + +//! Reranker abstract base class for re-ranking search results +class Reranker { + public: + using Ptr = std::shared_ptr; + + explicit Reranker(int topn = 10) : topn_(topn) {} + virtual ~Reranker() = default; + + int topn() const { + return topn_; + } + void set_topn(int topn) { + topn_ = topn; + } + + //! Re-rank documents from one or more vector queries. + //! \param query_results Mapping from vector field name to list of retrieved + //! documents (sorted by relevance). + //! \return Re-ranked list of documents (length <= topn), with updated scores. + virtual DocPtrList rerank( + const std::map &query_results) const = 0; + + protected: + int topn_; +}; + +//! Re-ranker using Reciprocal Rank Fusion (RRF) for multi-vector search. +//! +//! RRF combines results from multiple vector queries without requiring +//! relevance scores. The RRF score for a document at rank r is: +//! score = 1 / (k + r + 1) +//! where k is the rank constant. +class RrfReRanker : public Reranker { + public: + RrfReRanker(int topn = 10, int rank_constant = 60) + : Reranker(topn), rank_constant_(rank_constant) {} + + int rank_constant() const { + return rank_constant_; + } + + DocPtrList rerank( + const std::map &query_results) const override; + + private: + int rank_constant_; +}; + +//! Re-ranker that combines scores from multiple vector fields using weights. +//! +//! Each vector field's relevance score is normalized based on its metric type, +//! then scaled by a user-provided weight. Final scores are summed across +//! fields. Supported metrics: L2, IP, COSINE. +class WeightedReRanker : public Reranker { + public: + WeightedReRanker(int topn = 10, MetricType metric = MetricType::L2, + const std::map &weights = {}); + + MetricType metric() const { + return metric_; + } + const std::map &weights() const { + return weights_; + } + + DocPtrList rerank( + const std::map &query_results) const override; + + //! Normalize a raw distance/similarity score to [0, 1] range + static double normalize_score(double score, MetricType metric); + + private: + MetricType metric_; + std::map weights_; +}; + +//! Callback-based re-ranker for cross-language bridging. +//! +//! Wraps a user-provided callback (e.g., a Python callable) as a Reranker. +//! When the callback is a Python function, GIL must be managed by the caller. +class CallbackReRanker : public Reranker { + public: + using Callback = + std::function &)>; + + CallbackReRanker(Callback fn, int topn = 10) + : Reranker(topn), callback_(std::move(fn)) {} + + DocPtrList rerank( + const std::map &query_results) const override { + return callback_(query_results); + } + + private: + Callback callback_; +}; + +} // namespace zvec diff --git a/tests/c/c_api_test.c b/tests/c/c_api_test.c index 4f38d6912..649e1c4c4 100644 --- a/tests/c/c_api_test.c +++ b/tests/c/c_api_test.c @@ -4126,6 +4126,235 @@ void test_actual_vector_queries(void) { TEST_END(); } +void test_reranker_functions(void) { + TEST_START(); + + // Test 1: Create RRF reranker + zvec_reranker_t *rrf = zvec_reranker_create_rrf(10, 60); + TEST_ASSERT(rrf != NULL); + if (rrf) { + TEST_ASSERT(zvec_reranker_get_topn(rrf) == 10); + TEST_ASSERT(zvec_reranker_get_rank_constant(rrf) == 60); + zvec_reranker_destroy(rrf); + } + + // Test 2: Create RRF reranker with different params + zvec_reranker_t *rrf2 = zvec_reranker_create_rrf(5, 100); + TEST_ASSERT(rrf2 != NULL); + if (rrf2) { + TEST_ASSERT(zvec_reranker_get_topn(rrf2) == 5); + TEST_ASSERT(zvec_reranker_get_rank_constant(rrf2) == 100); + zvec_reranker_destroy(rrf2); + } + + // Test 3: Create Weighted reranker + const char *fields[] = {"embedding1", "embedding2"}; + double weights[] = {0.7, 0.3}; + zvec_reranker_t *weighted = + zvec_reranker_create_weighted(10, 0, fields, weights, 2); + TEST_ASSERT(weighted != NULL); + if (weighted) { + TEST_ASSERT(zvec_reranker_get_topn(weighted) == 10); + TEST_ASSERT(zvec_reranker_get_rank_constant(weighted) == -1); + zvec_reranker_destroy(weighted); + } + + // Test 4: Create Weighted reranker with no weights + zvec_reranker_t *weighted2 = + zvec_reranker_create_weighted(20, 2, NULL, NULL, 0); + TEST_ASSERT(weighted2 != NULL); + if (weighted2) { + TEST_ASSERT(zvec_reranker_get_topn(weighted2) == 20); + zvec_reranker_destroy(weighted2); + } + + // Test 5: NULL reranker operations + TEST_ASSERT(zvec_reranker_get_topn(NULL) == 0); + TEST_ASSERT(zvec_reranker_get_rank_constant(NULL) == -1); + zvec_reranker_destroy(NULL); // Should not crash + + TEST_END(); +} + +void test_multi_vector_query_with_reranker(void) { + TEST_START(); + + char temp_dir[] = "./zvec_test_multi_query_reranker"; + + // Create schema with two vector fields + zvec_collection_schema_t *schema = + zvec_collection_schema_create("multi_query_test"); + TEST_ASSERT(schema != NULL); + + if (schema) { + // Add ID field + zvec_field_schema_t *id_field = + zvec_field_schema_create("id", ZVEC_DATA_TYPE_INT64, false, 0); + zvec_collection_schema_add_field(schema, id_field); + + // Add first vector field (embedding1) with HNSW index + zvec_index_params_t *hnsw1 = zvec_index_params_create(ZVEC_INDEX_TYPE_HNSW); + zvec_index_params_set_metric_type(hnsw1, ZVEC_METRIC_TYPE_L2); + zvec_index_params_set_hnsw_params(hnsw1, 16, 100); + zvec_field_schema_t *vec1 = zvec_field_schema_create( + "embedding1", ZVEC_DATA_TYPE_VECTOR_FP32, false, 4); + zvec_field_schema_set_index_params(vec1, hnsw1); + zvec_collection_schema_add_field(schema, vec1); + zvec_index_params_destroy(hnsw1); + + // Add second vector field (embedding2) with HNSW index + zvec_index_params_t *hnsw2 = zvec_index_params_create(ZVEC_INDEX_TYPE_HNSW); + zvec_index_params_set_metric_type(hnsw2, ZVEC_METRIC_TYPE_L2); + zvec_index_params_set_hnsw_params(hnsw2, 16, 100); + zvec_field_schema_t *vec2 = zvec_field_schema_create( + "embedding2", ZVEC_DATA_TYPE_VECTOR_FP32, false, 4); + zvec_field_schema_set_index_params(vec2, hnsw2); + zvec_collection_schema_add_field(schema, vec2); + zvec_index_params_destroy(hnsw2); + + zvec_collection_t *collection = NULL; + zvec_error_code_t err = + zvec_collection_create_and_open(temp_dir, schema, NULL, &collection); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(collection != NULL); + + if (collection) { + // Insert test documents with both vector fields + float e1_v1[] = {1.0f, 0.0f, 0.0f, 0.0f}; + float e1_v2[] = {0.0f, 1.0f, 0.0f, 0.0f}; + float e1_v3[] = {0.0f, 0.0f, 1.0f, 0.0f}; + float e1_v4[] = {0.7f, 0.7f, 0.0f, 0.0f}; + + float e2_v1[] = {0.0f, 1.0f, 0.0f, 0.0f}; + float e2_v2[] = {1.0f, 0.0f, 0.0f, 0.0f}; + float e2_v3[] = {0.0f, 0.0f, 0.0f, 1.0f}; + float e2_v4[] = {0.5f, 0.5f, 0.0f, 0.0f}; + + zvec_doc_t *docs[4]; + for (int i = 0; i < 4; i++) { + docs[i] = zvec_doc_create(); + zvec_doc_set_pk(docs[i], zvec_test_make_pk(i + 1)); + zvec_doc_add_field_by_value(docs[i], "id", ZVEC_DATA_TYPE_INT64, + &(int64_t){i + 1}, sizeof(int64_t)); + } + + zvec_doc_add_field_by_value(docs[0], "embedding1", ZVEC_DATA_TYPE_VECTOR_FP32, + e1_v1, sizeof(e1_v1)); + zvec_doc_add_field_by_value(docs[0], "embedding2", ZVEC_DATA_TYPE_VECTOR_FP32, + e2_v1, sizeof(e2_v1)); + + zvec_doc_add_field_by_value(docs[1], "embedding1", ZVEC_DATA_TYPE_VECTOR_FP32, + e1_v2, sizeof(e1_v2)); + zvec_doc_add_field_by_value(docs[1], "embedding2", ZVEC_DATA_TYPE_VECTOR_FP32, + e2_v2, sizeof(e2_v2)); + + zvec_doc_add_field_by_value(docs[2], "embedding1", ZVEC_DATA_TYPE_VECTOR_FP32, + e1_v3, sizeof(e1_v3)); + zvec_doc_add_field_by_value(docs[2], "embedding2", ZVEC_DATA_TYPE_VECTOR_FP32, + e2_v3, sizeof(e2_v3)); + + zvec_doc_add_field_by_value(docs[3], "embedding1", ZVEC_DATA_TYPE_VECTOR_FP32, + e1_v4, sizeof(e1_v4)); + zvec_doc_add_field_by_value(docs[3], "embedding2", ZVEC_DATA_TYPE_VECTOR_FP32, + e2_v4, sizeof(e2_v4)); + + size_t success_count, error_count; + err = zvec_collection_insert(collection, (const zvec_doc_t **)docs, 4, + &success_count, &error_count); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(success_count == 4); + + zvec_collection_flush(collection); + + // Test 1: MultiQuery with RRF reranker + zvec_reranker_t *rrf = zvec_reranker_create_rrf(3, 60); + TEST_ASSERT(rrf != NULL); + + zvec_multi_vector_query_t *mvq = zvec_multi_vector_query_create(); + TEST_ASSERT(mvq != NULL); + zvec_multi_vector_query_set_topk(mvq, 3); + zvec_multi_vector_query_set_include_vector(mvq, false); + + // Add first sub-query for embedding1 + zvec_multi_vector_sub_query_t *vq1 = zvec_multi_vector_sub_query_create(); + zvec_multi_vector_sub_query_set_field_name(vq1, "embedding1"); + zvec_multi_vector_sub_query_set_query_vector(vq1, e1_v1, sizeof(e1_v1)); + zvec_multi_vector_sub_query_set_num_candidates(vq1, 3); + zvec_multi_vector_query_add_query(mvq, vq1); + + // Add second sub-query for embedding2 + zvec_multi_vector_sub_query_t *vq2 = zvec_multi_vector_sub_query_create(); + zvec_multi_vector_sub_query_set_field_name(vq2, "embedding2"); + zvec_multi_vector_sub_query_set_query_vector(vq2, e2_v1, sizeof(e2_v1)); + zvec_multi_vector_sub_query_set_num_candidates(vq2, 3); + zvec_multi_vector_query_add_query(mvq, vq2); + + // Set reranker + zvec_multi_vector_query_set_reranker(mvq, rrf); + + TEST_ASSERT(zvec_multi_vector_query_get_query_count(mvq) == 2); + TEST_ASSERT(zvec_multi_vector_query_get_topk(mvq) == 3); + TEST_ASSERT(zvec_multi_vector_query_get_include_vector(mvq) == false); + + // Execute multi query + zvec_doc_t **results = NULL; + size_t result_count = 0; + err = zvec_collection_multi_query(collection, mvq, &results, &result_count); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(results != NULL); + TEST_ASSERT(result_count > 0); + TEST_ASSERT(result_count <= 3); + + zvec_docs_free(results, result_count); + + // Cleanup + zvec_multi_vector_sub_query_destroy(vq1); + zvec_multi_vector_sub_query_destroy(vq2); + zvec_multi_vector_query_destroy(mvq); + zvec_reranker_destroy(rrf); + + // Test 2: MultiVectorQuery property setters/getters + zvec_multi_vector_query_t *mvq2 = zvec_multi_vector_query_create(); + TEST_ASSERT(mvq2 != NULL); + zvec_multi_vector_query_set_topk(mvq2, 5); + TEST_ASSERT(zvec_multi_vector_query_get_topk(mvq2) == 5); + + zvec_multi_vector_query_set_filter(mvq2, "id > 1"); + TEST_ASSERT(strcmp(zvec_multi_vector_query_get_filter(mvq2), "id > 1") == 0); + + zvec_multi_vector_query_set_include_vector(mvq2, true); + TEST_ASSERT(zvec_multi_vector_query_get_include_vector(mvq2) == true); + + const char *out_fields[] = {"id"}; + zvec_multi_vector_query_set_output_fields(mvq2, out_fields, 1); + const char **got_fields = NULL; + size_t field_count = 0; + err = zvec_multi_vector_query_get_output_fields(mvq2, &got_fields, + &field_count); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(field_count == 1); + if (field_count > 0) { + TEST_ASSERT(strcmp(got_fields[0], "id") == 0); + zvec_free((char *)got_fields); + } + + zvec_multi_vector_query_destroy(mvq2); + + // Cleanup documents + for (int i = 0; i < 4; i++) { + zvec_doc_destroy(docs[i]); + } + zvec_collection_destroy(collection); + } + + zvec_collection_schema_destroy(schema); + } + + cleanup_temp_directory(temp_dir); + + TEST_END(); +} + void test_index_creation_and_management(void) { TEST_START(); @@ -5409,6 +5638,8 @@ int main(void) { // Query tests test_query_params_functions(); test_actual_vector_queries(); + test_reranker_functions(); + test_multi_vector_query_with_reranker(); // Performance tests // test_performance_benchmarks(); diff --git a/tests/db/collection_test.cc b/tests/db/collection_test.cc index 6053ad04e..3a88de378 100644 --- a/tests/db/collection_test.cc +++ b/tests/db/collection_test.cc @@ -33,6 +33,7 @@ #include "zvec/db/doc.h" #include "zvec/db/index_params.h" #include "zvec/db/options.h" +#include "zvec/db/reranker.h" #include "zvec/db/schema.h" #include "zvec/db/status.h" #include "zvec/db/type.h" @@ -3586,6 +3587,342 @@ TEST_F(CollectionTest, Feature_Query_WithoutVector_WithScalarIndex) { "array_int32 contain_any (1)", 1); } +// ============================================================================= +// MultiQuery Tests +// ============================================================================= + +TEST_F(CollectionTest, Feature_MultiQuery_Validate) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + // Test 1: Empty queries should fail + { + MultiVectorQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(10, 60); + auto result = collection->MultiQuery(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); + } + + // Test 2: No reranker with multiple queries should fail + { + MultiVectorQuery mvq; + mvq.topk = 10; + auto query_doc = TestHelper::CreateDoc(1, *schema); + + SubVectorQuery vq1; + vq1.num_candidates_ = 10; + vq1.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + vq1.query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq1); + + SubVectorQuery vq2; + vq2.num_candidates_ = 10; + vq2.field_name_ = "dense_fp16"; + auto vector2 = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector2.has_value()); + vq2.query_vector_.assign((char *)vector2.value().data(), + vector2.value().size() * sizeof(float)); + mvq.queries.push_back(vq2); + + auto result = collection->MultiQuery(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); + } + + // Test 3: Invalid field name should fail + { + MultiVectorQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(10, 60); + + SubVectorQuery vq1; + vq1.num_candidates_ = 10; + vq1.field_name_ = "nonexistent_field"; + vq1.query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq1); + + SubVectorQuery vq2; + vq2.num_candidates_ = 10; + vq2.field_name_ = "dense_fp32"; + vq2.query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq2); + + auto result = collection->MultiQuery(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); + } + + // Test 4: Duplicate field names should fail + { + MultiVectorQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(10, 60); + + SubVectorQuery vq1; + vq1.num_candidates_ = 10; + vq1.field_name_ = "dense_fp32"; + vq1.query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq1); + + SubVectorQuery vq2; + vq2.num_candidates_ = 10; + vq2.field_name_ = "dense_fp32"; + vq2.query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq2); + + auto result = collection->MultiQuery(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); + } +} + +TEST_F(CollectionTest, Feature_MultiQuery_SingleFieldWithReranker) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + // Single query with reranker should fail (requires at least 2 sub-queries) + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiVectorQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(10, 60); + + SubVectorQuery vq; + vq.num_candidates_ = 10; + vq.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + vq.query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq); + + auto result = collection->MultiQuery(mvq); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); +} + +TEST_F(CollectionTest, Feature_MultiQuery_MultiFieldRRF) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiVectorQuery mvq; + mvq.topk = 10; + mvq.reranker = std::make_shared(10, 60); + + // Query dense_fp32 and dense_fp16 fields with different vectors + auto vector1 = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector1.has_value()); + + { + SubVectorQuery vq; + vq.num_candidates_ = 10; + vq.field_name_ = "dense_fp32"; + vq.query_vector_.assign((char *)vector1.value().data(), + vector1.value().size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + // Query sparse_fp32 field + auto sparse = query_doc.get< + std::pair, std::vector>>("sparse_fp32"); + ASSERT_TRUE(sparse.has_value()); + + { + SubVectorQuery vq; + vq.num_candidates_ = 10; + vq.field_name_ = "sparse_fp32"; + vq.query_sparse_indices_.assign( + (char *)sparse.value().first.data(), + sparse.value().first.size() * sizeof(uint32_t)); + vq.query_sparse_values_.assign( + (char *)sparse.value().second.data(), + sparse.value().second.size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + auto result = collection->MultiQuery(mvq); + ASSERT_TRUE(result.has_value()) << result.error().message(); + EXPECT_GT(result.value().size(), 0u); + EXPECT_LE(result.value().size(), 10u); + + // All results should have valid scores (RRF fused) + for (const auto &doc : result.value()) { + EXPECT_NE(doc->score(), 0.0f); + } +} + +TEST_F(CollectionTest, Feature_MultiQuery_MultiFieldWeighted) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiVectorQuery mvq; + mvq.topk = 10; + std::map weights = {{"dense_fp32", 0.7}, + {"sparse_fp32", 0.3}}; + mvq.reranker = std::make_shared(10, MetricType::IP, weights); + + // Query dense_fp32 field + { + SubVectorQuery vq; + vq.num_candidates_ = 10; + vq.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + vq.query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + // Query sparse_fp32 field + { + SubVectorQuery vq; + vq.num_candidates_ = 10; + vq.field_name_ = "sparse_fp32"; + auto sparse = query_doc.get< + std::pair, std::vector>>("sparse_fp32"); + ASSERT_TRUE(sparse.has_value()); + vq.query_sparse_indices_.assign( + (char *)sparse.value().first.data(), + sparse.value().first.size() * sizeof(uint32_t)); + vq.query_sparse_values_.assign( + (char *)sparse.value().second.data(), + sparse.value().second.size() * sizeof(float)); + mvq.queries.push_back(vq); + } + + auto result = collection->MultiQuery(mvq); + ASSERT_TRUE(result.has_value()) << result.error().message(); + EXPECT_GT(result.value().size(), 0u); + EXPECT_LE(result.value().size(), 10u); +} + +TEST_F(CollectionTest, Feature_MultiQuery_WithFilter) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiVectorQuery mvq; + mvq.topk = 10; + mvq.filter = "int32 > 50"; + mvq.reranker = std::make_shared(10, 60); + + SubVectorQuery vq1; + vq1.num_candidates_ = 10; + vq1.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + vq1.query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq1); + + auto sparse = query_doc.get< + std::pair, std::vector>>("sparse_fp32"); + ASSERT_TRUE(sparse.has_value()); + SubVectorQuery vq2; + vq2.num_candidates_ = 10; + vq2.field_name_ = "sparse_fp32"; + vq2.query_sparse_indices_.assign( + (char *)sparse.value().first.data(), + sparse.value().first.size() * sizeof(uint32_t)); + vq2.query_sparse_values_.assign( + (char *)sparse.value().second.data(), + sparse.value().second.size() * sizeof(float)); + mvq.queries.push_back(vq2); + + auto result = collection->MultiQuery(mvq); + ASSERT_TRUE(result.has_value()) << result.error().message(); + EXPECT_GT(result.value().size(), 0u); + EXPECT_LE(result.value().size(), 10u); +} + +TEST_F(CollectionTest, Feature_MultiQuery_WithOutputFields) { + FileHelper::RemoveDirectory(col_path); + + int doc_count = 100; + auto schema = TestHelper::CreateNormalSchema(); + auto options = CollectionOptions{false, true, 100 * 1024 * 1024}; + auto collection = TestHelper::CreateCollectionWithDoc(col_path, *schema, + options, 0, doc_count); + ASSERT_NE(collection, nullptr); + + auto query_doc = TestHelper::CreateDoc(1, *schema); + + MultiVectorQuery mvq; + mvq.topk = 5; + mvq.include_vector = false; + mvq.output_fields = std::make_optional>( + std::vector{"int32", "string"}); + mvq.reranker = std::make_shared(5, 60); + + SubVectorQuery vq1; + vq1.num_candidates_ = 10; + vq1.field_name_ = "dense_fp32"; + auto vector = query_doc.get>("dense_fp32"); + ASSERT_TRUE(vector.has_value()); + vq1.query_vector_.assign((char *)vector.value().data(), + vector.value().size() * sizeof(float)); + mvq.queries.push_back(vq1); + + auto sparse = query_doc.get< + std::pair, std::vector>>("sparse_fp32"); + ASSERT_TRUE(sparse.has_value()); + SubVectorQuery vq2; + vq2.num_candidates_ = 10; + vq2.field_name_ = "sparse_fp32"; + vq2.query_sparse_indices_.assign( + (char *)sparse.value().first.data(), + sparse.value().first.size() * sizeof(uint32_t)); + vq2.query_sparse_values_.assign( + (char *)sparse.value().second.data(), + sparse.value().second.size() * sizeof(float)); + mvq.queries.push_back(vq2); + + auto result = collection->MultiQuery(mvq); + ASSERT_TRUE(result.has_value()) << result.error().message(); + EXPECT_GT(result.value().size(), 0u); + EXPECT_LE(result.value().size(), 5u); +} + TEST_F(CollectionTest, Feature_GroupByQuery) {} TEST_F(CollectionTest, Feature_AddColumn_General) { diff --git a/tests/db/reranker_test.cc b/tests/db/reranker_test.cc new file mode 100644 index 000000000..73eee0cb7 --- /dev/null +++ b/tests/db/reranker_test.cc @@ -0,0 +1,192 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace zvec; + +namespace { + +//! Helper to create a Doc::Ptr with given id and score +Doc::Ptr MakeDoc(const std::string& id, float score) { + auto doc = std::make_shared(); + doc->set_pk(id); + doc->set_score(score); + return doc; +} + +} // namespace + +// ==================== RrfReRanker Tests ==================== + +TEST(RrfReRankerTest, BasicRRF) { + RrfReRanker reranker(/*topn=*/10, /*rank_constant=*/60); + + // Two vector fields, each returning 3 documents with some overlap + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.9f), MakeDoc("b", 0.8f), + MakeDoc("c", 0.7f)}; + query_results["vec2"] = {MakeDoc("b", 0.95f), MakeDoc("a", 0.85f), + MakeDoc("d", 0.75f)}; + + auto results = reranker.rerank(query_results); + + // "a" appears at rank 0 in vec1 and rank 1 in vec2: + // rrf_score = 1/(60+0+1) + 1/(60+1+1) = 1/61 + 1/62 + // "b" appears at rank 1 in vec1 and rank 0 in vec2: + // rrf_score = 1/(60+1+1) + 1/(60+0+1) = 1/62 + 1/61 + // So a and b should have equal scores and be at the top + ASSERT_GE(results.size(), 3u); + + // "a" and "b" should have the highest RRF scores + EXPECT_EQ(results[0]->pk(), "a"); + EXPECT_EQ(results[1]->pk(), "b"); + // Verify scores are close (a and b have same RRF score) + EXPECT_NEAR(results[0]->score(), results[1]->score(), 1e-10); +} + +TEST(RrfReRankerTest, Topn) { + RrfReRanker reranker(/*topn=*/2, /*rank_constant=*/60); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.9f), MakeDoc("b", 0.8f), + MakeDoc("c", 0.7f)}; + + auto results = reranker.rerank(query_results); + ASSERT_EQ(results.size(), 2u); +} + +TEST(RrfReRankerTest, SingleField) { + RrfReRanker reranker(/*topn=*/10, /*rank_constant=*/60); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.9f), MakeDoc("b", 0.8f)}; + + auto results = reranker.rerank(query_results); + ASSERT_EQ(results.size(), 2u); + // With single field, RRF score for rank 0 > rank 1 + EXPECT_GT(results[0]->score(), results[1]->score()); +} + +TEST(RrfReRankerTest, EmptyResults) { + RrfReRanker reranker(/*topn=*/10, /*rank_constant=*/60); + + std::map query_results; + auto results = reranker.rerank(query_results); + EXPECT_TRUE(results.empty()); +} + +// ==================== WeightedReRanker Tests ==================== + +TEST(WeightedReRankerTest, BasicWeighted) { + WeightedReRanker reranker(/*topn=*/10, MetricType::L2, + {{"vec1", 0.7}, {"vec2", 0.3}}); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.5f), MakeDoc("b", 0.3f)}; + query_results["vec2"] = {MakeDoc("a", 0.8f), MakeDoc("c", 0.6f)}; + + auto results = reranker.rerank(query_results); + ASSERT_GE(results.size(), 2u); + // "a" appears in both fields, should have highest combined score + EXPECT_EQ(results[0]->pk(), "a"); +} + +TEST(WeightedReRankerTest, NormalizeL2) { + // L2: normalize_score = 1 - 2*atan(score)/pi + // For score=0: 1 - 0 = 1.0 + // For score->inf: 1 - 2*(pi/2)/pi = 0.0 + EXPECT_NEAR(WeightedReRanker::normalize_score(0.0, MetricType::L2), 1.0, + 1e-10); + EXPECT_GT(WeightedReRanker::normalize_score(1.0, MetricType::L2), 0.0); + EXPECT_LT(WeightedReRanker::normalize_score(1.0, MetricType::L2), 1.0); +} + +TEST(WeightedReRankerTest, NormalizeIP) { + // IP: normalize_score = 0.5 + atan(score)/pi + // For score=0: 0.5 + 0 = 0.5 + EXPECT_NEAR(WeightedReRanker::normalize_score(0.0, MetricType::IP), 0.5, + 1e-10); + EXPECT_GT(WeightedReRanker::normalize_score(1.0, MetricType::IP), 0.5); +} + +TEST(WeightedReRankerTest, NormalizeCosine) { + // COSINE: normalize_score = 1 - score/2 + // For score=0: 1 - 0 = 1.0 + // For score=1: 1 - 0.5 = 0.5 + // For score=2: 1 - 1.0 = 0.0 + EXPECT_NEAR(WeightedReRanker::normalize_score(0.0, MetricType::COSINE), 1.0, + 1e-10); + EXPECT_NEAR(WeightedReRanker::normalize_score(1.0, MetricType::COSINE), 0.5, + 1e-10); + EXPECT_NEAR(WeightedReRanker::normalize_score(2.0, MetricType::COSINE), 0.0, + 1e-10); +} + +TEST(WeightedReRankerTest, Topn) { + WeightedReRanker reranker(/*topn=*/2, MetricType::L2, {}); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.1f), MakeDoc("b", 0.2f), + MakeDoc("c", 0.3f)}; + + auto results = reranker.rerank(query_results); + ASSERT_EQ(results.size(), 2u); +} + +TEST(WeightedReRankerTest, UnsupportedMetric) { + EXPECT_THROW(WeightedReRanker::normalize_score(1.0, MetricType::UNDEFINED), + std::invalid_argument); +} + +// ==================== CallbackReRanker Tests ==================== + +TEST(CallbackReRankerTest, BasicCallback) { + // Simple callback that returns docs sorted by score descending + CallbackReRanker::Callback cb = + [](const std::map& query_results) -> DocPtrList { + DocPtrList all_docs; + for (const auto& [_, docs] : query_results) { + for (const auto& doc : docs) { + all_docs.push_back(doc); + } + } + std::sort(all_docs.begin(), all_docs.end(), + [](const Doc::Ptr& a, const Doc::Ptr& b) { + return a->score() > b->score(); + }); + return all_docs; + }; + + CallbackReRanker reranker(cb, /*topn=*/10); + + std::map query_results; + query_results["vec1"] = {MakeDoc("a", 0.5f), MakeDoc("b", 0.9f)}; + query_results["vec2"] = {MakeDoc("c", 0.7f)}; + + auto results = reranker.rerank(query_results); + ASSERT_EQ(results.size(), 3u); + // Should be sorted by score descending + EXPECT_EQ(results[0]->pk(), "b"); + EXPECT_EQ(results[1]->pk(), "c"); + EXPECT_EQ(results[2]->pk(), "a"); +}