diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..bb178984d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +# Auto-generated files — collapsed in GitHub PR diffs +src/db/index/column/fts_column/gen/** linguist-generated=true +src/db/sqlengine/antlr/gen/** linguist-generated=true diff --git a/.gitmodules b/.gitmodules index 51934dfed..2f501c34b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -40,3 +40,12 @@ [submodule "thirdparty/RaBitQ-Library/RaBitQ-Library-0.1"] path = thirdparty/RaBitQ-Library/RaBitQ-Library-0.1 url = https://github.com/VectorDB-NTU/RaBitQ-Library.git +[submodule "thirdparty/cppjieba/cppjieba-5.6.7"] + path = thirdparty/cppjieba/cppjieba-5.6.7 + url = https://github.com/yanyiwu/cppjieba.git +[submodule "thirdparty/FastPFOR/FastPFOR-0.4.0"] + path = thirdparty/FastPFOR/FastPFOR-0.4.0 + url = https://github.com/fast-pack/FastPFOR.git +[submodule "thirdparty/limonp/limonp-v1.0.2"] + path = thirdparty/limonp/limonp-v1.0.2 + url = https://github.com/yanyiwu/limonp.git diff --git a/CMakeLists.txt b/CMakeLists.txt index a33e61e99..c492a95c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,4 +145,14 @@ if(BUILD_PYTHON_BINDINGS) message(STATUS "Zvec install path: ${ZVEC_PY_INSTALL_DIR}") install(TARGETS _zvec LIBRARY DESTINATION ${ZVEC_PY_INSTALL_DIR}) + + # Bundle cppjieba's dictionary files so the `jieba` FTS tokenizer works + # out of the box. python/zvec/__init__.py resolves this directory via + # importlib.resources and registers it with set_default_jieba_dict_dir(). + set(ZVEC_JIEBA_DICT_SRC + "${PROJECT_SOURCE_DIR}/thirdparty/cppjieba/cppjieba-5.6.7/dict") + install(FILES + "${ZVEC_JIEBA_DICT_SRC}/jieba.dict.utf8" + "${ZVEC_JIEBA_DICT_SRC}/hmm_model.utf8" + DESTINATION ${ZVEC_PY_INSTALL_DIR}/zvec/data/jieba_dict) endif() diff --git a/python/tests/test_collection_fts.py b/python/tests/test_collection_fts.py new file mode 100644 index 000000000..55832a143 --- /dev/null +++ b/python/tests/test_collection_fts.py @@ -0,0 +1,188 @@ +# Copyright 2025-present the zvec project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end tests for FTS-only collections (no vector field). + +The schema validation rule "must have at least one vector field" has been +lifted; these tests pin the new behavior so insert / query / delete / +optimize all work on a vector-less collection. +""" + +from __future__ import annotations + +import pytest +import zvec +from zvec import ( + Collection, + CollectionOption, + DataType, + Doc, + FieldSchema, + FtsIndexParam, + OptimizeOption, +) +from zvec.model.param.query import Fts, Query + + +# ==================== Fixtures ==================== + + +@pytest.fixture(scope="function") +def fts_collection(tmp_path_factory) -> Collection: + """FTS-only collection: a STRING field for forward + an FTS-indexed STRING.""" + temp_dir = tmp_path_factory.mktemp("zvec_fts_only") + collection_path = temp_dir / "fts_collection" + + schema = zvec.CollectionSchema( + name="fts_only", + fields=[ + FieldSchema("title", DataType.STRING, nullable=False), + FieldSchema( + "content", + DataType.STRING, + nullable=False, + index_param=FtsIndexParam( + tokenizer_name="standard", + filters=["lowercase"], + ), + ), + ], + # vectors omitted on purpose — schema validation must accept this. + ) + + coll = zvec.create_and_open( + path=str(collection_path), + schema=schema, + option=CollectionOption(read_only=False, enable_mmap=True), + ) + assert coll is not None + + try: + yield coll + finally: + try: + coll.destroy() + except Exception as e: + print(f"Warning: failed to destroy collection: {e}") + + +def _make_docs() -> list[Doc]: + """5-doc corpus where 4 contain 'hello' and doc 4 is the only outlier.""" + return [ + Doc(id="pk_0", fields={"title": "intro", "content": "hello world"}), + Doc(id="pk_1", fields={"title": "guide", "content": "hello foo bar"}), + Doc(id="pk_2", fields={"title": "tips", "content": "hello baz"}), + Doc(id="pk_3", fields={"title": "more", "content": "hello hello"}), + Doc(id="pk_4", fields={"title": "other", "content": "nothing relevant"}), + ] + + +def _fts_query(coll: Collection, term: str) -> list[Doc]: + """Run a single-term FTS match query against the `content` field.""" + return coll.query( + queries=Query(field_name="content", fts=Fts(match_string=term)), + topk=10, + ) + + +# ==================== Tests ==================== + + +class TestFtsOnlyCollectionSchema: + def test_create_and_open_without_vectors(self, fts_collection: Collection): + """Schema with zero vector fields must be accepted by validate().""" + assert fts_collection.schema.name == "fts_only" + assert {f.name for f in fts_collection.schema.fields} == {"title", "content"} + # Empty vectors is the whole point of the test. + assert list(fts_collection.schema.vectors) == [] + assert fts_collection.stats.doc_count == 0 + + def test_create_schema_omitting_vectors_kwarg(self): + """Constructing CollectionSchema without `vectors=` argument is valid.""" + schema = zvec.CollectionSchema( + name="bare_fts", + fields=[ + FieldSchema( + "content", + DataType.STRING, + nullable=False, + index_param=FtsIndexParam(), + ), + ], + ) + assert list(schema.vectors) == [] + assert {f.name for f in schema.fields} == {"content"} + + +class TestFtsOnlyCollectionLifecycle: + def test_insert_and_fts_query(self, fts_collection: Collection): + """FTS-only collection supports insert + FTS query end-to-end.""" + results = fts_collection.insert(_make_docs()) + assert all(r.ok() for r in results) + assert fts_collection.stats.doc_count == 5 + + hits = _fts_query(fts_collection, "hello") + assert len(hits) == 4 + assert {doc.id for doc in hits} == {"pk_0", "pk_1", "pk_2", "pk_3"} + + # Term that nothing in the surviving corpus contains. + assert _fts_query(fts_collection, "missing_term_xyz") == [] + + def test_delete_then_query(self, fts_collection: Collection): + """Tombstone filter must drop deleted docs from FTS results.""" + fts_collection.insert(_make_docs()) + statuses = fts_collection.delete(["pk_0", "pk_4"]) + assert all(s.ok() for s in statuses) + assert fts_collection.stats.doc_count == 3 + + hits = _fts_query(fts_collection, "hello") + assert len(hits) == 3 + assert {doc.id for doc in hits} == {"pk_1", "pk_2", "pk_3"} + # pk_4's unique term is filtered out post-delete. + assert _fts_query(fts_collection, "nothing") == [] + + def test_optimize_rebuilds_fts(self, fts_collection: Collection): + """Optimize with >30% deletes triggers ReduceFts; recall unchanged.""" + fts_collection.insert(_make_docs()) + # 40% delete ratio — above COMPACT_DELETE_RATIO_THRESHOLD=0.3, so + # build_compact_task picks the rebuild path and ReduceFts runs. + fts_collection.delete(["pk_0", "pk_4"]) + + before = {doc.id for doc in _fts_query(fts_collection, "hello")} + assert before == {"pk_1", "pk_2", "pk_3"} + + fts_collection.optimize(option=OptimizeOption()) + assert fts_collection.stats.doc_count == 3 + + after = {doc.id for doc in _fts_query(fts_collection, "hello")} + assert after == before + assert _fts_query(fts_collection, "nothing") == [] + + +class TestFtsOnlyCollectionQueryValidation: + def test_vector_query_rejected(self, fts_collection: Collection): + """Vector query on a no-vector collection must raise.""" + with pytest.raises(ValueError, match="vector or id"): + fts_collection.query( + queries=Query(field_name="content", vector=[0.1, 0.2, 0.3]), + topk=5, + ) + + def test_id_query_rejected(self, fts_collection: Collection): + """ID-based query on a no-vector collection must raise.""" + fts_collection.insert(_make_docs()[:1]) + with pytest.raises(ValueError, match="vector or id"): + fts_collection.query( + queries=Query(field_name="content", id="pk_0"), + topk=5, + ) diff --git a/python/tests/test_fts_query.py b/python/tests/test_fts_query.py new file mode 100644 index 000000000..74cca6a9b --- /dev/null +++ b/python/tests/test_fts_query.py @@ -0,0 +1,158 @@ +# Copyright 2025-present the zvec project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for FTS (Full-Text Search) query support in the Python SDK.""" + +import pickle + +import pytest + +from zvec.model.param.query import Fts, Query + + +class TestFtsQueryValidation: + """Test FTS parameter validation in Query dataclass.""" + + def test_fts_query_string_only(self): + """Query with only query_string in Fts should be valid.""" + q = Query( + field_name="content", fts=Fts(query_string='+hello -world "exact phrase"') + ) + q._validate() + assert q.fts.query_string == '+hello -world "exact phrase"' + assert q.fts.match_string is None + assert q.has_fts() is True + + def test_fts_match_string_only(self): + """Query with only match_string in Fts should be valid.""" + q = Query(field_name="content", fts=Fts(match_string="machine learning")) + q._validate() + assert q.fts.match_string == "machine learning" + assert q.fts.query_string is None + assert q.has_fts() is True + + def test_fts_query_string_and_match_string_mutually_exclusive(self): + """Cannot provide both query_string and match_string in Fts.""" + q = Query( + field_name="content", + fts=Fts(query_string="+hello", match_string="hello world"), + ) + with pytest.raises(ValueError, match="mutually exclusive"): + q._validate() + + def test_no_fts(self): + """Query without FTS fields should have has_fts() == False.""" + q = Query(field_name="embedding", vector=[0.1, 0.2, 0.3]) + assert q.has_fts() is False + + def test_vector_and_fts_mutually_exclusive(self): + """Cannot combine vector search with FTS in a single Query.""" + q = Query( + field_name="embedding", + vector=[0.1, 0.2, 0.3], + fts=Fts(match_string="deep learning"), + ) + with pytest.raises(ValueError, match="Cannot combine fts with vector search"): + q._validate() + + def test_fts_without_vector_or_id(self): + """Query with only FTS (no vector, no id) should be valid.""" + q = Query(field_name="content", fts=Fts(query_string="hello")) + q._validate() + assert q.has_vector() is False + assert q.has_id() is False + assert q.has_fts() is True + + +class TestFtsQueryBinding: + """Test FTS binding layer (_Fts).""" + + def test_import_fts_query(self): + """_Fts should be importable from _zvec.param.""" + from _zvec.param import _Fts + + fts = _Fts() + assert fts.query_string == "" + assert fts.match_string == "" + + def test_fts_query_set_fields(self): + """Setting fields on _Fts should work.""" + from _zvec.param import _Fts + + fts = _Fts() + fts.query_string = "+hello -world" + assert fts.query_string == "+hello -world" + + fts2 = _Fts() + fts2.match_string = "machine learning" + assert fts2.match_string == "machine learning" + + def test_fts_query_pickle(self): + """_Fts should support pickling.""" + from _zvec.param import _Fts + + fts = _Fts() + fts.query_string = "+vector search" + fts.match_string = "" + + data = pickle.dumps(fts) + restored = pickle.loads(data) + assert restored.query_string == "+vector search" + assert restored.match_string == "" + + def test_vector_query_fts_field(self): + """_VectorQuery should have fts field.""" + from _zvec.param import _Fts, _VectorQuery + + vq = _VectorQuery() + # fts should be None by default (optional) + assert vq.fts is None + + # set fts + fts = _Fts() + fts.query_string = "hello" + vq.fts = fts + assert vq.fts is not None + assert vq.fts.query_string == "hello" + + def test_vector_query_pickle_with_fts(self): + """_VectorQuery with fts should survive pickling.""" + from _zvec.param import _Fts, _VectorQuery + + vq = _VectorQuery() + vq.topk = 10 + vq.field_name = "embedding" + fts = _Fts() + fts.match_string = "test query" + vq.fts = fts + + data = pickle.dumps(vq) + restored = pickle.loads(data) + assert restored.topk == 10 + assert restored.field_name == "embedding" + assert restored.fts is not None + assert restored.fts.match_string == "test query" + + def test_vector_query_pickle_without_fts(self): + """_VectorQuery without fts should survive pickling.""" + from _zvec.param import _VectorQuery + + vq = _VectorQuery() + vq.topk = 5 + vq.field_name = "vec" + + data = pickle.dumps(vq) + restored = pickle.loads(data) + assert restored.topk == 5 + assert restored.field_name == "vec" + assert restored.fts is None diff --git a/python/tests/test_jieba_default_dict.py b/python/tests/test_jieba_default_dict.py new file mode 100644 index 000000000..278a3f1fc --- /dev/null +++ b/python/tests/test_jieba_default_dict.py @@ -0,0 +1,143 @@ +# Copyright 2025-present the zvec project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end: jieba FTS works without any user configuration. + +`import zvec` is supposed to register the wheel-bundled jieba dict +directory via `set_default_jieba_dict_dir`. With that in place a user can +declare an FTS field with `tokenizer_name="jieba"`, leave `extra_params` +empty, and Chinese full-text search just works. + +Falls back to GTEST_SKIP-equivalent when running against a build that did +not bundle the dict (e.g., source-tree dev install without the install +step). In that case CI will rely on the C++ unit tests instead. +""" + +from __future__ import annotations + +import pytest +import zvec +from zvec import ( + Collection, + CollectionOption, + DataType, + Doc, + FieldSchema, + FtsIndexParam, +) +from zvec.model.param.query import Fts, Query + + +def _bundled_dict_dir() -> str: + """Path zvec.__init__ would have registered; empty when not bundled.""" + return zvec.get_default_jieba_dict_dir() + + +def _bundled_dict_files_exist() -> bool: + """Whether the registered default actually contains the dict files. + + `importlib.resources` happily returns a path even when the data dir was + not installed (e.g. source-tree dev runs); only an installed wheel has + the files on disk. + """ + import os + + base = _bundled_dict_dir() + if not base: + return False + return os.path.isfile(os.path.join(base, "jieba.dict.utf8")) and os.path.isfile( + os.path.join(base, "hmm_model.utf8") + ) + + +@pytest.fixture(scope="module", autouse=True) +def _require_bundled_dict(): + if not _bundled_dict_files_exist(): + pytest.skip( + "Bundled jieba dict not found at zvec/data/jieba_dict/ — " + "this test requires an installed wheel (not a source-tree dev " + "build without the install step).", + ) + + +@pytest.fixture(scope="function") +def jieba_collection(tmp_path_factory) -> Collection: + """FTS-only collection using jieba tokenizer and no explicit dict path.""" + temp_dir = tmp_path_factory.mktemp("zvec_jieba_default") + collection_path = temp_dir / "fts_jieba" + + schema = zvec.CollectionSchema( + name="fts_jieba_default", + fields=[ + FieldSchema("title", DataType.STRING, nullable=False), + FieldSchema( + "content", + DataType.STRING, + nullable=False, + # Deliberately omit extra_params — the bundled default must + # be picked up via GlobalConfig.jieba_dict_dir. + index_param=FtsIndexParam( + tokenizer_name="jieba", + filters=["lowercase"], + ), + ), + ], + ) + + coll = zvec.create_and_open( + path=str(collection_path), + schema=schema, + option=CollectionOption(read_only=False, enable_mmap=True), + ) + assert coll is not None + try: + yield coll + finally: + try: + coll.destroy() + except Exception as e: + print(f"Warning: failed to destroy collection: {e}") + + +def test_jieba_works_without_explicit_dict_path(jieba_collection: Collection): + """User opens collection, inserts CJK doc, searches — no init() / no + extra_params / no env var / no manual setter call. Just `import zvec`.""" + docs = [ + Doc(id="pk_1", fields={"title": "t1", "content": "中华人民共和国成立"}), + Doc(id="pk_2", fields={"title": "t2", "content": "无关文档"}), + ] + insert_results = jieba_collection.insert(docs) + assert all(r.ok() for r in insert_results) + + hits = jieba_collection.query( + queries=Query(field_name="content", fts=Fts(match_string="中华")), + topk=10, + ) + ids = {doc.id for doc in hits} + assert "pk_1" in ids + assert "pk_2" not in ids + + +def test_default_dict_dir_is_registered_on_import(): + """Sanity check: zvec.__init__ registered a non-empty default.""" + assert _bundled_dict_dir() != "" + + +def test_user_can_override_default_at_runtime(): + """zvec.set_default_jieba_dict_dir can be called any time to override.""" + saved = zvec.get_default_jieba_dict_dir() + try: + zvec.set_default_jieba_dict_dir("/tmp/zvec/jieba-override") + assert zvec.get_default_jieba_dict_dir() == "/tmp/zvec/jieba-override" + finally: + zvec.set_default_jieba_dict_dir(saved) diff --git a/python/tests/test_query_executor.py b/python/tests/test_query_executor.py index 6b9b76356..0581183d5 100644 --- a/python/tests/test_query_executor.py +++ b/python/tests/test_query_executor.py @@ -225,7 +225,9 @@ def test_init(self): def test_do_validate_with_queries(self): schema = MockCollectionSchema() executor = NoVectorQueryExecutor(schema) - ctx = QueryContext(topk=10, queries=[Query(field_name="test")]) + ctx = QueryContext( + topk=10, queries=[Query(field_name="test", vector=[0.1, 0.2, 0.3])] + ) with pytest.raises( ValueError, match="Collection does not support query with vector or id" diff --git a/python/zvec/__init__.py b/python/zvec/__init__.py index 705f3e366..655535ebe 100644 --- a/python/zvec/__init__.py +++ b/python/zvec/__init__.py @@ -21,6 +21,24 @@ from importlib.metadata import PackageNotFoundError +# Register the wheel-bundled jieba dict dir so `import zvec` alone makes +# the jieba FTS tokenizer usable. Users can still override via +# zvec.init(jieba_dict_dir=...), zvec.set_default_jieba_dict_dir(...), +# ZVEC_JIEBA_DICT_DIR, or per-field FtsIndexParam.extra_params. +try: + from importlib.resources import files as _resource_files + + from _zvec import ( + get_default_jieba_dict_dir, + set_default_jieba_dict_dir, + ) + + set_default_jieba_dict_dir(str(_resource_files("zvec").joinpath("data/jieba_dict"))) +except Exception: + # Custom builds without bundled dict; users must configure explicitly. + pass + + # ============================== # Public API — grouped by category # ============================== @@ -56,11 +74,14 @@ from .model.doc import Doc # —— Query & index parameters —— +# —— FTS params (C++ binding) —— from .model.param import ( AddColumnOption, AlterColumnOption, CollectionOption, FlatIndexParam, + FtsIndexParam, + FtsQueryParam, HnswIndexParam, HnswQueryParam, HnswRabitqIndexParam, @@ -73,7 +94,7 @@ VamanaIndexParam, VamanaQueryParam, ) -from .model.param.query import Query, VectorQuery +from .model.param.query import Fts, Query, VectorQuery # —— Schema & field definitions —— from .model.schema import CollectionSchema, CollectionStats, FieldSchema, VectorSchema @@ -101,6 +122,8 @@ "create_and_open", "init", "open", + "set_default_jieba_dict_dir", + "get_default_jieba_dict_dir", # Core classes "Collection", "Doc", @@ -112,6 +135,9 @@ # Parameters "Query", "VectorQuery", + "Fts", + "FtsIndexParam", + "FtsQueryParam", "InvertIndexParam", "HnswIndexParam", "HnswRabitqIndexParam", diff --git a/python/zvec/executor/query_executor.py b/python/zvec/executor/query_executor.py index 3e54e37d2..d2d3391ca 100644 --- a/python/zvec/executor/query_executor.py +++ b/python/zvec/executor/query_executor.py @@ -20,7 +20,7 @@ import numpy as np from _zvec import _Collection -from _zvec.param import _VectorQuery +from _zvec.param import _Fts, _VectorQuery from ..extension import ReRanker, RrfReRanker, WeightedReRanker from ..model.convert import convert_to_py_doc @@ -141,6 +141,14 @@ def _do_build_query_wo_vector(self, ctx: QueryContext) -> _VectorQuery: core_vector.output_fields = ctx.output_fields return core_vector + def _do_build_fts_query(self, query: Query, core_vector: _VectorQuery) -> None: + """Set FTS query on core_vector if the query has FTS parameters.""" + if query.has_fts(): + fts = _Fts() + fts.query_string = query.fts.query_string or "" + fts.match_string = query.fts.match_string or "" + core_vector.fts = fts + def _do_build_query_with_vector( self, ctx: QueryContext, query: Query, collection: _Collection ) -> _VectorQuery: @@ -149,25 +157,34 @@ def _do_build_query_with_vector( if query.param: core_vector.query_params = query.param - vector_schema = ( - self._schema.vector(query.field_name) if query else self._schema.vectors[0] - ) - - if vector_schema is None: - raise ValueError("No vector field found") + # set FTS query if provided + self._do_build_fts_query(query, core_vector) # set output_fields core_vector.output_fields = ctx.output_fields + vector_schema = None + if query.has_vector() or query.has_id(): + vector_schema = ( + self._schema.vector(query.field_name) + if query + else self._schema.vectors[0] + ) + + if vector_schema is None: + raise ValueError("No vector field found") + # set vector if query.has_vector(): vec_data = query.vector - else: + elif query.has_id(): fetched = collection.Fetch([query.id]) doc = next(iter(fetched.values())) if not doc: return core_vector vec_data = doc.get_any(vector_schema.name, vector_schema.data_type) + else: + return core_vector target_dtype = DTYPE_MAP.get(vector_schema.data_type.value) core_vector.set_vector( @@ -243,13 +260,21 @@ def __init__(self, schema: CollectionSchema): super().__init__(schema) def _do_validate(self, ctx: QueryContext) -> None: - if len(ctx.queries) > 0: - raise ValueError("Collection does not support query with vector or id") + for query in ctx.queries: + if query.has_vector() or query.has_id(): + raise ValueError("Collection does not support query with vector or id") + query._validate() def _do_build( - self, ctx: QueryContext, _collection: _Collection + self, ctx: QueryContext, collection: _Collection ) -> list[_VectorQuery]: - return [self._do_build_query_wo_vector(ctx)] + if len(ctx.queries) == 0: + return [self._do_build_query_wo_vector(ctx)] + # FTS-only branch in _do_build_query_with_vector skips vector resolution. + return [ + self._do_build_query_with_vector(ctx, query, collection) + for query in ctx.queries + ] class SingleVectorQueryExecutor(NoVectorQueryExecutor): diff --git a/python/zvec/model/__init__.py b/python/zvec/model/__init__.py index f193f10bb..7d5b0689b 100644 --- a/python/zvec/model/__init__.py +++ b/python/zvec/model/__init__.py @@ -15,7 +15,7 @@ from .collection import Collection from .doc import Doc -from .param.query import Query, VectorQuery +from .param.query import Fts, Query, VectorQuery from .schema.collection_schema import CollectionSchema from .schema.field_schema import FieldSchema @@ -24,6 +24,7 @@ "CollectionSchema", "Doc", "FieldSchema", + "Fts", "Query", "VectorQuery", ] diff --git a/python/zvec/model/param/__init__.py b/python/zvec/model/param/__init__.py index 5758218d9..05909e90c 100644 --- a/python/zvec/model/param/__init__.py +++ b/python/zvec/model/param/__init__.py @@ -18,6 +18,8 @@ AlterColumnOption, CollectionOption, FlatIndexParam, + FtsIndexParam, + FtsQueryParam, HnswIndexParam, HnswQueryParam, HnswRabitqIndexParam, @@ -36,6 +38,8 @@ "AlterColumnOption", "CollectionOption", "FlatIndexParam", + "FtsIndexParam", + "FtsQueryParam", "HnswIndexParam", "HnswQueryParam", "HnswRabitqIndexParam", diff --git a/python/zvec/model/param/query.py b/python/zvec/model/param/query.py index f14c28509..f2c15ecd2 100644 --- a/python/zvec/model/param/query.py +++ b/python/zvec/model/param/query.py @@ -20,26 +20,42 @@ from ...common import VectorType from . import HnswQueryParam, HnswRabitqQueryParam, IVFQueryParam -__all__ = ["Query", "VectorQuery"] +__all__ = ["Fts", "Query", "VectorQuery"] + + +@dataclass(frozen=True) +class Fts: + """Full-text search query parameters. + + Attributes: + query_string (Optional[str]): FTS query expression + (e.g. '+vector -slow "exact phrase"'). Mutually exclusive with match_string. + match_string (Optional[str]): Natural language match string, + tokenized and combined using the default operator. + Mutually exclusive with query_string. + """ + + query_string: Optional[str] = None + match_string: Optional[str] = None @dataclass(frozen=True) class Query: """Represents a search query for a specific field in a collection. - A `Query` can be constructed using either a document ID (to look up - its vector) or an explicit vector. It may optionally include index-specific - query parameters to control search behavior (e.g., `ef` for HNSW, `nprobe` for IVF). + A `Query` can be constructed for either vector search or full-text search, + but not both simultaneously. - Exactly one of `id` or `vector` should be provided. If both are given, - behavior is implementation-defined (typically `id` takes precedence). + For vector search, provide `id` or `vector` (and optionally `param`). + For FTS, provide `fts`. Attributes: field_name (str): Name of the field to query. id (Optional[str], optional): Document ID to fetch vector from. Default is None. vector (VectorType, optional): Explicit query vector. Default is None. param (Optional[Union[HnswQueryParam, IVFQueryParam]], optional): - Index-specific query parameters. Default is None. + Index-specific query parameters for vector search. Default is None. + fts (Optional[Fts], optional): Full-text search parameters. Default is None. Examples: >>> import zvec @@ -51,12 +67,18 @@ class Query: ... vector=[0.1, 0.2, 0.3], ... param=HnswQueryParam(ef=300) ... ) + >>> # FTS query + >>> q3 = zvec.Query( + ... field_name="content", + ... fts=Fts(match_string="machine learning") + ... ) """ field_name: str id: Optional[str] = None vector: VectorType = None param: Optional[Union[HnswQueryParam, HnswRabitqQueryParam, IVFQueryParam]] = None + fts: Optional[Fts] = None def has_id(self) -> bool: """Check if the query is based on a document ID. @@ -74,11 +96,32 @@ def has_vector(self) -> bool: """ return self.vector is not None and len(self.vector) > 0 + def has_fts(self) -> bool: + """Check if the query contains an FTS (full-text search) condition. + + Returns: + bool: True if `fts` is set with a query_string or match_string. + """ + if self.fts is not None: + return bool(self.fts.query_string) or bool(self.fts.match_string) + return False + def _validate(self) -> None: if self.field_name is None: raise ValueError("Field name cannot be empty") if self.id and self.vector: raise ValueError("Cannot provide both id and vector") + if self.has_fts() and ( + self.has_vector() or self.has_id() or self.param is not None + ): + raise ValueError( + "Cannot combine fts with vector search fields (id/vector/param) in a single Query" + ) + if self.fts is not None and self.fts.query_string and self.fts.match_string: + raise ValueError( + "Cannot provide both query_string and match_string in Fts; " + "they are mutually exclusive" + ) class VectorQuery(Query): diff --git a/python/zvec/zvec.py b/python/zvec/zvec.py index 114fb49c9..9f3e815bb 100644 --- a/python/zvec/zvec.py +++ b/python/zvec/zvec.py @@ -38,7 +38,9 @@ def init( optimize_threads: Optional[int] = None, invert_to_forward_scan_ratio: Optional[float] = None, brute_force_by_keys_ratio: Optional[float] = None, + fts_brute_force_by_keys_ratio: Optional[float] = None, memory_limit_mb: Optional[int] = None, + jieba_dict_dir: Optional[str] = None, ) -> None: """Initialize Zvec with configuration options. @@ -88,11 +90,25 @@ def init( Threshold to use brute-force key lookup over index. Lower → prefer index; higher → prefer brute-force. Range: [0.0, 1.0]. Default: ``0.1``. + fts_brute_force_by_keys_ratio (Optional[float], optional): + Threshold to switch FTS scan from posting-driven to + candidate-driven (brute-force) when the invert filter is + highly selective. Independent from ``brute_force_by_keys_ratio`` + because per-candidate FTS cost is higher. + Range: [0.0, 1.0]. Default: ``0.05``. memory_limit_mb (Optional[int], optional): Soft memory cap in MB. Zvec may throttle or fail operations approaching this limit. If ``None``, inferred from cgroup memory limit * 0.8 (e.g., in Docker). Must be > 0 if provided. + jieba_dict_dir (Optional[str], optional): + Override the default directory containing ``jieba.dict.utf8`` and + ``hmm_model.utf8`` for the jieba FTS tokenizer. When ``None``, the + value previously registered by ``zvec.set_default_jieba_dict_dir`` + (called automatically on ``import zvec`` to point at the wheel's + bundled dict) is preserved. JiebaTokenizer also honors the + ``ZVEC_JIEBA_DICT_DIR`` environment variable and per-field + ``FtsIndexParam.extra_params.jieba_dict_dir`` ahead of this value. Raises: RuntimeError: If Zvec is already initialized. @@ -157,8 +173,12 @@ def init( config_dict["invert_to_forward_scan_ratio"] = invert_to_forward_scan_ratio if brute_force_by_keys_ratio is not None: config_dict["brute_force_by_keys_ratio"] = brute_force_by_keys_ratio + if fts_brute_force_by_keys_ratio is not None: + config_dict["fts_brute_force_by_keys_ratio"] = fts_brute_force_by_keys_ratio if memory_limit_mb is not None: config_dict["memory_limit_mb"] = memory_limit_mb + if jieba_dict_dir is not None: + config_dict["jieba_dict_dir"] = jieba_dict_dir Initialize(config_dict) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a3787dc6b..807c86208 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -138,10 +138,10 @@ target_include_directories(zvec_shared # Strip symbols in release builds to reduce library size if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") if(UNIX AND NOT APPLE) - add_custom_command(TARGET zvec_shared POST_BUILD - COMMAND ${CMAKE_STRIP} $ - COMMENT "Stripping symbols from libzvec.so" - ) + # add_custom_command(TARGET zvec_shared POST_BUILD + # COMMAND ${CMAKE_STRIP} $ + # COMMENT "Stripping symbols from libzvec.so" + # ) elseif(APPLE) add_custom_command(TARGET zvec_shared POST_BUILD COMMAND /usr/bin/strip -x $ diff --git a/src/binding/c/c_api.cc b/src/binding/c/c_api.cc index b23c7ecd8..4a957ace2 100644 --- a/src/binding/c/c_api.cc +++ b/src/binding/c/c_api.cc @@ -627,6 +627,27 @@ float zvec_config_data_get_brute_force_by_keys_ratio( return cpp_config->brute_force_by_keys_ratio; } +zvec_error_code_t zvec_config_data_set_fts_brute_force_by_keys_ratio( + zvec_config_data_t *config, float ratio) { + if (!config) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Config pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *cpp_config = reinterpret_cast(config); + cpp_config->fts_brute_force_by_keys_ratio = ratio; + return ZVEC_OK; +} + +float zvec_config_data_get_fts_brute_force_by_keys_ratio( + const zvec_config_data_t *config) { + if (!config) { + return 0.0f; + } + auto *cpp_config = + reinterpret_cast(config); + return cpp_config->fts_brute_force_by_keys_ratio; +} + zvec_error_code_t zvec_config_data_set_optimize_thread_count( zvec_config_data_t *config, uint32_t thread_count) { if (!config) { @@ -648,6 +669,27 @@ uint32_t zvec_config_data_get_optimize_thread_count( return cpp_config->optimize_thread_count; } +zvec_error_code_t zvec_config_data_set_jieba_dict_dir( + zvec_config_data_t *config, const char *dir) { + if (!config) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Config pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *cpp_config = reinterpret_cast(config); + cpp_config->jieba_dict_dir = (dir != nullptr) ? std::string(dir) : ""; + return ZVEC_OK; +} + +const char *zvec_config_data_get_jieba_dict_dir( + const zvec_config_data_t *config) { + if (!config) { + return ""; + } + auto *cpp_config = + reinterpret_cast(config); + return cpp_config->jieba_dict_dir.c_str(); +} + // ============================================================================= // Initialization and cleanup interface implementation @@ -703,6 +745,18 @@ bool zvec_is_initialized(void) { return g_initialized.load(); } +void zvec_set_default_jieba_dict_dir(const char *dir) { + zvec::GlobalConfig::Instance().set_default_jieba_dict_dir( + (dir != nullptr) ? std::string(dir) : std::string()); +} + +const char *zvec_get_default_jieba_dict_dir(void) { + // Thread-local buffer keeps c_str() valid until the next call on this thread. + thread_local std::string cached; + cached = zvec::GlobalConfig::Instance().jieba_dict_dir(); + return cached.c_str(); +} + // ============================================================================= // Error handling interface implementation // ============================================================================= @@ -879,6 +933,16 @@ static std::shared_ptr convert_c_index_params_to_cpp( ? std::make_shared(*invert_params) : nullptr; } + case zvec::IndexType::FTS: { + auto *fts_params = + dynamic_cast(cpp_params); + // FtsIndexParams is not copy-constructible; rebuild from accessors. + return fts_params ? std::make_shared( + fts_params->tokenizer_name(), + fts_params->filters(), + fts_params->extra_params()) + : nullptr; + } default: return nullptr; } @@ -1300,6 +1364,11 @@ zvec_index_params_t *zvec_index_params_create(zvec_index_type_t index_type) { new zvec::InvertIndexParams(true, // enable_range_optimization false); // enable_extended_wildcard break; + case ZVEC_INDEX_TYPE_FTS: + // Defaults align with FtsIndexParams default ctor: + // tokenizer="standard", filters=["lowercase"], extra="". + cpp_params = new zvec::FtsIndexParams(); + break; case ZVEC_INDEX_TYPE_HNSW: cpp_params = new zvec::HnswIndexParams( @@ -1635,6 +1704,77 @@ zvec_error_code_t zvec_index_params_get_invert_params(const zvec_index_params_t return ZVEC_OK; } +zvec_error_code_t zvec_index_params_set_fts_params( + zvec_index_params_t *params, const char *tokenizer_name, + const zvec_string_array_t *filters, const char *extra_params) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Invalid params or not FTS index type"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *cpp_params = reinterpret_cast(params); + auto *fts_params = dynamic_cast(cpp_params); + if (!fts_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Invalid params or not FTS index type"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + if (tokenizer_name) { + fts_params->set_tokenizer_name(std::string(tokenizer_name)); + } + if (filters) { + std::vector filter_vec; + filter_vec.reserve(filters->count); + for (size_t i = 0; i < filters->count; ++i) { + const auto &item = filters->strings[i]; + filter_vec.emplace_back(item.data ? item.data : "", + item.data ? item.length : 0); + } + fts_params->set_filters(std::move(filter_vec)); + } + if (extra_params) { + fts_params->set_extra_params(std::string(extra_params)); + } + return ZVEC_OK; +} + +zvec_error_code_t zvec_index_params_get_fts_params( + const zvec_index_params_t *params, const char **out_tokenizer_name, + zvec_string_array_t **out_filters, const char **out_extra_params) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Invalid params or not FTS index type"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *cpp_params = reinterpret_cast(params); + auto *fts_params = dynamic_cast(cpp_params); + if (!fts_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Invalid params or not FTS index type"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + if (out_tokenizer_name) { + *out_tokenizer_name = fts_params->tokenizer_name().c_str(); + } + if (out_extra_params) { + *out_extra_params = fts_params->extra_params().c_str(); + } + if (out_filters) { + const auto &filters = fts_params->filters(); + zvec_string_array_t *arr = zvec_string_array_create(filters.size()); + if (!arr) { + SET_LAST_ERROR(ZVEC_ERROR_RESOURCE_EXHAUSTED, + "Failed to allocate filters string array"); + return ZVEC_ERROR_RESOURCE_EXHAUSTED; + } + for (size_t i = 0; i < filters.size(); ++i) { + zvec_string_array_add(arr, i, filters[i].c_str()); + } + *out_filters = arr; + } + return ZVEC_OK; +} + // ============================================================================= // FieldSchema management interface implementation // ============================================================================= @@ -2482,6 +2622,8 @@ const char *zvec_index_type_to_string(zvec_index_type_t index_type) { return "FLAT"; case ZVEC_INDEX_TYPE_INVERT: return "INVERT"; + case ZVEC_INDEX_TYPE_FTS: + return "FTS"; default: return "UNKNOWN_INDEX_TYPE"; } @@ -4837,6 +4979,47 @@ bool zvec_query_params_flat_get_is_using_refiner( return ptr->is_using_refiner(); } +// ============================================================================= +// FtsQueryParams implementation - wrapper around zvec::FtsQueryParams +// ============================================================================= + +zvec_fts_query_params_t *zvec_query_params_fts_create( + const char *default_operator) { + ZVEC_TRY_RETURN_NULL( + "Failed to create FtsQueryParams", + auto *params = new zvec::FtsQueryParams(); + if (default_operator && *default_operator) { + params->set_default_operator(std::string(default_operator)); + } return reinterpret_cast(params);) + return nullptr; +} + +void zvec_query_params_fts_destroy(zvec_fts_query_params_t *params) { + if (params) { + delete reinterpret_cast(params); + } +} + +zvec_error_code_t zvec_query_params_fts_set_default_operator( + zvec_fts_query_params_t *params, const char *default_operator) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "FTS query params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(params); + ptr->set_default_operator(std::string(default_operator ? default_operator + : "")); + return ZVEC_OK; +} + +const char *zvec_query_params_fts_get_default_operator( + const zvec_fts_query_params_t *params) { + if (!params) return nullptr; + auto *ptr = reinterpret_cast(params); + return ptr->default_operator().c_str(); +} + // ============================================================================= // VectorQuery implementation - owns zvec::VectorQuery via raw pointer // ============================================================================= @@ -5079,6 +5262,95 @@ zvec_error_code_t zvec_vector_query_set_flat_params( return ZVEC_OK; } +zvec_error_code_t zvec_vector_query_set_fts_params( + zvec_vector_query_t *query, zvec_fts_query_params_t *fts_params) { + if (!query || !fts_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query or FTS params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *query_ptr = reinterpret_cast(query); + auto *params_ptr = reinterpret_cast(fts_params); + + query_ptr->query_params_.reset(params_ptr); + + return ZVEC_OK; +} + +// ============================================================================= +// Fts payload implementation - wrapper around zvec::Fts (value type) +// ============================================================================= + +zvec_fts_t *zvec_fts_create(void) { + ZVEC_TRY_RETURN_NULL("Failed to create Fts payload", + auto *fts = new zvec::Fts(); + return reinterpret_cast(fts);) + return nullptr; +} + +void zvec_fts_destroy(zvec_fts_t *fts) { + if (fts) { + delete reinterpret_cast(fts); + } +} + +zvec_error_code_t zvec_fts_set_query_string(zvec_fts_t *fts, + const char *query_string) { + if (!fts) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Fts pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(fts); + ptr->query_string_ = query_string ? query_string : ""; + return ZVEC_OK; +} + +zvec_error_code_t zvec_fts_set_match_string(zvec_fts_t *fts, + const char *match_string) { + if (!fts) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Fts pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(fts); + ptr->match_string_ = match_string ? match_string : ""; + return ZVEC_OK; +} + +const char *zvec_fts_get_query_string(const zvec_fts_t *fts) { + if (!fts) return nullptr; + auto *ptr = reinterpret_cast(fts); + return ptr->query_string_.c_str(); +} + +const char *zvec_fts_get_match_string(const zvec_fts_t *fts) { + if (!fts) return nullptr; + auto *ptr = reinterpret_cast(fts); + return ptr->match_string_.c_str(); +} + +zvec_error_code_t zvec_vector_query_set_fts(zvec_vector_query_t *query, + const zvec_fts_t *fts) { + if (!query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Vector query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *query_ptr = reinterpret_cast(query); + if (!fts) { + query_ptr->fts_ = std::nullopt; + } else { + query_ptr->fts_ = *reinterpret_cast(fts); + } + return ZVEC_OK; +} + +const zvec_fts_t *zvec_vector_query_get_fts(const zvec_vector_query_t *query) { + if (!query) return nullptr; + auto *query_ptr = reinterpret_cast(query); + if (!query_ptr->fts_.has_value()) return nullptr; + return reinterpret_cast(&query_ptr->fts_.value()); +} + // ============================================================================= // GroupByVectorQuery implementation - owns zvec::GroupByVectorQuery via raw // pointer diff --git a/src/binding/python/model/common/python_config.cc b/src/binding/python/model/common/python_config.cc index bbcbb5bdb..8abd42184 100644 --- a/src/binding/python/model/common/python_config.cc +++ b/src/binding/python/model/common/python_config.cc @@ -177,6 +177,24 @@ void ZVecPyConfig::Initialize(pybind11::module_ &m) { data.brute_force_by_keys_ratio = static_cast(v); } + // set fts_brute_force_by_keys_ratio + if (has_key(config_dict, "fts_brute_force_by_keys_ratio")) { + auto v = + get_if(config_dict, "fts_brute_force_by_keys_ratio").value(); + if (v < 0.0 || v > 1.0) { + throw py::value_error( + "fts_brute_force_by_keys_ratio must be in [0.0, 1.0]"); + } + data.fts_brute_force_by_keys_ratio = static_cast(v); + } + + // jieba_dict_dir: optional override of the SDK-registered default. + // Empty value is a no-op (Initialize preserves the SDK default). + if (has_key(config_dict, "jieba_dict_dir")) { + data.jieba_dict_dir = + get_if(config_dict, "jieba_dict_dir").value(); + } + // initialize (contains validate) Status status = GlobalConfig::Instance().Initialize(data); if (!status.ok()) { @@ -184,6 +202,21 @@ void ZVecPyConfig::Initialize(pybind11::module_ &m) { } return py::none(); }); + + // Process-wide setter, independent of Initialize(); called by __init__.py + // on import to register the wheel-bundled dict path. + m.def( + "set_default_jieba_dict_dir", + [](const std::string &dir) { + GlobalConfig::Instance().set_default_jieba_dict_dir(dir); + }, + pybind11::arg("dir"), + "Register the process-wide default jieba dict directory."); + + m.def( + "get_default_jieba_dict_dir", + []() -> std::string { return GlobalConfig::Instance().jieba_dict_dir(); }, + "Read the currently registered default jieba dict directory."); } diff --git a/src/binding/python/model/param/python_param.cc b/src/binding/python/model/param/python_param.cc index 268246cd6..cc59fb404 100644 --- a/src/binding/python/model/param/python_param.cc +++ b/src/binding/python/model/param/python_param.cc @@ -35,6 +35,8 @@ static std::string index_type_to_string(const IndexType type) { return "HNSW_RABITQ"; case IndexType::VAMANA: return "VAMANA"; + case IndexType::FTS: + return "FTS"; default: return "UNDEFINED"; } @@ -251,6 +253,88 @@ Note: Prefix search is always enabled regardless of this setting. t[1].cast()); })); + // binding fts index params + py::class_> + fts_index_params(m, "FtsIndexParam", R"pbdoc( +Parameters for configuring a full-text search (FTS) index. + +Controls the tokenizer pipeline used during indexing and querying. + +Attributes: + type (IndexType): Always ``IndexType.FTS``. + tokenizer_name (str): Name of the tokenizer (e.g., "standard", "jieba"). + Default is "standard". + filters (list[str]): List of token filter names applied after tokenization. + Default is ["lowercase"]. + extra_params (str): Additional parameters passed to the tokenizer. + Default is "". + +Examples: + >>> params = FtsIndexParam(tokenizer_name="jieba", filters=["lowercase"]) + >>> print(params.tokenizer_name) + jieba +)pbdoc"); + fts_index_params + .def(py::init, std::string>(), + py::arg("tokenizer_name") = "standard", + py::arg("filters") = std::vector{"lowercase"}, + py::arg("extra_params") = "", + R"pbdoc( +Constructs an FtsIndexParam instance. + +Args: + tokenizer_name (str, optional): Tokenizer name. Defaults to "standard". + filters (list[str], optional): Token filter names. Defaults to ["lowercase"]. + extra_params (str, optional): Extra tokenizer parameters. Defaults to "". +)pbdoc") + .def_property_readonly("tokenizer_name", &FtsIndexParams::tokenizer_name, + "str: Name of the tokenizer.") + .def_property_readonly("filters", &FtsIndexParams::filters, + "list[str]: Token filter names.") + .def_property_readonly("extra_params", &FtsIndexParams::extra_params, + "str: Additional tokenizer parameters.") + .def( + "to_dict", + [](const FtsIndexParams &self) -> py::dict { + py::dict dict; + dict["type"] = index_type_to_string(self.type()); + dict["tokenizer_name"] = self.tokenizer_name(); + dict["filters"] = self.filters(); + dict["extra_params"] = self.extra_params(); + return dict; + }, + "Convert to dictionary with all fields") + .def("__repr__", + [](const FtsIndexParams &self) -> std::string { + std::string filters_str = "["; + for (size_t i = 0; i < self.filters().size(); ++i) { + if (i > 0) { + filters_str += ","; + } + filters_str += "\"" + self.filters()[i] + "\""; + } + filters_str += "]"; + return "{" + "\"type\":\"" + + index_type_to_string(self.type()) + + "\", \"tokenizer_name\":\"" + self.tokenizer_name() + + "\", \"filters\":" + filters_str + ", \"extra_params\":\"" + + self.extra_params() + "\"}"; + }) + .def(py::pickle( + [](const FtsIndexParams &self) { + return py::make_tuple(self.tokenizer_name(), self.filters(), + self.extra_params()); + }, + [](py::tuple t) { + if (t.size() != 3) { + throw std::runtime_error("Invalid state for FtsIndexParams"); + } + return std::make_shared( + t[0].cast(), t[1].cast>(), + t[2].cast()); + })); + // binding base vector index params py::class_> vector_params(m, "VectorIndexParam", R"pbdoc( @@ -1102,6 +1186,64 @@ Constructs a VamanaQueryParam instance. obj->set_is_using_refiner(t[3].cast()); return obj; })); + + // binding fts query params + py::class_> + fts_query_params(m, "FtsQueryParam", R"pbdoc( +Query parameters for full-text search (FTS) index. + +Controls the default boolean operator used to combine adjacent bare terms +in a query string. + +Attributes: + type (IndexType): Always ``IndexType.FTS``. + default_operator (str): Default boolean operator for adjacent bare terms. + Supported values (case-insensitive): "OR" (default), "AND". + +Examples: + >>> params = FtsQueryParam(default_operator="AND") + >>> print(params.default_operator) + AND +)pbdoc"); + fts_query_params + .def(py::init([](const std::string &default_operator) { + auto params = std::make_shared(); + if (!default_operator.empty()) { + params->set_default_operator(default_operator); + } + return params; + }), + py::arg("default_operator") = "", + R"pbdoc( +Constructs an FtsQueryParam instance. + +Args: + default_operator (str, optional): Default boolean operator for adjacent + bare terms. Supported: "OR", "AND". Defaults to "" (uses engine default). +)pbdoc") + .def_property_readonly("default_operator", + &FtsQueryParams::default_operator, + "str: Default boolean operator for bare terms.") + .def("__repr__", + [](const FtsQueryParams &self) -> std::string { + return "{" + "\"type\":\"" + + index_type_to_string(self.type()) + + "\", \"default_operator\":\"" + self.default_operator() + + "\"}"; + }) + .def(py::pickle( + [](const FtsQueryParams &self) { + return py::make_tuple(self.default_operator()); + }, + [](py::tuple t) { + if (t.size() != 1) { + throw std::runtime_error("Invalid state for FtsQueryParams"); + } + auto obj = std::make_shared(); + obj->set_default_operator(t[0].cast()); + return obj; + })); } void ZVecPyParams::bind_options(py::module_ &m) { // binding collection options @@ -1372,6 +1514,24 @@ Constructs an AlterColumnOption instance. } void ZVecPyParams::bind_vector_query(py::module_ &m) { + // bind Fts + py::class_(m, "_Fts") + .def(py::init<>()) + .def_readwrite("query_string", &Fts::query_string_) + .def_readwrite("match_string", &Fts::match_string_) + .def(py::pickle( + [](const Fts &self) { + return py::make_tuple(self.query_string_, self.match_string_); + }, + [](py::tuple t) { + if (t.size() != 2) + throw std::runtime_error("Invalid pickle data for Fts"); + Fts obj{}; + obj.query_string_ = t[0].cast(); + obj.match_string_ = t[1].cast(); + return obj; + })); + py::class_(m, "_VectorQuery") .def(py::init<>()) // properties @@ -1381,6 +1541,21 @@ void ZVecPyParams::bind_vector_query(py::module_ &m) { .def_readwrite("include_vector", &VectorQuery::include_vector_) .def_readwrite("query_params", &VectorQuery::query_params_) .def_readwrite("output_fields", &VectorQuery::output_fields_) + .def_property( + "fts", + [](const VectorQuery &self) -> py::object { + if (self.fts_.has_value()) { + return py::cast(self.fts_.value()); + } + return py::none(); + }, + [](VectorQuery &self, const py::object &obj) { + if (obj.is_none()) { + self.fts_ = std::nullopt; + } else { + self.fts_ = obj.cast(); + } + }) // vector .def("set_vector", [](VectorQuery &self, const FieldSchema &field_schema, @@ -1588,11 +1763,16 @@ void ZVecPyParams::bind_vector_query(py::module_ &m) { return py::make_tuple( self.topk_, self.field_name_, self.query_vector_, self.query_sparse_indices_, self.query_sparse_values_, - self.filter_, self.include_vector_, self.output_fields_, - self.query_params_ ? py::cast(self.query_params_) : py::none()); + self.filter_, self.include_vector_, + self.output_fields_.has_value() + ? py::cast(self.output_fields_.value()) + : py::none(), + self.query_params_ ? py::cast(self.query_params_) : py::none(), + self.fts_.has_value() ? py::cast(self.fts_.value()) + : py::none()); }, [](py::tuple t) { - if (t.size() != 9) + if (t.size() != 10) throw std::runtime_error("Invalid pickle data for VectorQuery"); VectorQuery obj{}; @@ -1603,11 +1783,16 @@ void ZVecPyParams::bind_vector_query(py::module_ &m) { obj.query_sparse_values_ = t[4].cast(); obj.filter_ = t[5].cast(); obj.include_vector_ = t[6].cast(); - obj.output_fields_ = t[7].cast>(); + if (!t[7].is_none()) { + obj.output_fields_ = t[7].cast>(); + } if (!t[8].is_none()) { obj.query_params_ = t[8].cast(); } + if (!t[9].is_none()) { + obj.fts_ = t[9].cast(); + } return obj; })); } diff --git a/src/db/CMakeLists.txt b/src/db/CMakeLists.txt index b2689278a..4a756a880 100644 --- a/src/db/CMakeLists.txt +++ b/src/db/CMakeLists.txt @@ -13,6 +13,23 @@ cc_directory(sqlengine) file(GLOB_RECURSE ALL_DB_SRCS *.cc *.c *.h) +# Ensure bitpacked_simd_sse41.cc is compiled with SSE4.1 flag and +# bitpacked_simd_avx2.cc with AVX2 flag in the packed zvec_db target as well +# (they are also compiled separately in zvec_index). +if(NOT ANDROID AND AUTO_DETECT_ARCH) + if(HOST_ARCH MATCHES "^(x86|x64)$") + setup_compiler_march_for_x86(_DB_MARCH_SSE _DB_MARCH_AVX2 _DB_MARCH_AVX512 _DB_MARCH_AVX512FP16) + set_source_files_properties( + ${CMAKE_CURRENT_SOURCE_DIR}/index/column/fts_column/posting/bitpacked_simd_sse41.cc + PROPERTIES COMPILE_FLAGS "${_DB_MARCH_SSE}" + ) + set_source_files_properties( + ${CMAKE_CURRENT_SOURCE_DIR}/index/column/fts_column/posting/bitpacked_simd_avx2.cc + PROPERTIES COMPILE_FLAGS "${_DB_MARCH_AVX2}" + ) + endif() +endif() + cc_library( NAME zvec_db STATIC STRICT SRCS_NO_GLOB PACKED SRCS ${ALL_DB_SRCS} ${CMAKE_CURRENT_BINARY_DIR}/proto/zvec.pb.cc @@ -26,6 +43,8 @@ cc_library( rocksdb antlr4 libprotobuf + FastPFOR + cppjieba Arrow::arrow_static Arrow::arrow_compute Arrow::arrow_dataset diff --git a/src/db/collection.cc b/src/db/collection.cc index 36f9a7420..d0c3ca667 100644 --- a/src/db/collection.cc +++ b/src/db/collection.cc @@ -1585,8 +1585,13 @@ Result CollectionImpl::Query(const VectorQuery &query) const { CHECK_DESTROY_RETURN_STATUS_EXPECTED(destroyed_, false); VectorQuery sanitized = query; - auto s = sanitized.validate_and_sanitize( - schema_->get_vector_field(sanitized.field_name_)); + // When field_name_ is set, use get_field to retrieve the schema uniformly. + // validate_and_sanitize checks that the field type matches the query type + // (FTS query requires an FTS field, vector query requires a vector field). + const FieldSchema *field_schema = + sanitized.field_name_.empty() ? nullptr + : schema_->get_field(sanitized.field_name_); + auto s = sanitized.validate_and_sanitize(field_schema); CHECK_RETURN_STATUS_EXPECTED(s); auto segments = get_all_segments(); diff --git a/src/db/common/config.cc b/src/db/common/config.cc index 5938f5375..57eaae812 100644 --- a/src/db/common/config.cc +++ b/src/db/common/config.cc @@ -37,7 +37,9 @@ GlobalConfig::ConfigData::ConfigData() query_thread_count(CgroupUtil::getCpuLimit()), invert_to_forward_scan_ratio(0.9), brute_force_by_keys_ratio(0.1), - optimize_thread_count(CgroupUtil::getCpuLimit()) {} + fts_brute_force_by_keys_ratio(0.05), + optimize_thread_count(CgroupUtil::getCpuLimit()), + jieba_dict_dir() {} Status GlobalConfig::Validate(const ConfigData &config) const { if (config.memory_limit_bytes < MIN_MEMORY_LIMIT_BYTES) { @@ -69,6 +71,13 @@ Status GlobalConfig::Validate(const ConfigData &config) const { "brute_force_by_keys_ratio must be between 0 and 1"); } + // Validate fts_brute_force_by_keys_ratio (should be between 0 and 1) + if (config.fts_brute_force_by_keys_ratio < 0.0f || + config.fts_brute_force_by_keys_ratio > 1.0f) { + return Status::InvalidArgument( + "fts_brute_force_by_keys_ratio must be between 0 and 1"); + } + // Validate optimize thread count if (config.optimize_thread_count == 0) { return Status::InvalidArgument( @@ -116,7 +125,16 @@ Status GlobalConfig::Initialize(const ConfigData &config) { auto s = Validate(config); CHECK_RETURN_STATUS(s); - config_ = config; + // Preserve the SDK-set jieba_dict_dir when caller didn't specify one. + // Lock spans the bulk assign so readers never see a half-written string. + { + std::lock_guard lk(mutex_); + std::string final_jieba = config.jieba_dict_dir.empty() + ? config_.jieba_dict_dir + : config.jieba_dict_dir; + config_ = config; + config_.jieba_dict_dir = std::move(final_jieba); + } s = LogUtil::Init(log_dir(), log_file_basename(), int(log_level()), log_type(), log_file_size(), log_overdue_days()); @@ -131,6 +149,16 @@ Status GlobalConfig::Initialize(const ConfigData &config) { return Status::OK(); } +void GlobalConfig::set_default_jieba_dict_dir(const std::string &dir) { + std::lock_guard lk(mutex_); + config_.jieba_dict_dir = dir; +} + +std::string GlobalConfig::jieba_dict_dir() const { + std::lock_guard lk(mutex_); + return config_.jieba_dict_dir; +} + uint64_t GlobalConfig::memory_limit_bytes() const noexcept { return config_.memory_limit_bytes; } diff --git a/src/db/common/constants.h b/src/db/common/constants.h index f987aa289..3aa0512a5 100644 --- a/src/db/common/constants.h +++ b/src/db/common/constants.h @@ -80,5 +80,11 @@ const std::string INVERT_KEY_SEALED{"$ZVEC$SEALED"}; const uint32_t INVERT_ID_LIST_SIZE_THRESHOLD = 3; +// FTS (Full-Text Search) column family name suffixes and shared CF name +constexpr const char *kFtsPositionsSuffix = "$POSITIONS"; +constexpr const char *kFtsTfSuffix = "$TF"; +constexpr const char *kFtsMaxTfSuffix = "$MAX_TF"; +constexpr const char *kFtsDocLenSuffix = "$DOC_LEN"; +constexpr const char *kFtsStatCfName = "$FTS_STAT"; } // namespace zvec diff --git a/src/db/common/file_helper.h b/src/db/common/file_helper.h index 065c80bd7..c983f4a86 100644 --- a/src/db/common/file_helper.h +++ b/src/db/common/file_helper.h @@ -139,6 +139,16 @@ class FileHelper { ailego::StringHelper::Concat("scalar.index.", block_id, ".rocksdb")); } + // e.g.: **/seg1/fts.rocksdb + static const std::string MakeFtsIndexPath(const std::string &path, + uint32_t seg_id) { + return ailego::FileHelper::PathJoin(path, seg_id, "fts.rocksdb"); + } + + static const std::string MakeFtsIndexPath(const std::string &seg_path) { + return ailego::FileHelper::PathJoin(seg_path, "fts.rocksdb"); + } + static const std::string MakeVectorIndexPath(const std::string &path, const std::string &column, uint32_t seg_id, diff --git a/src/db/common/rocksdb_context.cc b/src/db/common/rocksdb_context.cc index 42867cc7e..4bad92793 100644 --- a/src/db/common/rocksdb_context.cc +++ b/src/db/common/rocksdb_context.cc @@ -15,6 +15,8 @@ #include "rocksdb_context.h" #include +#include +#include #include #include #include @@ -27,39 +29,14 @@ namespace zvec { Status RocksdbContext::create( const std::string &db_path, std::shared_ptr merge_op) { - std::lock_guard lock(mutex_); - - if (db_) { - LOG_ERROR("RocksDB[%s] is already opened", db_path_.c_str()); - return Status::PermissionDenied(); - } - - if (auto s = validate_and_set_db_path(db_path, false); !s.ok()) { - return s; - } - - create_opts_.create_if_missing = true; - prepare_options(merge_op); - - // Open RocksDB - rocksdb::DB *db; - if (auto s = rocksdb::DB::Open(create_opts_, db_path, &db); !s.ok()) { - LOG_ERROR("Failed to create RocksDB[%s], code[%d], reason[%s]", - db_path.c_str(), s.code(), s.ToString().c_str()); - return Status::InternalError(); - } - - db_.reset(db); - read_only_ = false; - write_opts_.disableWAL = true; - LOG_DEBUG("Created RocksDB[%s]", db_path.c_str()); - return Status::OK(); + return create(Args{db_path, {}, std::move(merge_op), {}}); } -Status RocksdbContext::create( - const std::string &db_path, const std::vector &column_names, - std::shared_ptr merge_op) { +Status RocksdbContext::create(Args args) { + per_cf_merge_ops_ = std::move(args.per_cf_merge_ops); + enable_hash_skiplist_ = args.enable_hash_skiplist; + std::lock_guard lock(mutex_); if (db_) { @@ -67,26 +44,24 @@ Status RocksdbContext::create( return Status::PermissionDenied(); } - if (auto s = validate_and_set_db_path(db_path, false); !s.ok()) { + if (auto s = validate_and_set_db_path(args.db_path, false); !s.ok()) { return s; } create_opts_.create_if_missing = true; - prepare_options(merge_op); + prepare_options(std::move(args.merge_op)); - // Open RocksDB rocksdb::DB *db; - rocksdb::Status s = rocksdb::DB::Open(create_opts_, db_path, &db); + rocksdb::Status s = rocksdb::DB::Open(create_opts_, args.db_path, &db); if (!s.ok()) { LOG_ERROR("Failed to create RocksDB[%s], code[%d], reason[%s]", - db_path.c_str(), s.code(), s.ToString().c_str()); + args.db_path.c_str(), s.code(), s.ToString().c_str()); return Status::InternalError(); } db_.reset(db); - // Create column families bool has_default = false; - for (auto const &column_name : column_names) { + for (const auto &column_name : args.column_names) { if (column_name == rocksdb::kDefaultColumnFamilyName) { cf_handles_.push_back(db->DefaultColumnFamily()); has_default = true; @@ -94,10 +69,14 @@ Status RocksdbContext::create( } rocksdb::ColumnFamilyHandle *cf_handle{nullptr}; rocksdb::ColumnFamilyOptions cf_options(create_opts_); + auto it = per_cf_merge_ops_.find(column_name); + if (it != per_cf_merge_ops_.end() && it->second) { + cf_options.merge_operator = it->second; + } s = db->CreateColumnFamily(cf_options, column_name, &cf_handle); if (!s.ok()) { LOG_ERROR("Failed to create cf[%s] in RocksDB[%s], code[%d], reason[%s]", - column_name.c_str(), db_path.c_str(), s.code(), + column_name.c_str(), args.db_path.c_str(), s.code(), s.ToString().c_str()); delete_cf_handles(); db->Close(); @@ -112,53 +91,28 @@ Status RocksdbContext::create( read_only_ = false; write_opts_.disableWAL = true; - LOG_DEBUG("Created RocksDB[%s]", db_path.c_str()); + LOG_DEBUG("Created RocksDB[%s] with Args", args.db_path.c_str()); return Status::OK(); } -Status RocksdbContext::open(const std::string &db_path, bool read_only, - std::shared_ptr merge_op) { - std::lock_guard lock(mutex_); - - if (db_) { - LOG_ERROR("RocksDB[%s] is already opened", db_path_.c_str()); - return Status::PermissionDenied(); - } - - if (auto s = validate_and_set_db_path(db_path, true); !s.ok()) { - return s; - } - - create_opts_.create_if_missing = false; - prepare_options(merge_op); +Status RocksdbContext::create( + const std::string &db_path, const std::vector &column_names, + std::shared_ptr merge_op) { + return create(Args{db_path, column_names, std::move(merge_op), {}}); +} - // Open RocksDB - rocksdb::DB *db; - rocksdb::Status s; - if (read_only) { - s = rocksdb::DB::OpenForReadOnly(create_opts_, db_path, &db); - } else { - s = rocksdb::DB::Open(create_opts_, db_path, &db); - } - if (!s.ok()) { - LOG_ERROR("Failed to open RocksDB[%s], code[%d], reason[%s]", - db_path.c_str(), s.code(), s.ToString().c_str()); - return Status::InternalError(); - } - db_.reset(db); - read_only_ = read_only; - write_opts_.disableWAL = true; - LOG_DEBUG("Opened RocksDB[%s]", db_path.c_str()); - return Status::OK(); +Status RocksdbContext::open(const std::string &db_path, bool read_only, + std::shared_ptr merge_op) { + return open(Args{db_path, {}, std::move(merge_op), {}}, read_only); } -Status RocksdbContext::open(const std::string &db_path, - const std::vector &column_names, - bool read_only, - std::shared_ptr merge_op) { +Status RocksdbContext::open(Args args, bool read_only) { + per_cf_merge_ops_ = std::move(args.per_cf_merge_ops); + enable_hash_skiplist_ = args.enable_hash_skiplist; + std::lock_guard lock(mutex_); if (db_) { @@ -166,36 +120,44 @@ Status RocksdbContext::open(const std::string &db_path, return Status::PermissionDenied(); } - if (auto s = validate_and_set_db_path(db_path, true); !s.ok()) { + if (auto s = validate_and_set_db_path(args.db_path, true); !s.ok()) { return s; } create_opts_.create_if_missing = false; - prepare_options(merge_op); + prepare_options(std::move(args.merge_op)); - // Set up column families rocksdb::Status s; std::vector existing_cf_names{}; std::vector cf_descriptors{}; - s = rocksdb::DB::ListColumnFamilies(create_opts_, db_path, + s = rocksdb::DB::ListColumnFamilies(create_opts_, args.db_path, &existing_cf_names); if (!s.ok()) { LOG_ERROR("Failed to list cf in RocksDB[%s], code[%d], reason[%s]", - db_path.c_str(), s.code(), s.ToString().c_str()); + args.db_path.c_str(), s.code(), s.ToString().c_str()); return Status::InternalError(); } - rocksdb::ColumnFamilyOptions cf_options(create_opts_); - if (column_names.empty()) { // Get all column families from DB - for (auto const &column_name : existing_cf_names) { - cf_descriptors.emplace_back(column_name, cf_options); + + auto make_cf_options = [&](const std::string &cf_name) { + rocksdb::ColumnFamilyOptions cf_options(create_opts_); + auto it = per_cf_merge_ops_.find(cf_name); + if (it != per_cf_merge_ops_.end() && it->second) { + cf_options.merge_operator = it->second; + } + return cf_options; + }; + + if (args.column_names.empty()) { + for (const auto &column_name : existing_cf_names) { + cf_descriptors.emplace_back(column_name, make_cf_options(column_name)); } } else { bool has_default = false; - for (const auto &column_name : column_names) { + for (const auto &column_name : args.column_names) { if (std::find(existing_cf_names.begin(), existing_cf_names.end(), column_name) == existing_cf_names.end()) { - LOG_ERROR("Column family[%s] does not exist in RocksDB[%s]", - column_name.c_str(), db_path.c_str()); + LOG_WARN("Column family[%s] does not exist in RocksDB[%s]", + column_name.c_str(), args.db_path.c_str()); return Status::InvalidArgument(); } if (column_name == rocksdb::kDefaultColumnFamilyName) { @@ -203,43 +165,51 @@ Status RocksdbContext::open(const std::string &db_path, } } if (read_only) { - for (const auto &column_name : column_names) { - cf_descriptors.emplace_back(column_name, cf_options); + for (const auto &column_name : args.column_names) { + cf_descriptors.emplace_back(column_name, make_cf_options(column_name)); } if (!has_default) { - cf_descriptors.emplace_back(rocksdb::kDefaultColumnFamilyName, - cf_options); + cf_descriptors.emplace_back( + rocksdb::kDefaultColumnFamilyName, + make_cf_options(rocksdb::kDefaultColumnFamilyName)); } - } else { // Rocksdb must be opened with all column families in write mode - for (auto const &column_name : existing_cf_names) { - cf_descriptors.emplace_back(column_name, cf_options); + } else { + for (const auto &column_name : existing_cf_names) { + cf_descriptors.emplace_back(column_name, make_cf_options(column_name)); } } } - // Open RocksDB rocksdb::DB *db; if (read_only) { - s = rocksdb::DB::OpenForReadOnly(create_opts_, db_path, cf_descriptors, + s = rocksdb::DB::OpenForReadOnly(create_opts_, args.db_path, cf_descriptors, &cf_handles_, &db); } else { - s = rocksdb::DB::Open(create_opts_, db_path, cf_descriptors, &cf_handles_, - &db); + s = rocksdb::DB::Open(create_opts_, args.db_path, cf_descriptors, + &cf_handles_, &db); } if (!s.ok()) { LOG_ERROR("Failed to open RocksDB[%s], code[%d], reason[%s]", - db_path.c_str(), s.code(), s.ToString().c_str()); + args.db_path.c_str(), s.code(), s.ToString().c_str()); return Status::InternalError(); } db_.reset(db); read_only_ = read_only; write_opts_.disableWAL = true; - LOG_DEBUG("Opened RocksDB[%s]", db_path.c_str()); + LOG_DEBUG("Opened RocksDB[%s] with Args", args.db_path.c_str()); return Status::OK(); } +Status RocksdbContext::open(const std::string &db_path, + const std::vector &column_names, + bool read_only, + std::shared_ptr merge_op) { + return open(Args{db_path, column_names, std::move(merge_op), {}}, read_only); +} + + Status RocksdbContext::validate_and_set_db_path(const std::string &db_path, bool should_exist) { if (db_path.empty()) { @@ -321,6 +291,18 @@ void RocksdbContext::prepare_options( // Disable direct reads (use buffered I/O instead) create_opts_.use_direct_reads = false; + + // Hash skip list memtable for prefix-based lookups + if (enable_hash_skiplist_) { + create_opts_.prefix_extractor.reset(rocksdb::NewCappedPrefixTransform(8)); + create_opts_.memtable_factory.reset(rocksdb::NewHashSkipListRepFactory( + 1000000, // bucket_count + 4, // skiplist_height + 4 // skiplist_branching_factor + )); + create_opts_.allow_concurrent_memtable_write = false; + read_opts_.total_order_seek = true; + } } @@ -443,8 +425,13 @@ Status RocksdbContext::create_cf(const std::string &cf_name) { } rocksdb::ColumnFamilyHandle *cf_handle{nullptr}; - auto s = db_->CreateColumnFamily(rocksdb::ColumnFamilyOptions(create_opts_), - cf_name, &cf_handle); + rocksdb::ColumnFamilyOptions cf_options(create_opts_); + // Apply per-CF merge operator if one was registered for this CF name + auto it = per_cf_merge_ops_.find(cf_name); + if (it != per_cf_merge_ops_.end() && it->second) { + cf_options.merge_operator = it->second; + } + auto s = db_->CreateColumnFamily(cf_options, cf_name, &cf_handle); if (s.ok()) { cf_handles_.push_back(cf_handle); LOG_DEBUG("Created cf[%s] in RocksDB[%s]", cf_name.c_str(), @@ -590,6 +577,4 @@ size_t RocksdbContext::count() { return 0; } } - - } // namespace zvec \ No newline at end of file diff --git a/src/db/common/rocksdb_context.h b/src/db/common/rocksdb_context.h index 302d7ca8c..d47d90245 100644 --- a/src/db/common/rocksdb_context.h +++ b/src/db/common/rocksdb_context.h @@ -16,7 +16,12 @@ #pragma once +#include +#include +#include +#include #include +#include #include #include @@ -27,9 +32,18 @@ namespace zvec { // A very thin wrapper around RocksDB struct RocksdbContext { public: + struct Args { + std::string db_path; + std::vector column_names; + std::shared_ptr merge_op; + std::unordered_map> + per_cf_merge_ops; + bool enable_hash_skiplist = false; + }; std::unique_ptr db_{nullptr}; std::string db_path_; bool read_only_; + bool enable_hash_skiplist_{false}; std::vector cf_handles_; rocksdb::Options create_opts_; rocksdb::WriteOptions write_opts_; @@ -37,6 +51,9 @@ struct RocksdbContext { rocksdb::FlushOptions flush_opts_; rocksdb::CompactRangeOptions compact_range_opts_; std::mutex mutex_; + // Per-CF merge operators (keyed by CF name) + std::unordered_map> + per_cf_merge_ops_; public: @@ -79,7 +96,7 @@ struct RocksdbContext { rocksdb::ColumnFamilyHandle *get_cf(const std::string &cf_name); - // Create a column family + // Create a column family (uses per_cf_merge_ops_ if set for cf_name) Status create_cf(const std::string &cf_name); @@ -103,6 +120,13 @@ struct RocksdbContext { size_t count(); + // Create a Rocksdb instance from Args + Status create(Args args); + + // Open an existing Rocksdb instance from Args + Status open(Args args, bool read_only); + + private: using FILE = ailego::File; diff --git a/src/db/index/CMakeLists.txt b/src/db/index/CMakeLists.txt index 4420050e6..d4efc32c9 100644 --- a/src/db/index/CMakeLists.txt +++ b/src/db/index/CMakeLists.txt @@ -1,9 +1,25 @@ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) +if(NOT ANDROID AND AUTO_DETECT_ARCH) + if (HOST_ARCH MATCHES "^(x86|x64)$") + setup_compiler_march_for_x86(INDEX_MARCH_FLAG_SSE INDEX_MARCH_FLAG_AVX2 INDEX_MARCH_FLAG_AVX512 INDEX_MARCH_FLAG_AVX512FP16) + set_source_files_properties( + ${CMAKE_CURRENT_SOURCE_DIR}/column/fts_column/posting/bitpacked_simd_sse41.cc + PROPERTIES + COMPILE_FLAGS "${INDEX_MARCH_FLAG_SSE}" + ) + set_source_files_properties( + ${CMAKE_CURRENT_SOURCE_DIR}/column/fts_column/posting/bitpacked_simd_avx2.cc + PROPERTIES + COMPILE_FLAGS "${INDEX_MARCH_FLAG_AVX2}" + ) + endif() +endif() + cc_library( NAME zvec_index STATIC STRICT - SRCS *.cc segment/*.cc column/vector_column/*.cc column/inverted_column/*.cc storage/*.cc storage/wal/*.cc common/*.cc + SRCS *.cc segment/*.cc column/vector_column/*.cc column/inverted_column/*.cc column/fts_column/*.cc column/fts_column/tokenizer/*.cc column/fts_column/posting/*.cc column/fts_column/iterator/*.cc storage/*.cc storage/wal/*.cc common/*.cc LIBS zvec_common zvec_proto rocksdb @@ -11,6 +27,8 @@ cc_library( Arrow::arrow_static Arrow::arrow_compute Arrow::arrow_dataset + cppjieba + FastPFOR INCS . ${PROJECT_ROOT_DIR}/src VERSION "${PROXIMA_ZVEC_VERSION}" ) diff --git a/src/db/index/column/fts_column/FtsLexer.g4 b/src/db/index/column/fts_column/FtsLexer.g4 new file mode 100644 index 000000000..1456e4ba5 --- /dev/null +++ b/src/db/index/column/fts_column/FtsLexer.g4 @@ -0,0 +1,59 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +lexer grammar FtsLexer; + +// ── Boolean operators ──────────────────────────────────────────────────────── +OR : [Oo][Rr]; +AND : [Aa][Nn][Dd]; +NOT : [Nn][Oo][Tt]; + +// ── Modifier prefixes ──────────────────────────────────────────────────────── +PLUS_SIGN: '+'; +MINUS_SIGN: '-'; + +COLON: ':'; +CARET: '^'; + +// ── Grouping ───────────────────────────────────────────────────────────────── +LP: '('; +RP: ')'; + +// ── Quoted strings (phrase queries) ────────────────────────────────────────── +DQUOTA_STRING + : '"' (~["\\\r\n] | '\\' .)* '"' + ; + + +fragment ASCII_ALNUM : [A-Za-z0-9_]; +fragment ESCAPED_CHAR + : '\\' [-+=&|!(){}[\]^"~*?:\\/] + ; +fragment UNI_CHAR : [\u0080-\uFFFF]; +fragment TERM_START : ASCII_ALNUM | UNI_CHAR; +fragment TERM_BODY : ASCII_ALNUM | UNI_CHAR | [._#/%\-'@] | ESCAPED_CHAR; + +// Matches sequences of letters, digits, underscores and hyphens that start +// with a letter or underscore (same as the original SQLLexer REGULAR_ID). +REGULAR_ID: [A-Za-z_] [A-Za-z0-9_\-]*; + +NUMBER: [0-9]+ ('.' [0-9]+)?; + +// Generic term +TERM: TERM_START TERM_BODY*; + +// ── Whitespace (skip) ───────────────────────────────────────────────────────── +SPACES: [ \t\r\n]+ -> skip; + +DEFAULT: . ; diff --git a/src/db/index/column/fts_column/FtsParser.g4 b/src/db/index/column/fts_column/FtsParser.g4 new file mode 100644 index 000000000..82613748e --- /dev/null +++ b/src/db/index/column/fts_column/FtsParser.g4 @@ -0,0 +1,92 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +parser grammar FtsParser; + +options { tokenVocab = FtsLexer; } + +// ── Entry point ─────────────────────────────────────────────────────────────── +fts_query_unit + : fts_or_expr EOF + ; + +// ── OR (lowest precedence) ──────────────────────────────────────────────────── +fts_or_expr + : fts_and_expr (OR fts_and_expr)* + ; + +// ── AND / NOT (same precedence) ────────────────────────────────────────────── +// `a NOT b` is the binary `a AND NOT b` operator: documents matching `a` +// excluding those matching `b`. The explicit form `a AND NOT b` is also +// accepted for readability; semantically it is identical to `a NOT b`. +fts_and_expr + : fts_seq_expr ((AND NOT? | NOT) fts_seq_expr)* + ; + +// ── Implicit adjacency ──────────────────────────────────────────────────────── +// Adjacent atoms without an explicit operator are grouped together; the +// builder treats them as an implicit OR (same behaviour as the original SQL +// parser). +fts_seq_expr + : fts_unary+ + ; + +// ── Unary modifier ──────────────────────────────────────────────────────────── +// NOT is *not* a unary modifier here — it is consumed by fts_and_expr above +// as a binary operator. Unary modifiers are limited to `+` (must) and `-` +// (must_not). +fts_unary + : PLUS_SIGN fts_atom # must_atom + | MINUS_SIGN fts_atom # must_not_atom + | fts_atom # plain_atom + ; + +// ── Atom: optional field prefix + primary + optional boost ─────────────────── +fts_atom + : fts_field_prefix? fts_primary fts_boost? + ; + +// ── Field prefix: REGULAR_ID ':' ───────────────────────────────────────────── +fts_field_prefix + : REGULAR_ID COLON + ; + +// ── Primary: term | phrase | parenthesised sub-expression ──────────────────── +fts_primary + : fts_term + | fts_phrase + | LP fts_or_expr RP + ; + +// ── Boost: '^' NUMBER ──────────────────────────────────────────────────────── +fts_boost + : CARET NUMBER + ; + +fts_natural_term + : DEFAULT+ // One or more default characters forming a natural language term + ; + +// ── Term: identifier, number, or generic token ─────────────────────────────── +fts_term + : TERM + | REGULAR_ID + | NUMBER + | fts_natural_term + ; + +// ── Phrase: double-quoted string ───────────────────────────────────────────── +fts_phrase + : DQUOTA_STRING + ; diff --git a/src/db/index/column/fts_column/bm25_scorer.cc b/src/db/index/column/fts_column/bm25_scorer.cc new file mode 100644 index 000000000..ed8f34fd3 --- /dev/null +++ b/src/db/index/column/fts_column/bm25_scorer.cc @@ -0,0 +1,186 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bm25_scorer.h" +#include +#include +#include +#include "fts_utils.h" + +namespace zvec::fts { + +// ============================================================ +// BM25Scorer implementation +// ============================================================ + +int BM25Scorer::load_segment_stats(const std::string &field_name, + RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *stat_cf) { + if (!ctx || !stat_cf) { + LOG_WARN("BM25Scorer::load_segment_stats: null ctx/stat_cf for field[%s]", + field_name.c_str()); + return -1; + } + + // Read total_docs + std::string total_docs_value; + auto ret = ctx->db_->Get(ctx->read_opts_, stat_cf, + make_total_docs_key(field_name), &total_docs_value); + if (!ret.ok()) { + LOG_ERROR( + "BM25Scorer::load_segment_stats: failed to read total_docs. " + "field[%s]", + field_name.c_str()); + return -1; + } + if (total_docs_value.size() < sizeof(uint64_t)) { + LOG_ERROR( + "BM25Scorer::load_segment_stats: total_docs value too small. " + "field[%s] value_size[%zu]", + field_name.c_str(), total_docs_value.size()); + return -1; + } + uint64_t total_docs = decode_uint64_value(total_docs_value.data()); + stats_.total_docs.store(total_docs, std::memory_order_release); + + // Read total_tokens + std::string total_tokens_value; + auto status = + ctx->db_->Get(ctx->read_opts_, stat_cf, make_total_tokens_key(field_name), + &total_tokens_value); + if (!status.ok()) { + LOG_ERROR( + "BM25Scorer::load_segment_stats: failed to read total_tokens. " + "field[%s]", + field_name.c_str()); + return -1; + } + if (total_tokens_value.size() < sizeof(uint64_t)) { + LOG_ERROR( + "BM25Scorer::load_segment_stats: total_tokens value too small. " + "field[%s] value_size[%zu]", + field_name.c_str(), total_tokens_value.size()); + return -1; + } + uint64_t total_tokens = decode_uint64_value(total_tokens_value.data()); + stats_.total_tokens.store(total_tokens, std::memory_order_release); + + return 0; +} + +float BM25Scorer::idf(uint64_t term_doc_freq) const { + const auto snap = stats_.snapshot(); + if (snap.total_docs == 0) { + return 0.0f; + } + // Robertson-Sparck Jones IDF formula (with smoothing): + // IDF(t) = ln((N - df + 0.5) / (df + 0.5) + 1) + const float total_docs = static_cast(snap.total_docs); + const float df = static_cast(term_doc_freq); + return std::log((total_docs - df + 0.5f) / (df + 0.5f) + 1.0f); +} + +float BM25Scorer::score(uint64_t term_doc_freq, uint32_t term_freq, + uint32_t doc_len) const { + // Take a single snapshot so that IDF and TF normalization use the same + // consistent values of total_docs / total_tokens. + const auto snap = stats_.snapshot(); + if (snap.total_docs == 0) { + return 0.0f; + } + + // IDF + const float total_docs = static_cast(snap.total_docs); + const float df = static_cast(term_doc_freq); + const float idf_value = + std::log((total_docs - df + 0.5f) / (df + 0.5f) + 1.0f); + if (idf_value <= 0.0f) { + return 0.0f; + } + + // TF normalization + const float tf = static_cast(term_freq); + const float doc_length = static_cast(doc_len); + const float avg_dl = snap.avg_doc_len(); + + // BM25 TF normalization formula: + // tf_norm = tf * (k1 + 1) / (tf + k1 * (1 - b + b * |d| / avgdl)) + const float tf_norm = + tf * (params_.k1 + 1.0f) / + (tf + params_.k1 * (1.0f - params_.b + params_.b * doc_length / avg_dl)); + + return idf_value * tf_norm; +} + +float BM25Scorer::score_with_idf(float idf_value, uint32_t term_freq, + uint32_t doc_len) const { + return score_with_idf(idf_value, term_freq, doc_len, 1.0f); +} + +float BM25Scorer::score_with_idf(float idf_value, uint32_t term_freq, + uint32_t doc_len, float boost) const { + if (idf_value <= 0.0f) { + return 0.0f; + } + const auto snap = stats_.snapshot(); + if (snap.total_docs == 0) { + return 0.0f; + } + + const float tf = static_cast(term_freq); + const float doc_length = static_cast(doc_len); + const float avg_dl = snap.avg_doc_len(); + + const float tf_norm = + tf * (params_.k1 + 1.0f) / + (tf + params_.k1 * (1.0f - params_.b + params_.b * doc_length / avg_dl)); + + return boost * idf_value * tf_norm; +} + +// ============================================================ +// WandOptimizer implementation +// ============================================================ + +int WandOptimizer::open(BM25ScorerPtr scorer, RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *max_tf_cf, uint32_t topk) { + if (!scorer || !ctx || !max_tf_cf) { + LOG_ERROR( + "WandOptimizer open failed: null arguments scorer[%p] ctx[%p] " + "max_tf_cf[%p]", + (void *)scorer.get(), (void *)ctx, (void *)max_tf_cf); + return -1; + } + scorer_ = std::move(scorer); + ctx_ = ctx; + max_tf_cf_ = max_tf_cf; + topk_ = topk; + return 0; +} + +uint32_t WandOptimizer::read_max_tf(const std::string &term) const { + if (!max_tf_cf_) { + return 1; + } + std::string max_tf_value; + if (!ctx_->db_->Get(ctx_->read_opts_, max_tf_cf_, term, &max_tf_value).ok() || + max_tf_value.size() < sizeof(uint32_t)) { + return 1; // Default max term frequency is 1 + } + uint32_t max_tf = 0; + std::memcpy(&max_tf, max_tf_value.data(), sizeof(uint32_t)); + return max_tf; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/bm25_scorer.h b/src/db/index/column/fts_column/bm25_scorer.h new file mode 100644 index 000000000..dd8bcfe9c --- /dev/null +++ b/src/db/index/column/fts_column/bm25_scorer.h @@ -0,0 +1,209 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "db/common/rocksdb_context.h" + +namespace zvec::fts { + +/*! BM25 scoring parameters + */ +struct BM25Params { + // Term frequency saturation parameter, typical value 1.2 + float k1{1.2f}; + // Document length normalization parameter, typical value 0.75 + float b{0.75f}; +}; + +/*! Plain snapshot of per-segment BM25 statistics (non-atomic, for callers) + */ +struct SegmentStatsSnapshot { + uint64_t total_docs{0}; + uint64_t total_tokens{0}; + + float avg_doc_len() const { + if (total_docs == 0) { + return 1.0f; + } + return static_cast(total_tokens) / static_cast(total_docs); + } +}; + +/*! Per-segment BM25 statistics (thread-safe) + * Fields are std::atomic so that concurrent insert (writer) and search + * (reader) threads do not race on the raw values. + */ +struct SegmentStats { + // Total number of documents in segment + std::atomic total_docs{0}; + // Total number of tokens in all documents in segment (used to calculate + // average document length) + std::atomic total_tokens{0}; + + SegmentStats() = default; + + // std::atomic is neither copyable nor movable; provide manual move + // semantics so that BM25Scorer (which embeds SegmentStats) stays movable. + // These are only used during single-threaded construction / NRVO and are + // therefore safe with relaxed ordering. + SegmentStats(SegmentStats &&other) noexcept + : total_docs(other.total_docs.load(std::memory_order_relaxed)), + total_tokens(other.total_tokens.load(std::memory_order_relaxed)) {} + + SegmentStats &operator=(SegmentStats &&other) noexcept { + total_docs.store(other.total_docs.load(std::memory_order_relaxed), + std::memory_order_relaxed); + total_tokens.store(other.total_tokens.load(std::memory_order_relaxed), + std::memory_order_relaxed); + return *this; + } + + SegmentStats(const SegmentStats &) = delete; + SegmentStats &operator=(const SegmentStats &) = delete; + + // Take a consistent snapshot: load total_tokens first (the value that + // grows together with total_docs) so the pair is *at least* as fresh as + // the docs count, avoiding avg_doc_len() returning an inflated value. + SegmentStatsSnapshot snapshot() const { + const uint64_t tokens = total_tokens.load(std::memory_order_acquire); + const uint64_t docs = total_docs.load(std::memory_order_acquire); + return {docs, tokens}; + } + + // Average document length (total_tokens / total_docs) + float avg_doc_len() const { + return snapshot().avg_doc_len(); + } +}; + +/*! BM25 scorer + * Encapsulates standard BM25 formula, supports per-segment statistics loading + * and WAND optimization + * + * BM25 formula: + * score(q, d) = Σ IDF(t) * (tf(t,d) * (k1+1)) / (tf(t,d) + + * k1*(1-b+b*|d|/avgdl)) IDF(t) = ln((N - df(t) + 0.5) / (df(t) + 0.5) + 1) + */ +class BM25Scorer { + public: + explicit BM25Scorer(BM25Params params = BM25Params{}) : params_(params) {} + + /*! Load per-segment statistics from $SEGMENT_STAT CF + * \param field_name Field name + * \param stat_cf $SEGMENT_STAT CF + * \return 0 for success, non-0 for failure + */ + int load_segment_stats(const std::string &field_name, RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *stat_cf); + + /*! Calculate BM25 contribution score of a single term for a single document + * \param term_doc_freq Document frequency of this term in segment (df) + * \param term_freq Term frequency of this term in current document + * (tf) \param doc_len Length of current document (number of tokens) + * \return BM25 score contribution + */ + float score(uint64_t term_doc_freq, uint32_t term_freq, + uint32_t doc_len) const; + + /*! Calculate IDF value of a term + * \param term_doc_freq Document frequency of this term in segment (df) + * \return IDF value + */ + float idf(uint64_t term_doc_freq) const; + + /*! Calculate BM25 score using a pre-computed IDF value. + * Avoids recomputing log() on every call — IDF is constant per term. + * \param idf_value Pre-computed IDF value (from idf()) + * \param term_freq Term frequency in current document + * \param doc_len Document length (number of tokens) + * \return BM25 score contribution + */ + float score_with_idf(float idf_value, uint32_t term_freq, + uint32_t doc_len) const; + + /*! Calculate BM25 score with a per-term boost multiplier. + * Boost > 1 represents a term that appears multiple times in the original + * query (collapsed by the AST rewriter) or carries an explicit user weight. + * The multiplier is linear so that the post-rewrite score exactly matches + * the pre-rewrite "sum of N independent scorers" value. + * \param idf_value Pre-computed IDF value (from idf()) + * \param term_freq Term frequency in current document + * \param doc_len Document length (number of tokens) + * \param boost Per-term boost (1.0 = no boost) + * \return BM25 score contribution scaled by boost + */ + float score_with_idf(float idf_value, uint32_t term_freq, uint32_t doc_len, + float boost) const; + + /*! Update in-memory segment statistics (called by FtsColumnIndexer after + * each insert so that search() uses up-to-date stats for BM25 scoring) + * \param total_docs Current total number of documents + * \param total_tokens Current total number of tokens + */ + void update_stats(uint64_t total_docs, uint64_t total_tokens) { + // Store total_docs first so that a concurrent reader calling snapshot() + // (which loads total_tokens before total_docs) never sees a new docs + // count paired with a stale tokens count, which would deflate avg_doc_len. + stats_.total_docs.store(total_docs, std::memory_order_release); + stats_.total_tokens.store(total_tokens, std::memory_order_release); + } + + SegmentStatsSnapshot stats() const { + return stats_.snapshot(); + } + const BM25Params ¶ms() const { + return params_; + } + + private: + BM25Params params_; + SegmentStats stats_; +}; + +using BM25ScorerPtr = std::shared_ptr; + +/*! WAND optimizer + * Uses $MAX_TF as upper bound for TopK pruning, reduces unnecessary document + * scoring + */ +class WandOptimizer { + public: + /*! Initialize WAND optimizer + * \param scorer BM25 scorer (with segment statistics loaded) + * \param max_tf_cf $MAX_TF CF (stores maximum term frequency for each + * term) \param topk Number of TopK results to return + */ + int open(BM25ScorerPtr scorer, RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *max_tf_cf, uint32_t topk); + + /*! Read the maximum term frequency for a term from $MAX_TF CF. + * Used by TermDocIterator to precompute WAND upper bound score. + * \param term The term to look up + * \return Maximum term frequency, or 1 if not found + */ + uint32_t read_max_tf(const std::string &term) const; + + private: + BM25ScorerPtr scorer_; + RocksdbContext *ctx_{nullptr}; + rocksdb::ColumnFamilyHandle *max_tf_cf_{nullptr}; + uint32_t topk_{10}; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_ast_rewriter.cc b/src/db/index/column/fts_column/fts_ast_rewriter.cc new file mode 100644 index 000000000..475f71b9d --- /dev/null +++ b/src/db/index/column/fts_column/fts_ast_rewriter.cc @@ -0,0 +1,398 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_ast_rewriter.h" +#include +#include +#include + +namespace zvec::fts { + +namespace { + +// Two AST nodes are dedup-equivalent when they are the same leaf kind and +// carry identical modifiers and identical scoring key (term string for +// TermNode, terms vector for PhraseNode). Boost is intentionally NOT part of +// the key — it is what we accumulate during dedup. +bool same_dedup_key(const FtsAstNode &a, const FtsAstNode &b) { + if (a.type() != b.type()) { + return false; + } + if (a.must != b.must || a.must_not != b.must_not) { + return false; + } + if (a.type() == FtsNodeType::TERM) { + return static_cast(a).term == + static_cast(b).term; + } + if (a.type() == FtsNodeType::PHRASE) { + return static_cast(a).terms == + static_cast(b).terms; + } + return false; +} + +// Same scoring key as same_dedup_key but ignores modifiers — used to detect +// `+apple -apple` style conflicts inside an AND node. +bool same_term_or_phrase_text(const FtsAstNode &a, const FtsAstNode &b) { + if (a.type() != b.type()) { + return false; + } + if (a.type() == FtsNodeType::TERM) { + return static_cast(a).term == + static_cast(b).term; + } + if (a.type() == FtsNodeType::PHRASE) { + return static_cast(a).terms == + static_cast(b).terms; + } + return false; +} + +// Collapse adjacent duplicates (TermNode/PhraseNode siblings sharing the same +// dedup key) into a single node whose boost is the linear sum. O(K^2) — K is +// the sibling count, typically small enough that a hash map would cost more in +// allocations than it would save in comparisons. +void merge_duplicate_siblings(std::vector &children) { + for (size_t i = 0; i < children.size(); ++i) { + auto &a = children[i]; + if (!a) { + continue; + } + if (a->type() != FtsNodeType::TERM && a->type() != FtsNodeType::PHRASE) { + continue; + } + for (size_t j = i + 1; j < children.size();) { + auto &b = children[j]; + if (b && same_dedup_key(*a, *b)) { + a->boost += b->boost; + children.erase(children.begin() + j); + } else { + ++j; + } + } + } +} + +// Flatten guard: an inner OrNode can be inlined into a parent OR only when it +// is a pure disjunction — itself unmodified and containing no must/must_not +// children. Otherwise inlining would change semantics (a must_not child would +// silently widen its exclusion scope from the inner OR to the outer OR). +bool can_inline_into_or(const FtsAstNode &child) { + if (child.type() != FtsNodeType::OR) { + return false; + } + if (child.must || child.must_not) { + return false; + } + const auto &inner = static_cast(child); + for (const auto &c : inner.children) { + if (c && (c->must || c->must_not)) { + return false; + } + } + return true; +} + +// Flatten guard: an inner AndNode can be inlined into a parent AND only when +// itself unmodified and containing no must_not children. must children inside +// an AND are equivalent to plain children (build_and_iterator treats both as +// MUST), so they are safe to inline. must_not children are NOT safe to lift +// across a must_not parent boundary. +bool can_inline_into_and(const FtsAstNode &child) { + if (child.type() != FtsNodeType::AND) { + return false; + } + if (child.must || child.must_not) { + return false; + } + const auto &inner = static_cast(child); + for (const auto &c : inner.children) { + if (c && c->must_not) { + return false; + } + } + return true; +} + +// Splice inlinable OR children's grandchildren in place of the child. Reuses +// each grandchild's unique_ptr — no AST node allocations. +void flatten_or_children(std::vector &children) { + std::vector out; + out.reserve(children.size()); + for (auto &child : children) { + if (child && can_inline_into_or(*child)) { + auto &inner = static_cast(*child); + for (auto &grandchild : inner.children) { + if (grandchild) { + out.push_back(std::move(grandchild)); + } + } + } else { + out.push_back(std::move(child)); + } + } + children = std::move(out); +} + +void flatten_and_children(std::vector &children) { + std::vector out; + out.reserve(children.size()); + for (auto &child : children) { + if (child && can_inline_into_and(*child)) { + auto &inner = static_cast(*child); + for (auto &grandchild : inner.children) { + if (grandchild) { + out.push_back(std::move(grandchild)); + } + } + } else { + out.push_back(std::move(child)); + } + } + children = std::move(out); +} + +// Drop null children left behind by recursive simplify() reporting "this +// subtree contributed nothing" via a moved-out pointer. +void drop_nulls(std::vector &children) { + children.erase(std::remove_if(children.begin(), children.end(), + [](const FtsAstNodePtr &p) { return !p; }), + children.end()); +} + +// Make an EmptyNode carrying the modifier of the node being replaced. This +// preserves +/- semantics so parent nodes interpret the replacement the same +// way they would the original. +FtsAstNodePtr make_empty_like(const FtsAstNode &original) { + auto e = std::make_unique(); + e->must = original.must; + e->must_not = original.must_not; + // Boost is meaningless on EmptyNode — it matches nothing — but keep the + // value for round-trippable debug output. + e->boost = original.boost; + return e; +} + +// If the AND contains a positive child and a must_not child with the same +// term/phrase key, the conjunction matches nothing. +bool and_has_mustnot_conflict(const AndNode &n) { + for (size_t i = 0; i < n.children.size(); ++i) { + const auto &pi = n.children[i]; + if (!pi || pi->must_not) { + continue; + } + if (pi->type() != FtsNodeType::TERM && pi->type() != FtsNodeType::PHRASE) { + continue; + } + for (size_t j = 0; j < n.children.size(); ++j) { + if (i == j) { + continue; + } + const auto &pj = n.children[j]; + if (!pj || !pj->must_not) { + continue; + } + if (same_term_or_phrase_text(*pi, *pj)) { + return true; + } + } + } + return false; +} + +void simplify_and(FtsAstNodePtr &node); +void simplify_or(FtsAstNodePtr &node); + +void simplify_and(FtsAstNodePtr &node) { + auto &n = static_cast(*node); + + // 1. Recurse first so children are already in normal form. + for (auto &child : n.children) { + simplify(child); + } + drop_nulls(n.children); + + // 2. EmptyNode propagation: a positive EMPTY makes the whole AND empty; + // a must_not EMPTY (i.e. "exclude nothing") is a no-op and is dropped. + for (auto it = n.children.begin(); it != n.children.end();) { + if ((*it)->type() == FtsNodeType::EMPTY) { + if ((*it)->must_not) { + it = n.children.erase(it); + } else { + node = make_empty_like(n); + return; + } + } else { + ++it; + } + } + + // 3. Flatten nested AND, then dedup siblings (linear-boost sum). + flatten_and_children(n.children); + merge_duplicate_siblings(n.children); + + // 4. `+apple -apple` style conflict → empty doc set. + if (and_has_mustnot_conflict(n)) { + node = make_empty_like(n); + return; + } + + // 5. AND containing only must_not children has no positive base set to + // subtract from — by convention this matches nothing. + bool any_positive = false; + for (const auto &c : n.children) { + if (!c->must_not) { + any_positive = true; + break; + } + } + if (!any_positive) { + node = make_empty_like(n); + return; + } + + // 6. Single-child fold. Combine the outer AND's modifier with the surviving + // child; if the combination yields must && must_not, replace with EMPTY + // (a self-contradictory clause matches nothing). + if (n.children.size() == 1) { + FtsAstNodePtr child = std::move(n.children[0]); + child->must = child->must || n.must; + child->must_not = child->must_not || n.must_not; + if (child->must && child->must_not) { + auto e = std::make_unique(); + e->must = n.must; + e->must_not = n.must_not; + node = std::move(e); + return; + } + node = std::move(child); + } +} + +void simplify_or(FtsAstNodePtr &node) { + auto &n = static_cast(*node); + + for (auto &child : n.children) { + simplify(child); + } + drop_nulls(n.children); + + // EmptyNode in OR: a positive EMPTY contributes no documents → drop it. + // A must_not EMPTY excludes nothing → also drop. Either way, simply remove. + n.children.erase(std::remove_if(n.children.begin(), n.children.end(), + [](const FtsAstNodePtr &p) { + return p && p->type() == FtsNodeType::EMPTY; + }), + n.children.end()); + + flatten_or_children(n.children); + merge_duplicate_siblings(n.children); + + // OR with no remaining positive children matches nothing. (must_not children + // inside an OR mean "exclude from the disjunction"; with no positive base + // the result is empty.) + bool any_positive = false; + size_t mustnot_count = 0; + for (const auto &c : n.children) { + if (c->must_not) { + ++mustnot_count; + } else { + any_positive = true; + } + } + if (!any_positive) { + node = make_empty_like(n); + return; + } + + // Canonicalize OR-with-must_not into AND(OR(positives), must_nots...). After + // this, an OrNode never carries must_not children, so the iterator builder + // can drop its special-case wrapping. Conflict cases like `apple -apple` end + // up inside the new AND where and_has_mustnot_conflict catches them and + // collapses the whole subtree to EmptyNode for free. + if (mustnot_count > 0) { + std::vector positives; + std::vector negatives; + positives.reserve(n.children.size() - mustnot_count); + negatives.reserve(mustnot_count); + for (auto &c : n.children) { + if (c->must_not) { + negatives.push_back(std::move(c)); + } else { + positives.push_back(std::move(c)); + } + } + + FtsAstNodePtr positive_part; + if (positives.size() == 1) { + positive_part = std::move(positives[0]); + } else { + auto inner_or = std::make_unique(); + inner_or->children = std::move(positives); + positive_part = std::move(inner_or); + } + + auto wrap = std::make_unique(); + wrap->children.reserve(1 + negatives.size()); + wrap->children.push_back(std::move(positive_part)); + for (auto &mn : negatives) { + wrap->children.push_back(std::move(mn)); + } + wrap->must = n.must; + wrap->must_not = n.must_not; + wrap->boost = n.boost; + + FtsAstNodePtr replacement = std::move(wrap); + simplify_and(replacement); + node = std::move(replacement); + return; + } + + if (n.children.size() == 1) { + FtsAstNodePtr child = std::move(n.children[0]); + child->must = child->must || n.must; + child->must_not = child->must_not || n.must_not; + if (child->must && child->must_not) { + auto e = std::make_unique(); + e->must = n.must; + e->must_not = n.must_not; + node = std::move(e); + return; + } + node = std::move(child); + } +} + +} // namespace + +void simplify(FtsAstNodePtr &node) { + if (!node) { + return; + } + switch (node->type()) { + case FtsNodeType::TERM: + case FtsNodeType::PHRASE: + case FtsNodeType::EMPTY: + return; + case FtsNodeType::AND: + simplify_and(node); + return; + case FtsNodeType::OR: + simplify_or(node); + return; + } +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_ast_rewriter.h b/src/db/index/column/fts_column/fts_ast_rewriter.h new file mode 100644 index 000000000..071f77e91 --- /dev/null +++ b/src/db/index/column/fts_column/fts_ast_rewriter.h @@ -0,0 +1,43 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fts_query_ast.h" + +namespace zvec::fts { + +/*! Structural simplification of an FTS AST. + * + * Performs a single post-order pass that: + * - flattens nested AND-of-AND / OR-of-OR (with Lucene-style guards that + * preserve the must/must_not semantics of the inner node) + * - dedups sibling TermNode / PhraseNode duplicates by summing boosts + * linearly, so the resulting score equals the pre-rewrite "sum of N + * independent scorers" output exactly + * - propagates EmptyNode (AND short-circuits, OR drops empties) + * - folds single-child AND/OR into the child + * - detects must vs must_not contradictions inside an AND + * (e.g. `+apple -apple`) and rewrites the AND to EmptyNode + * + * Idempotent: simplify(simplify(x)) == simplify(x). The transformation + * preserves the document-set semantics of the original AST and, under the + * linear-boost rule, also preserves the per-document BM25 score. + * + * Mutates the node in place via the unique_ptr (may replace it with a + * different node, e.g. EmptyNode or a folded child). + */ +void simplify(FtsAstNodePtr &node); + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_column_indexer.cc b/src/db/index/column/fts_column/fts_column_indexer.cc new file mode 100644 index 000000000..f67dd5d99 --- /dev/null +++ b/src/db/index/column/fts_column/fts_column_indexer.cc @@ -0,0 +1,888 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_column_indexer.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "db/common/typedef.h" +#include "iterator/fts_candidate_iterator.h" +#include "iterator/fts_conjunction_iterator.h" +#include "iterator/fts_disjunction_iterator.h" +#include "iterator/fts_phrase_iterator.h" +#include "iterator/fts_term_iterator.h" +#include "posting/bitpacked_posting_list.h" +#include "fts_pipeline.h" +#include "fts_utils.h" + +namespace zvec::fts { + +// ============================================================ +// Lifecycle +// ============================================================ + +FtsColumnIndexer::~FtsColumnIndexer() { + // Pipeline release is handled by FtsIndexParams destructor via fts_params_. + if (opened_.load()) { + (void)close(); + } +} + +// ============================================================ +// Initialization — shared reader core +// ============================================================ + +Result FtsColumnIndexer::open_reader( + const std::string &field_name, RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *postings_cf, + rocksdb::ColumnFamilyHandle *positions_cf, + rocksdb::ColumnFamilyHandle *term_freq_cf, + rocksdb::ColumnFamilyHandle *max_tf_cf, + rocksdb::ColumnFamilyHandle *doc_len_cf, + rocksdb::ColumnFamilyHandle *stat_cf, BM25Params bm25_params) { + if (opened_.load()) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer already opened. field=", field_name)); + } + + field_name_ = field_name; + ctx_ = ctx; + postings_cf_ = postings_cf; + positions_cf_ = positions_cf; + term_freq_cf_ = term_freq_cf; + max_tf_cf_ = max_tf_cf; + doc_len_cf_ = doc_len_cf; + stat_cf_ = stat_cf; + + scorer_ = std::make_shared(bm25_params); + + // doc_len_cf == nullptr → immutable path, load persisted stats. + // doc_len_cf != nullptr → mutable path, stats maintained in-memory. + if (doc_len_cf == nullptr) { + int ret = scorer_->load_segment_stats(field_name, ctx, stat_cf); + if (ret != 0) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer failed to load segment stats. field=", field_name)); + } + } + + opened_.store(true); + return {}; +} + +// ============================================================ +// Initialization — read+write (mutable) +// ============================================================ + +Result FtsColumnIndexer::open(FieldSchema::Ptr field_meta, + RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *postings_cf, + rocksdb::ColumnFamilyHandle *positions_cf, + rocksdb::ColumnFamilyHandle *term_freq_cf, + rocksdb::ColumnFamilyHandle *max_tf_cf, + rocksdb::ColumnFamilyHandle *doc_len_cf, + rocksdb::ColumnFamilyHandle *stat_cf) { + if (!field_meta || !ctx) { + return tl::make_unexpected( + Status::InvalidArgument("FtsColumnIndexer: null field_meta or ctx")); + } + + // Obtain FtsIndexParams from field_meta's index_params. + auto index_params = field_meta->index_params(); + auto fts_param = + std::dynamic_pointer_cast(index_params); + if (!fts_param) { + return tl::make_unexpected(Status::InvalidArgument( + "FtsColumnIndexer: field has no FtsIndexParams. field=", + field_meta->name())); + } + + auto pipeline_result = zvec::detail::AcquireFtsPipeline(*fts_param); + if (!pipeline_result.has_value()) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer: failed to create tokenizer pipeline. field=", + field_meta->name(), " err=", pipeline_result.error().message())); + } + + field_meta_ = std::move(field_meta); + tokenizer_pipeline_ = std::move(pipeline_result.value()); + fts_params_ = fts_param; + + return open_reader(field_meta_->name(), ctx, postings_cf, positions_cf, + term_freq_cf, max_tf_cf, doc_len_cf, stat_cf); +} + +// ============================================================ +// Initialization — read-only (immutable / standalone) +// ============================================================ + +// ============================================================ +// Close +// ============================================================ + +Result FtsColumnIndexer::close() { + if (!opened_.load()) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer::close: not opened. field=", field_name_)); + } + + postings_cf_ = nullptr; + positions_cf_ = nullptr; + term_freq_cf_.store(nullptr, std::memory_order_release); + max_tf_cf_.store(nullptr, std::memory_order_release); + doc_len_cf_.store(nullptr, std::memory_order_release); + stat_cf_ = nullptr; + scorer_.reset(); + + opened_.store(false); + return {}; +} + +// ============================================================ +// Query entry point +// ============================================================ + +Result> FtsColumnIndexer::search( + const FtsAstNode &ast, const FtsQueryParams &query_params) const { + if (!scorer_) { + LOG_ERROR("FtsColumnIndexer::search: not opened. field[%s]", + field_name_.c_str()); + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer::search: not opened. field=", field_name_)); + } + + if (query_params.topk == 0) { + return std::vector{}; + } + + if (ast.must_not) { + LOG_WARN( + "FtsColumnIndexer::search: must_not on root is not allowed. field[%s]", + field_name_.c_str()); + return tl::make_unexpected(Status::InvalidArgument( + "FtsColumnIndexer::search: must_not on root is not allowed. field=", + field_name_)); + } + + auto iter_result = build_iterator(ast); + if (!iter_result.has_value()) { + LOG_ERROR("FtsColumnIndexer::search: build_iterator failed. field[%s] %s", + field_name_.c_str(), iter_result.error().message().c_str()); + return tl::make_unexpected(iter_result.error()); + } + DocIteratorPtr root_iter = std::move(iter_result.value()); + if (!root_iter) { + // No matching terms found — valid empty result, not an error. + return std::vector{}; + } + + // Candidate-driven mode: AND a CandidateDocIterator into the root so the + // small candidate set leads (Conjunction sorts by cost asc), turning the + // posting walk into per-candidate advance()+matches()+score(). + if (!query_params.candidate_ids.empty()) { + std::vector musts; + musts.reserve(2); + musts.push_back( + std::make_unique(query_params.candidate_ids)); + musts.push_back(std::move(root_iter)); + root_iter = std::make_unique( + std::move(musts), std::vector{}); + } + + const uint32_t topk = query_params.topk; + const zvec::IndexFilter *filter_ptr = query_params.filter.get(); + + using MinHeap = std::priority_queue, + std::greater>; + MinHeap min_heap; + + // Filter pushdown: when a filter is present, use the filter-aware next_doc + // overload so composite iterators skip filtered docs before paying for + // block-max binary search, do_next alignment, or phase-2 position checks. + uint32_t doc_id = + filter_ptr ? root_iter->next_doc(filter_ptr) : root_iter->next_doc(); + while (doc_id != DocIterator::NO_MORE_DOCS) { + const uint64_t global_doc_id = static_cast(doc_id); + + if (root_iter->matches()) { + float s = root_iter->score(); + if (s > 0.0f) { + if (min_heap.size() < topk) { + min_heap.push({global_doc_id, s}); + if (min_heap.size() == topk) { + root_iter->set_min_competitive_score(min_heap.top().score); + } + } else if (s > min_heap.top().score) { + min_heap.pop(); + min_heap.push({global_doc_id, s}); + root_iter->set_min_competitive_score(min_heap.top().score); + } + } + } + doc_id = + filter_ptr ? root_iter->next_doc(filter_ptr) : root_iter->next_doc(); + } + + std::vector results(min_heap.size()); + for (auto it = results.rbegin(); it != results.rend(); ++it) { + *it = min_heap.top(); + min_heap.pop(); + } + + return results; +} + +// ============================================================ +// Side CF reset (dump path) +// ============================================================ + +void FtsColumnIndexer::reset_side_cfs() { + cf_dropped_.store(true); + while (cf_counter_.load() > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + term_freq_cf_.store(nullptr, std::memory_order_release); + max_tf_cf_.store(nullptr, std::memory_order_release); + doc_len_cf_.store(nullptr, std::memory_order_release); +} + +// ============================================================ +// Iterator tree construction +// ============================================================ + +Result FtsColumnIndexer::build_iterator( + const FtsAstNode &node) const { + switch (node.type()) { + case FtsNodeType::TERM: + return build_term_iterator(static_cast(node)); + case FtsNodeType::PHRASE: + return build_phrase_iterator(static_cast(node)); + case FtsNodeType::AND: + return build_and_iterator(static_cast(node)); + case FtsNodeType::OR: + return build_or_iterator(static_cast(node)); + case FtsNodeType::EMPTY: + // Null iterator reuses the existing AND/OR/search() null-handling path. + return DocIteratorPtr{nullptr}; + default: + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer::build_iterator: unknown node type. field=", + field_name_)); + } +} + +Result FtsColumnIndexer::create_term_iterator_from_raw( + const std::string &term, rocksdb::PinnableSlice raw_data, + float boost) const { + if (BitPackedPostingList::is_bitpacked_format(raw_data.data(), + raw_data.size())) { + auto iter = std::make_unique(term, std::move(raw_data), + scorer_, boost); + if (iter->cost() == 0) { + return DocIteratorPtr{nullptr}; + } + return iter; + } + + roaring_bitmap_t *bitmap = roaring_bitmap_portable_deserialize_safe( + raw_data.data(), raw_data.size()); + if (!bitmap) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer: failed to deserialize roaring bitmap. field=", + field_name_, " term=", term)); + } + + const uint64_t df = roaring_bitmap_get_cardinality(bitmap); + if (df == 0) { + roaring_bitmap_free(bitmap); + return nullptr; + } + + ++cf_counter_; + auto *term_freq_cf = term_freq_cf_.load(std::memory_order_acquire); + auto *doc_len_cf = doc_len_cf_.load(std::memory_order_acquire); + auto *max_tf_cf = max_tf_cf_.load(std::memory_order_acquire); + auto *cf_counter = &cf_counter_; + if (cf_dropped_) { + term_freq_cf = nullptr; + doc_len_cf = nullptr; + cf_counter = nullptr; + max_tf_cf = nullptr; + --cf_counter_; + } + + float max_score_val = 0.0f; + if (max_tf_cf) { + WandOptimizer wand; + if (wand.open(scorer_, ctx_, max_tf_cf, 0) == 0) { + uint32_t max_tf = wand.read_max_tf(term); + uint32_t min_dl = min_doc_len_.load(std::memory_order_relaxed); + if (min_dl == std::numeric_limits::max()) { + min_dl = 1; + } + max_score_val = scorer_->score(df, max_tf, min_dl); + } + } + + return std::make_unique(term, bitmap, df, scorer_, + max_score_val, ctx_, term_freq_cf, + doc_len_cf, cf_counter, boost); +} + +Result FtsColumnIndexer::build_term_iterator( + const TermNode &term_node) const { + const std::string &term = term_node.term; + + rocksdb::PinnableSlice raw_data; + auto s = ctx_->db_->Get(ctx_->read_opts_, postings_cf_, term, &raw_data); + if (!s.ok() || raw_data.empty()) { + return DocIteratorPtr{nullptr}; + } + + return create_term_iterator_from_raw(term, std::move(raw_data), + term_node.boost); +} + +std::vector FtsColumnIndexer::batch_get_postings( + const std::vector &terms) const { + std::vector raw_postings(terms.size()); + if (terms.empty()) { + return raw_postings; + } + + std::vector cfs(terms.size(), postings_cf_); + std::vector statuses(terms.size()); + ctx_->db_->MultiGet(ctx_->read_opts_, terms.size(), cfs.data(), terms.data(), + raw_postings.data(), statuses.data()); + // Ignore failed lookups as callers can check via empty() + return raw_postings; +} + +Result FtsColumnIndexer::build_phrase_iterator( + const PhraseNode &phrase_node) const { + if (phrase_node.terms.empty()) { + return DocIteratorPtr{nullptr}; + } + + const std::vector &terms = phrase_node.terms; + std::vector term_slices; + term_slices.reserve(terms.size()); + for (const auto &t : terms) { + term_slices.emplace_back(t); + } + auto raw_postings = batch_get_postings(term_slices); + + std::vector term_iterators; + term_iterators.reserve(terms.size()); + + // Phrase-level boost is distributed across the internal term iterators. + // PhraseDocIterator.score() delegates to conjunction.score() which sums the + // internal contributions, so multiplying each contribution by boost yields + // boost * (sum) = boost-applied-once at the phrase level. + for (size_t i = 0; i < terms.size(); ++i) { + if (raw_postings[i].empty()) { + return DocIteratorPtr{nullptr}; + } + auto iter_result = create_term_iterator_from_raw( + terms[i], std::move(raw_postings[i]), phrase_node.boost); + if (!iter_result.has_value()) { + return iter_result; + } + if (!iter_result.value()) { + return DocIteratorPtr{nullptr}; + } + term_iterators.push_back(std::move(iter_result.value())); + } + + if (term_iterators.empty()) { + return DocIteratorPtr{nullptr}; + } + + auto conjunction = std::make_unique( + std::move(term_iterators), std::vector{}); + + return std::make_unique(std::move(conjunction), terms, + ctx_, positions_cf_); +} + +Result FtsColumnIndexer::build_and_iterator( + const AndNode &and_node) const { + if (and_node.children.empty()) { + return DocIteratorPtr{nullptr}; + } + + std::vector term_key_slices; + std::vector term_child_indices; + term_key_slices.reserve(and_node.children.size()); + term_child_indices.reserve(and_node.children.size()); + + for (size_t i = 0; i < and_node.children.size(); ++i) { + const auto &child = and_node.children[i]; + if (child && child->type() == FtsNodeType::TERM) { + term_key_slices.emplace_back(static_cast(*child).term); + term_child_indices.push_back(i); + } + } + + auto term_raw_postings = batch_get_postings(term_key_slices); + + std::vector must_iterators; + std::vector must_not_iterators; + size_t batched_cursor = 0; + + for (size_t i = 0; i < and_node.children.size(); ++i) { + const auto &child = and_node.children[i]; + const bool is_must_not = child->must_not; + + DocIteratorPtr iter; + if (batched_cursor < term_child_indices.size() && + term_child_indices[batched_cursor] == i) { + rocksdb::PinnableSlice &raw = term_raw_postings[batched_cursor]; + const auto &term_node = static_cast(*child); + if (!raw.empty()) { + auto iter_result = create_term_iterator_from_raw( + term_node.term, std::move(raw), term_node.boost); + if (!iter_result.has_value()) { + return iter_result; + } + iter = std::move(iter_result.value()); + } + ++batched_cursor; + } else { + auto iter_result = build_iterator(*child); + if (!iter_result.has_value()) { + return iter_result; + } + iter = std::move(iter_result.value()); + } + + if (!iter) { + if (!is_must_not) { + return DocIteratorPtr{nullptr}; + } + continue; + } + + if (is_must_not) { + must_not_iterators.push_back(std::move(iter)); + } else { + must_iterators.push_back(std::move(iter)); + } + } + + if (must_iterators.empty()) { + return DocIteratorPtr{nullptr}; + } + + if (must_iterators.size() == 1 && must_not_iterators.empty()) { + return std::move(must_iterators[0]); + } + + return std::make_unique(std::move(must_iterators), + std::move(must_not_iterators)); +} + +Result FtsColumnIndexer::build_or_iterator( + const OrNode &or_node) const { + if (or_node.children.empty()) { + return DocIteratorPtr{nullptr}; + } + + std::vector term_key_slices; + std::vector term_child_indices; + term_key_slices.reserve(or_node.children.size()); + term_child_indices.reserve(or_node.children.size()); + + for (size_t i = 0; i < or_node.children.size(); ++i) { + const auto &child = or_node.children[i]; + if (child && child->type() == FtsNodeType::TERM) { + term_key_slices.emplace_back(static_cast(*child).term); + term_child_indices.push_back(i); + } + } + + auto term_raw_postings = batch_get_postings(term_key_slices); + + // Invariant: the AST rewriter (fts::simplify) lifts any must_not children + // out of OrNode into a wrapping AndNode before we get here, so the loop + // below only ever sees SHOULD-style positives. A must_not child reaching + // this point indicates a caller that bypassed simplify — bail out loudly + // rather than silently produce wrong scores. + std::vector positive_iterators; + size_t batched_cursor = 0; + + for (size_t i = 0; i < or_node.children.size(); ++i) { + const auto &child = or_node.children[i]; + if (child->must_not) { + LOG_ERROR( + "build_or_iterator: must_not child reached OR (rewriter bypassed)"); + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer::build_or_iterator: OR contains must_not child")); + } + + DocIteratorPtr iter; + if (batched_cursor < term_child_indices.size() && + term_child_indices[batched_cursor] == i) { + rocksdb::PinnableSlice &raw = term_raw_postings[batched_cursor]; + const auto &term_node = static_cast(*child); + if (!raw.empty()) { + auto iter_result = create_term_iterator_from_raw( + term_node.term, std::move(raw), term_node.boost); + if (!iter_result.has_value()) { + return iter_result; + } + iter = std::move(iter_result.value()); + } + ++batched_cursor; + } else { + auto iter_result = build_iterator(*child); + if (!iter_result.has_value()) { + return iter_result; + } + iter = std::move(iter_result.value()); + } + + if (iter) { + positive_iterators.push_back(std::move(iter)); + } + } + + if (positive_iterators.empty()) { + return DocIteratorPtr{nullptr}; + } + if (positive_iterators.size() == 1) { + return std::move(positive_iterators[0]); + } + return std::make_unique(std::move(positive_iterators)); +} + +// ============================================================ +// Write operations +// ============================================================ + +Result FtsColumnIndexer::insert(uint64_t seg_doc_id, + const std::string &text) { + // safe access check + + if (!tokenizer_pipeline_ || !ctx_) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer::insert: not opened. field=", field_name_)); + } + + // Tokenize + std::vector tokens = tokenizer_pipeline_->process(text); + const uint32_t doc_len = static_cast(tokens.size()); + + // Aggregate position lists by term + std::unordered_map> term_positions; + for (const auto &token : tokens) { + term_positions[token.text].push_back(token.position); + } + + // Store seg_doc_id in RocksDB directly, similar to invert indexer + const uint32_t doc_id_32 = static_cast(seg_doc_id); + + // Pre-serialize a single-element Roaring Bitmap for this doc_id once, + // reused across all terms to avoid repeated create/serialize/free overhead. + roaring_bitmap_t *single_bitmap = roaring_bitmap_create_with_capacity(1); + roaring_bitmap_add(single_bitmap, doc_id_32); + size_t bitmap_size = roaring_bitmap_portable_size_in_bytes(single_bitmap); + std::string bitmap_data(bitmap_size, '\0'); + roaring_bitmap_portable_serialize(single_bitmap, bitmap_data.data()); + roaring_bitmap_free(single_bitmap); + + // Batch all writes for this document into a single cross-CF WriteBatch, + // reducing 4N+1 individual RocksDB Write() calls to one atomic write. + rocksdb::WriteBatch batch; + + for (const auto &[term, positions] : term_positions) { + const uint32_t tf = static_cast(positions.size()); + + // 1. Postings CF: merge doc_id bitmap + batch.Merge(postings_cf_, term, bitmap_data); + + // 2. Positions CF: term\0doc_id -> delta-varint positions + const std::string doc_term_key = make_doc_term_key(term, doc_id_32); + batch.Put(positions_cf_, doc_term_key, encode_positions(positions)); + + // 3. Term-freq CF: term\0doc_id -> uint32_t tf + std::string tf_value(sizeof(uint32_t), '\0'); + std::memcpy(tf_value.data(), &tf, sizeof(uint32_t)); + batch.Put(term_freq_cf_.load(), doc_term_key, tf_value); + + // 4. Max-TF CF: term -> max(tf) via merge + batch.Merge(max_tf_cf_.load(), term, tf_value); + } + + // 5. Doc-len CF: doc_id -> uint32_t doc_len + std::string doc_id_key(sizeof(uint32_t), '\0'); + std::memcpy(doc_id_key.data(), &doc_id_32, sizeof(uint32_t)); + std::string doc_len_value(sizeof(uint32_t), '\0'); + std::memcpy(doc_len_value.data(), &doc_len, sizeof(uint32_t)); + batch.Put(doc_len_cf_.load(), doc_id_key, doc_len_value); + + if (auto s = ctx_->db_->Write(ctx_->write_opts_, &batch); !s.ok()) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer::insert: write batch failed. field=", field_name_, + " status=", s.ToString())); + } + + // 6. Update in-memory statistics atomically so concurrent search() calls + // see up-to-date values for BM25 scoring. + const uint64_t new_total_docs = + total_docs_.fetch_add(1, std::memory_order_relaxed) + 1; + const uint64_t new_total_tokens = + total_tokens_.fetch_add(doc_len, std::memory_order_relaxed) + doc_len; + + // Propagate updated stats to the scorer so that search() uses current avgdl. + if (scorer_) { + scorer_->update_stats(new_total_docs, new_total_tokens); + } + + // CAS-update min_doc_len_ only when this document has tokens (doc_len > 0). + if (doc_len > 0) { + uint32_t cur = min_doc_len_.load(std::memory_order_relaxed); + while (doc_len < cur && !min_doc_len_.compare_exchange_weak( + cur, doc_len, std::memory_order_relaxed)) { + } + } + + return {}; +} + +Result FtsColumnIndexer::flush() { + // safe access check + + if (!stat_cf_) { + return {}; + } + + // Write total_docs and total_tokens to $SEGMENT_STAT CF. + // Use acquire ordering so we see all inserts that happened before flush(). + const uint64_t snapshot_total_docs = + total_docs_.load(std::memory_order_acquire); + const uint64_t snapshot_total_tokens = + total_tokens_.load(std::memory_order_acquire); + + ctx_->db_->Put(ctx_->write_opts_, stat_cf_, make_total_docs_key(field_name_), + encode_uint64_value(snapshot_total_docs)); + ctx_->db_->Put(ctx_->write_opts_, stat_cf_, + make_total_tokens_key(field_name_), + encode_uint64_value(snapshot_total_tokens)); + + return {}; +} + +// ============================================================ +// BitPacked conversion (called by MutableSegment::dump_fts_column_indexers) +// ============================================================ + +Result FtsColumnIndexer::convert_postings_to_bitpacked() { + // safe access check + + if (!postings_cf_ || !term_freq_cf_ || !doc_len_cf_ || !scorer_) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer::convert_postings_to_bitpacked: not opened. field=", + field_name_)); + } + + // --------------------------------------------------------------- + // 1) Load doc_len_cf into an in-memory vector indexed by local doc_id. + // Single segment is at most a few MB even for 1M docs (4B per doc), + // so a flat vector is by far the cheapest lookup structure. + // --------------------------------------------------------------- + std::vector doc_lens; + { + std::unique_ptr iter( + ctx_->db_->NewIterator(ctx_->read_opts_, doc_len_cf_.load())); + iter->SeekToFirst(); + while (iter->Valid()) { + const std::string key = iter->key().ToString(); + const std::string value = iter->value().ToString(); + if (key.size() != sizeof(uint32_t) || value.size() != sizeof(uint32_t)) { + LOG_WARN( + "FtsColumnIndexer::convert_postings_to_bitpacked: malformed " + "doc_len entry. field[%s] key_size[%zu] value_size[%zu]", + field_name_.c_str(), key.size(), value.size()); + iter->Next(); + continue; + } + uint32_t local_doc_id = 0; + uint32_t doc_len = 0; + std::memcpy(&local_doc_id, key.data(), sizeof(uint32_t)); + std::memcpy(&doc_len, value.data(), sizeof(uint32_t)); + if (local_doc_id >= doc_lens.size()) { + // Resize with default 1 to avoid divide-by-zero / log(0) downstream + // if a stray doc_id ever shows up without a doc_len entry. + doc_lens.resize(local_doc_id + 1, 1); + } + doc_lens[local_doc_id] = doc_len; + iter->Next(); + } + } + + // --------------------------------------------------------------- + // 2) Streaming scan of term_freq_cf, grouped by term. + // RocksDB BytewiseComparator + big-endian doc_id encoding guarantees + // that within a term, doc_ids appear in ascending order — exactly what + // BitPackedPostingList::encode() requires. + // --------------------------------------------------------------- + std::string current_term; + std::vector doc_ids; + std::vector tfs; + std::vector term_doc_lens; // reused buffer + + auto flush_current_term = [&]() -> Result { + if (current_term.empty() || doc_ids.empty()) { + return {}; + } + // Idempotency: skip if this term's postings are already BitPacked. + // Important for crash-recovery — a re-run of dump after a partial + // conversion must not double-encode. + std::string existing; + auto get_ret = + ctx_->db_->Get(ctx_->read_opts_, postings_cf_, current_term, &existing); + if (get_ret.ok() && !existing.empty() && + BitPackedPostingList::is_bitpacked_format(existing.data(), + existing.size())) { + return {}; + } + + term_doc_lens.assign(doc_ids.size(), 1); + for (size_t i = 0; i < doc_ids.size(); ++i) { + const uint32_t did = doc_ids[i]; + if (did < doc_lens.size() && doc_lens[did] > 0) { + term_doc_lens[i] = doc_lens[did]; + } + } + std::string packed = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), term_doc_lens.data(), doc_ids.size(), + /*df=*/doc_ids.size(), *scorer_); + if (!ctx_->db_->Put(ctx_->write_opts_, postings_cf_, current_term, packed) + .ok()) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer::convert_postings_to_bitpacked: put failed. field=", + field_name_, " term=", current_term)); + } + return {}; + }; + + { + std::unique_ptr iter( + ctx_->db_->NewIterator(ctx_->read_opts_, term_freq_cf_.load())); + iter->SeekToFirst(); + while (iter->Valid()) { + const std::string key = iter->key().ToString(); + const std::string value = iter->value().ToString(); + std::string term; + uint32_t local_doc_id = 0; + if (!parse_doc_term_key(key, &term, &local_doc_id) || + value.size() != sizeof(uint32_t)) { + LOG_WARN( + "FtsColumnIndexer::convert_postings_to_bitpacked: malformed " + "term_freq entry. field[%s] key_size[%zu] value_size[%zu]", + field_name_.c_str(), key.size(), value.size()); + iter->Next(); + continue; + } + uint32_t tf = 0; + std::memcpy(&tf, value.data(), sizeof(uint32_t)); + + if (term != current_term) { + auto ret = flush_current_term(); + if (!ret) { + return ret; + } + current_term = std::move(term); + doc_ids.clear(); + tfs.clear(); + } + doc_ids.push_back(local_doc_id); + tfs.push_back(tf); + iter->Next(); + } + } + // Flush the last term. + auto ret = flush_current_term(); + if (!ret) { + return ret; + } + + // --------------------------------------------------------------- + // 3) Clear $TF / $DOC_LEN / $MAX_TF CFs via DeleteRange. + // + // All payloads (tf, doc_len, max_score) have been inlined into the + // BitPacked postings in step 2. Wiping them here ensures the SST files + // are cleaned up during the dump-side compaction, so the dumped immutable + // segment is significantly smaller. MutableSegment then drops the CFs + // entirely after all indexers finish conversion. + // + // DeleteRange uses [begin, end) semantics; an empty begin and a 256-byte + // 0xFF end together cover every possible key in these CFs. + // --------------------------------------------------------------- + static const std::string kClearBegin{}; + static const std::string kClearEnd(256, '\xFF'); + + const std::pair cfs_to_clear[] = + { + {"$TF", term_freq_cf_.load()}, + {"$DOC_LEN", doc_len_cf_.load()}, + {"$MAX_TF", max_tf_cf_.load()}, + }; + for (const auto &[cf_name, cf] : cfs_to_clear) { + if (cf == nullptr) { + continue; + } + if (!ctx_->db_->DeleteRange(ctx_->write_opts_, cf, kClearBegin, kClearEnd) + .ok()) { + return tl::make_unexpected(Status::InternalError( + "FtsColumnIndexer::convert_postings_to_bitpacked: failed to clear ", + cf_name, " CF. field=", field_name_)); + } + } + + return {}; +} + +// ============================================================ +// Private helper methods +// ============================================================ + +void FtsColumnIndexer::encode_varint(uint32_t value, std::string *output) { + while (value >= 0x80) { + output->push_back(static_cast((value & 0x7F) | 0x80)); + value >>= 7; + } + output->push_back(static_cast(value)); +} + +std::string FtsColumnIndexer::encode_positions( + const std::vector &positions) { + std::string result; + uint32_t prev_position = 0; + for (uint32_t position : positions) { + // Delta encoding: store the difference between adjacent positions + encode_varint(position - prev_position, &result); + prev_position = position; + } + return result; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_column_indexer.h b/src/db/index/column/fts_column/fts_column_indexer.h new file mode 100644 index 000000000..e57fcd40a --- /dev/null +++ b/src/db/index/column/fts_column/fts_column_indexer.h @@ -0,0 +1,238 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "db/common/rocksdb_context.h" +#include "db/index/column/fts_column/fts_types.h" +#include "iterator/fts_doc_iterator.h" +#include "tokenizer/tokenizer_factory.h" +#include "bm25_scorer.h" +#include "fts_query_ast.h" + + +namespace zvec::fts { + +/*! Single document in FTS query results. + * + * Note: `doc_id` here is the GLOBAL doc_id */ +struct FtsResult { + uint64_t doc_id{0}; + float score{0.0f}; + + bool operator>(const FtsResult &other) const { + return score > other.score; + } +}; + +/*! FTS column indexer + * Handles both read (search with BM25 + WAND) and write (insert / flush) + * operations on a single FTS column backed by RocksDB. + * Uses cross-CF WriteBatch to batch all per-document writes into a single + * atomic RocksDB Write() call for optimal write throughput. + */ +class FtsColumnIndexer { + public: + FtsColumnIndexer() = default; + ~FtsColumnIndexer(); + + // ----------------------------------------------------------------- + // Initialization + // ----------------------------------------------------------------- + + /*! Initialize for read+write (mutable path). + * \param field_meta Field meta describing this FTS field; provides both + * the field name and the tokenizer extra params used + * to acquire/release the shared pipeline. + * \param ctx RocksdbContext pointer + * \param postings_cf postings CF (main CF) + * \param positions_cf $POS CF + * \param term_freq_cf $TF CF + * \param max_tf_cf $MAX_TF CF + * \param doc_len_cf $DOC_LEN CF + * \param stat_cf $SEGMENT_STAT CF + * \return Result on success, or Status on failure + */ + Result open(FieldSchema::Ptr field_meta, RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *postings_cf, + rocksdb::ColumnFamilyHandle *positions_cf, + rocksdb::ColumnFamilyHandle *term_freq_cf, + rocksdb::ColumnFamilyHandle *max_tf_cf, + rocksdb::ColumnFamilyHandle *doc_len_cf, + rocksdb::ColumnFamilyHandle *stat_cf); + + /*! Initialize for read-only (immutable / standalone reader path). + * No tokenizer is acquired; insert() will fail if called. + * \param field_name Field name + * \param ctx RocksdbContext pointer + * \param postings_cf postings CF + * \param positions_cf $POS CF + * \param term_freq_cf $TF CF (may be nullptr for immutable) + * \param max_tf_cf $MAX_TF CF (may be nullptr) + * \param doc_len_cf $DOC_LEN CF (may be nullptr) + * \param stat_cf $SEGMENT_STAT CF + * \param bm25_params BM25 parameters (k1, b) + * \return Result on success, or Status on failure + */ + Result open_reader(const std::string &field_name, RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *postings_cf, + rocksdb::ColumnFamilyHandle *positions_cf, + rocksdb::ColumnFamilyHandle *term_freq_cf, + rocksdb::ColumnFamilyHandle *max_tf_cf, + rocksdb::ColumnFamilyHandle *doc_len_cf, + rocksdb::ColumnFamilyHandle *stat_cf, + BM25Params bm25_params = BM25Params{}); + + /*! Release all CF pointers and reset internal state. + * Thread-safe: waits for in-flight search() calls to drain before + * invalidating any state. Must be called before the underlying + * RocksdbStore is closed. + * \return Result on success, or Status on failure (e.g. already + * closed). + */ + Result close(); + + // ----------------------------------------------------------------- + // Query + // ----------------------------------------------------------------- + + /*! Execute FTS query and return result list with BM25 scores + * \param ast Pre-parsed FTS AST (caller owns the parse step) + * \param query_params Query parameters (topk, filter, etc.) + * \return Result containing sorted results (descending score), or Status + */ + Result> search( + const FtsAstNode &ast, const FtsQueryParams &query_params) const; + + /*! Atomically reset $TF/$MAX_TF/$DOC_LEN CF pointers to nullptr. + * Called before dropping these CFs so that concurrent search() calls + * on the Roaring path gracefully degrade (return default tf=1/doc_len=1). + */ + void reset_side_cfs(); + + // ----------------------------------------------------------------- + // Write + // ----------------------------------------------------------------- + + /*! Insert FTS field content for a document + * \param seg_doc_id Segment-local document ID + * \param text UTF-8 encoded text content + * \return Result on success, or Status on failure + */ + Result insert(uint64_t seg_doc_id, const std::string &text); + + /*! Flush in-memory statistics to RocksDB (called before segment dump) + * \return Result on success, or Status on failure + */ + Result flush(); + + /*! Convert all Roaring-format postings in postings_cf to BitPacked format + * with inline tf/doc_len/max_score payloads, then DeleteRange-clear the + * $TF, $DOC_LEN, and $MAX_TF CFs. + * + * Called by MutableSegment::dump_fts_column_indexers() right before the + * SST dump. After all indexers finish conversion, MutableSegment drops + * the $TF/$MAX_TF/$DOC_LEN CFs entirely (via reset_side_cfs() + + * RocksdbStore::drop_column_family()), so the dumped immutable segment + * no longer contains these CFs at all. + * + * Idempotent: terms whose postings are already in BitPacked format are + * skipped, so re-running after a partial-failure dump is safe. + * + * Must be called after flush() so that the BM25 scorer used by encode() + * sees the up-to-date segment statistics. + * + * \return Result on success, or Status on failure + */ + Result convert_postings_to_bitpacked(); + + uint64_t total_docs() const { + return total_docs_.load(std::memory_order_relaxed); + } + uint64_t total_tokens() const { + return total_tokens_.load(std::memory_order_relaxed); + } + + // Accessors used by the compaction-time FTS reducer to feed source segments + // (postings + positions) without going through the higher-level search path. + RocksdbContext *ctx() const { + return ctx_; + } + rocksdb::ColumnFamilyHandle *postings_cf() const { + return postings_cf_; + } + rocksdb::ColumnFamilyHandle *positions_cf() const { + return positions_cf_; + } + + private: + // --- Iterator tree construction (search internals) --- + Result build_iterator(const FtsAstNode &node) const; + Result build_term_iterator(const TermNode &term_node) const; + Result build_phrase_iterator( + const PhraseNode &phrase_node) const; + Result build_and_iterator(const AndNode &and_node) const; + Result build_or_iterator(const OrNode &or_node) const; + Result create_term_iterator_from_raw( + const std::string &term, rocksdb::PinnableSlice raw_data, + float boost = 1.0f) const; + std::vector batch_get_postings( + const std::vector &terms) const; + + // --- Write helpers --- + static void encode_varint(uint32_t value, std::string *output); + static std::string encode_positions(const std::vector &positions); + + // --- Tokenizer (write path only) --- + FieldSchema::Ptr field_meta_{}; + TokenizerPipelinePtr tokenizer_pipeline_{nullptr}; + std::shared_ptr fts_params_; + + // --- Reader state --- + std::string field_name_; + RocksdbContext *ctx_{nullptr}; + BM25ScorerPtr scorer_; + + rocksdb::ColumnFamilyHandle *postings_cf_{nullptr}; + rocksdb::ColumnFamilyHandle *positions_cf_{nullptr}; + std::atomic term_freq_cf_{nullptr}; + std::atomic max_tf_cf_{nullptr}; + std::atomic doc_len_cf_{nullptr}; + mutable std::atomic cf_counter_{0}; + std::atomic cf_dropped_{false}; + rocksdb::ColumnFamilyHandle *stat_cf_{nullptr}; + + // Minimum doc length observed so far. Used as a (loose) lower bound on + // doc_len when computing the WAND max_score for Roaring-format postings. + std::atomic min_doc_len_{std::numeric_limits::max()}; + + mutable std::atomic counter_{0}; + std::atomic opened_{false}; + + // --- Write-path statistics --- + std::atomic total_docs_{0}; + std::atomic total_tokens_{0}; +}; + +using FtsColumnIndexerPtr = std::shared_ptr; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_index_results.h b/src/db/index/column/fts_column/fts_index_results.h new file mode 100644 index 000000000..dc65c42a8 --- /dev/null +++ b/src/db/index/column/fts_column/fts_index_results.h @@ -0,0 +1,85 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "db/common/constants.h" +#include "db/index/column/common/index_results.h" +#include "db/index/column/fts_column/fts_column_indexer.h" + +namespace zvec { + +// IndexResults adapter for FTS search results (doc_id + BM25 score pairs). +// Results are ordered by descending score from FtsColumnIndexer::search(). +class FtsIndexResults : public IndexResults, + public std::enable_shared_from_this { + public: + using Ptr = std::shared_ptr; + + explicit FtsIndexResults(std::vector results) + : results_(std::move(results)) {} + + size_t count() const override { + return results_.size(); + } + + const std::vector &results() const { + return results_; + } + + class FtsIterator : public Iterator { + public: + explicit FtsIterator(std::shared_ptr owner) + : owner_(std::move(owner)), pos_(0) {} + + idx_t doc_id() const override { + if (pos_ < owner_->results_.size()) { + return static_cast(owner_->results_[pos_].doc_id); + } + return INVALID_DOC_ID; + } + + float score() const override { + if (pos_ < owner_->results_.size()) { + return owner_->results_[pos_].score; + } + return 0.0f; + } + + void next() override { + if (pos_ < owner_->results_.size()) { + ++pos_; + } + } + + bool valid() const override { + return pos_ < owner_->results_.size(); + } + + private: + std::shared_ptr owner_; + size_t pos_; + }; + + IteratorUPtr create_iterator() override { + return std::make_unique(shared_from_this()); + } + + private: + std::vector results_; +}; + +} // namespace zvec diff --git a/src/db/index/column/fts_column/fts_pipeline.h b/src/db/index/column/fts_column/fts_pipeline.h new file mode 100644 index 000000000..793f6007f --- /dev/null +++ b/src/db/index/column/fts_column/fts_pipeline.h @@ -0,0 +1,37 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace zvec { + +namespace fts { +class TokenizerPipeline; +} // namespace fts + +namespace detail { + +// Internal entry to lazily acquire (and cache, per FtsIndexParams instance) +// the tokenizer pipeline. Thread-safe; same params instance returns the +// same shared_ptr on subsequent calls; the manager-side reference is +// released when the params instance is destroyed. +Result> AcquireFtsPipeline( + FtsIndexParams ¶ms); + +} // namespace detail +} // namespace zvec diff --git a/src/db/index/column/fts_column/fts_query_ast.h b/src/db/index/column/fts_column/fts_query_ast.h new file mode 100644 index 000000000..61d0a0a0e --- /dev/null +++ b/src/db/index/column/fts_column/fts_query_ast.h @@ -0,0 +1,186 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace zvec::fts { + +/*! AST node type enumeration + */ +enum class FtsNodeType { + TERM, // Term node, e.g., "vector" + PHRASE, // Phrase node, e.g., "\"exact phrase\"" + AND, // AND combination node (intersection) + OR, // OR combination node (union) + EMPTY, // Matches zero documents (analogous to Lucene MatchNoDocsQuery). +}; + +/*! AST node base class + * All FTS AST nodes carry must/must_not modifiers so that the +/- prefix + * (and AND NOT semantics) can be applied uniformly to terms, phrases and + * composite (AND/OR) sub-expressions. + */ +struct FtsAstNode { + bool must{false}; // Prefix + means must + bool must_not{false}; // Prefix - / right-hand side of AND NOT means must_not + // Per-node scoring weight. Currently meaningful only on TermNode / PhraseNode + // (composite nodes inherit boost from their scored leaves). Repeated terms in + // a sibling list are collapsed by the AST rewriter into a single node whose + // boost is the linear sum of duplicates, so that the post-rewrite score + // matches the pre-rewrite "sum of independent scorers" semantics exactly. + float boost{1.0f}; + + virtual ~FtsAstNode() = default; + virtual FtsNodeType type() const = 0; + + // Return a human-readable text representation for debugging / logging + virtual std::string text() const = 0; + + protected: + // Helper: prepend +/- modifier prefix + std::string modifier_prefix() const { + if (must) { + return "+"; + } + if (must_not) { + return "-"; + } + return ""; + } + + // Helper: append ^X boost suffix when boost differs from default 1.0 + std::string boost_suffix() const { + if (std::fabs(boost - 1.0f) < 1e-6f) { + return ""; + } + return "^" + std::to_string(boost); + } +}; + +using FtsAstNodePtr = std::unique_ptr; + +/*! Term node + * Represents a single query term, can have must (+) or must_not (-) modifiers + * inherited from FtsAstNode. + */ +struct TermNode : public FtsAstNode { + std::string term; + + explicit TermNode(std::string term_text, bool is_must = false, + bool is_must_not = false) + : term(std::move(term_text)) { + must = is_must; + must_not = is_must_not; + } + + FtsNodeType type() const override { + return FtsNodeType::TERM; + } + + std::string text() const override { + return modifier_prefix() + term + boost_suffix(); + } +}; + +/*! Phrase node + * Represents an exact phrase query, e.g., "exact phrase" + * Requires exact match of word order and adjacent positions + */ +struct PhraseNode : public FtsAstNode { + std::vector terms; // Individual words in the phrase + + FtsNodeType type() const override { + return FtsNodeType::PHRASE; + } + + std::string text() const override { + std::string result = modifier_prefix() + "\""; + for (size_t i = 0; i < terms.size(); ++i) { + if (i > 0) { + result += " "; + } + result += terms[i]; + } + result += "\""; + result += boost_suffix(); + return result; + } +}; + +/*! Match-nothing node — used when the analyzer drops every term (e.g. + * pure punctuation or all stop-words). Composes naturally with AND/OR so + * callers don't have to special-case nullptr. + */ +struct EmptyNode : public FtsAstNode { + FtsNodeType type() const override { + return FtsNodeType::EMPTY; + } + + std::string text() const override { + return modifier_prefix() + ""; + } +}; + +/*! AND combination node + * All child nodes must match (intersection semantics) + */ +struct AndNode : public FtsAstNode { + std::vector children; + + FtsNodeType type() const override { + return FtsNodeType::AND; + } + + std::string text() const override { + std::string result = modifier_prefix() + "AND("; + for (size_t i = 0; i < children.size(); ++i) { + if (i > 0) { + result += " "; + } + result += children[i]->text(); + } + result += ")"; + return result; + } +}; + +/*! OR combination node + * Any child node matches (union semantics) + */ +struct OrNode : public FtsAstNode { + std::vector children; + + FtsNodeType type() const override { + return FtsNodeType::OR; + } + + std::string text() const override { + std::string result = modifier_prefix() + "OR("; + for (size_t i = 0; i < children.size(); ++i) { + if (i > 0) { + result += " "; + } + result += children[i]->text(); + } + result += ")"; + return result; + } +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_rocksdb_merge.cc b/src/db/index/column/fts_column/fts_rocksdb_merge.cc new file mode 100644 index 000000000..737e321ca --- /dev/null +++ b/src/db/index/column/fts_column/fts_rocksdb_merge.cc @@ -0,0 +1,181 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_rocksdb_merge.h" +#include +#include +#include +#include "db/index/column/fts_column/posting/bitpacked_posting_list.h" + +namespace zvec::fts { + +// ============================================================ +// Helper: deserialize a posting value (Roaring Bitmap or BitPacked) into a +// Roaring Bitmap. Caller owns the returned bitmap and must free it. +// Returns nullptr on failure. +// ============================================================ + +static roaring_bitmap_t *deserialize_posting_to_roaring(const char *data, + size_t size) { + if (BitPackedPostingList::is_bitpacked_format(data, size)) { + // Decode BitPacked format into a new Roaring Bitmap + BitPackedPostingIterator bp_iter; + if (bp_iter.open(data, size) != 0) { + LOG_ERROR( + "FtsPostingsMerge: failed to open bitpacked posting during merge, " + "size[%zu]", + size); + return nullptr; + } + roaring_bitmap_t *bitmap = roaring_bitmap_create(); + uint32_t doc_id = bp_iter.next_doc(); + while (doc_id != BitPackedPostingIterator::NO_MORE_DOCS) { + roaring_bitmap_add(bitmap, doc_id); + doc_id = bp_iter.next_doc(); + } + return bitmap; + } + + // Roaring Bitmap format + return roaring_bitmap_portable_deserialize_safe(data, size); +} + +// ============================================================ +// FtsPostingsMerge: Roaring Bitmap OR merge (supports BitPacked input) +// ============================================================ + +bool FtsPostingsMerge::FullMergeV2(const MergeOperationInput &merge_in, + MergeOperationOutput *merge_out) const { + // If there is only one operand and no existing_value, return directly + if (merge_in.existing_value == nullptr && merge_in.operand_list.size() == 1) { + merge_out->new_value = std::string(merge_in.operand_list[0].data(), + merge_in.operand_list[0].size()); + return true; + } + + // Deserialize bitmap from existing_value + roaring_bitmap_t *result_bitmap = roaring_bitmap_create(); + + if (merge_in.existing_value != nullptr) { + roaring_bitmap_t *existing_bitmap = deserialize_posting_to_roaring( + merge_in.existing_value->data(), merge_in.existing_value->size()); + if (existing_bitmap != nullptr) { + roaring_bitmap_or_inplace(result_bitmap, existing_bitmap); + roaring_bitmap_free(existing_bitmap); + } + } + + // Merge all operands + for (const auto &operand : merge_in.operand_list) { + roaring_bitmap_t *operand_bitmap = + deserialize_posting_to_roaring(operand.data(), operand.size()); + if (operand_bitmap != nullptr) { + roaring_bitmap_or_inplace(result_bitmap, operand_bitmap); + roaring_bitmap_free(operand_bitmap); + } + } + + // Serialize result as Roaring Bitmap + roaring_bitmap_run_optimize(result_bitmap); + size_t serialized_size = roaring_bitmap_portable_size_in_bytes(result_bitmap); + merge_out->new_value.resize(serialized_size); + roaring_bitmap_portable_serialize(result_bitmap, merge_out->new_value.data()); + roaring_bitmap_free(result_bitmap); + return true; +} + +bool FtsPostingsMerge::PartialMerge(const rocksdb::Slice & /*key*/, + const rocksdb::Slice &left_operand, + const rocksdb::Slice &right_operand, + std::string *new_value, + rocksdb::Logger * /*logger*/) const { + roaring_bitmap_t *left_bitmap = + deserialize_posting_to_roaring(left_operand.data(), left_operand.size()); + roaring_bitmap_t *right_bitmap = deserialize_posting_to_roaring( + right_operand.data(), right_operand.size()); + + if (left_bitmap == nullptr || right_bitmap == nullptr) { + LOG_ERROR( + "FtsPostingsMerge::PartialMerge: failed to deserialize operand. " + "left_size[%zu] right_size[%zu]", + left_operand.size(), right_operand.size()); + if (left_bitmap != nullptr) roaring_bitmap_free(left_bitmap); + if (right_bitmap != nullptr) roaring_bitmap_free(right_bitmap); + return false; + } + + roaring_bitmap_or_inplace(left_bitmap, right_bitmap); + roaring_bitmap_free(right_bitmap); + + size_t serialized_size = roaring_bitmap_portable_size_in_bytes(left_bitmap); + new_value->resize(serialized_size); + roaring_bitmap_portable_serialize(left_bitmap, new_value->data()); + roaring_bitmap_free(left_bitmap); + return true; +} + +// ============================================================ +// FtsMaxTfMerge: uint32_t max merge +// ============================================================ + +bool FtsMaxTfMerge::FullMergeV2(const MergeOperationInput &merge_in, + MergeOperationOutput *merge_out) const { + uint32_t max_tf = 0; + + if (merge_in.existing_value != nullptr && + merge_in.existing_value->size() >= sizeof(uint32_t)) { + std::memcpy(&max_tf, merge_in.existing_value->data(), sizeof(uint32_t)); + } + + for (const auto &operand : merge_in.operand_list) { + if (operand.size() >= sizeof(uint32_t)) { + uint32_t operand_tf = 0; + std::memcpy(&operand_tf, operand.data(), sizeof(uint32_t)); + if (operand_tf > max_tf) { + max_tf = operand_tf; + } + } + } + + merge_out->new_value.resize(sizeof(uint32_t)); + std::memcpy(merge_out->new_value.data(), &max_tf, sizeof(uint32_t)); + return true; +} + +bool FtsMaxTfMerge::PartialMerge(const rocksdb::Slice & /*key*/, + const rocksdb::Slice &left_operand, + const rocksdb::Slice &right_operand, + std::string *new_value, + rocksdb::Logger * /*logger*/) const { + if (left_operand.size() < sizeof(uint32_t) || + right_operand.size() < sizeof(uint32_t)) { + LOG_ERROR( + "FtsMaxTfMerge::PartialMerge: operand too small. " + "left_size[%zu] right_size[%zu] expected[%zu]", + left_operand.size(), right_operand.size(), sizeof(uint32_t)); + return false; + } + + uint32_t left_tf = 0; + uint32_t right_tf = 0; + std::memcpy(&left_tf, left_operand.data(), sizeof(uint32_t)); + std::memcpy(&right_tf, right_operand.data(), sizeof(uint32_t)); + + uint32_t max_tf = (left_tf > right_tf) ? left_tf : right_tf; + new_value->resize(sizeof(uint32_t)); + std::memcpy(new_value->data(), &max_tf, sizeof(uint32_t)); + return true; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_rocksdb_merge.h b/src/db/index/column/fts_column/fts_rocksdb_merge.h new file mode 100644 index 000000000..1bed8f4b6 --- /dev/null +++ b/src/db/index/column/fts_column/fts_rocksdb_merge.h @@ -0,0 +1,59 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace zvec::fts { + +/*! FTS postings CF-specific Merge Operator + * Performs OR merge on Roaring Bitmap serialized values, used for + * incrementally updating term document lists + */ +class FtsPostingsMerge : public ROCKSDB_NAMESPACE::MergeOperator { + public: + bool FullMergeV2(const MergeOperationInput &merge_in, + MergeOperationOutput *merge_out) const override; + + bool PartialMerge(const rocksdb::Slice &key, + const rocksdb::Slice &left_operand, + const rocksdb::Slice &right_operand, std::string *new_value, + rocksdb::Logger *logger) const override; + + const char *Name() const override { + return "FtsPostingsMerge"; + } +}; + +/*! FTS $MAX_TF CF-specific Merge Operator + * Performs max merge on uint32_t values, used for maintaining the maximum term + * frequency for each term (WAND upper bound) + */ +class FtsMaxTfMerge : public ROCKSDB_NAMESPACE::MergeOperator { + public: + bool FullMergeV2(const MergeOperationInput &merge_in, + MergeOperationOutput *merge_out) const override; + + bool PartialMerge(const rocksdb::Slice &key, + const rocksdb::Slice &left_operand, + const rocksdb::Slice &right_operand, std::string *new_value, + rocksdb::Logger *logger) const override; + + const char *Name() const override { + return "FtsMaxTfMerge"; + } +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_rocksdb_reducer.cc b/src/db/index/column/fts_column/fts_rocksdb_reducer.cc new file mode 100644 index 000000000..21f3567b2 --- /dev/null +++ b/src/db/index/column/fts_column/fts_rocksdb_reducer.cc @@ -0,0 +1,492 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_rocksdb_reducer.h" +#include +#include +#include +#include +#include "db/index/column/fts_column/fts_utils.h" +#include "db/index/column/fts_column/posting/bitpacked_posting_list.h" + +namespace zvec::fts { + +namespace { + +// Dense survivor index in [0, effective_total_docs), or kFilteredRank if +// scan_pos is in the delete bitmap. Roaring rank(x) counts elements ≤ x; +// for an alive scan_pos that's exactly the number of deletes strictly +// before it, so `scan_pos - rank(scan_pos)` is its survivor rank. +constexpr uint32_t kFilteredRank = std::numeric_limits::max(); + +inline uint32_t dense_rank(uint64_t scan_pos, const roaring::Roaring &bitmap) { + const uint32_t pos32 = static_cast(scan_pos); + if (bitmap.contains(pos32)) { + return kFilteredRank; + } + return static_cast(scan_pos - bitmap.rank(pos32)); +} + +} // namespace + +// ============================================================ +// Design notes +// ============================================================ +// +// Immutable FTS segment CFs: +// - postings_cf : term -> BitPacked posting (inline tf/doc_len/max_score) +// - positions_cf : term\0doc_id -> varint delta positions (phrase queries) +// - stat_cf : field_total_docs / field_total_tokens +// +// Multi-way merge N source segments into one destination, in two passes. +// All input postings must be BitPacked; output is BitPacked too — no +// Roaring intermediate, no side CFs ($TF/$MAX_TF/$DOC_LEN) read or written. +// +// Doc id spaces: +// SRC LOCAL ∈ [0, stats.doc_count): value stored in src postings. +// SCAN POS ∈ [0, Σ stats.doc_count): feed-order concatenated position; +// same id space as SegmentHelper::delete_row_id_bitmap. +// scan_pos = scan_offset_per_seg_[seg] + local +// DST LOCAL ∈ [0, effective_total_docs_): dense survivor rank. +// Equals the row index ReduceScalar writes into the new +// segment's densified forward storage, so post-merge fetch() +// needs no translation. +// dst_local = scan_pos - bitmap.rank(scan_pos) +// +// Pass 1 (collect_effective_stats): no per-doc materialization. +// - effective_total_docs_ = Σ stats.doc_count - bitmap.cardinality() +// - effective_total_tokens_ = sum of survivors' inline doc_len +// (per-segment dedup uses vector, ~125 KB / 1M docs) +// +// Pass 2 (merge_and_flush_postings): N RocksDB iterators, term-by-term +// multi-way merge in lex order; per-term entries are encoded + put +// immediately so peak memory is one term's entries. dst_local resolved +// on the fly via dense_rank(scan_pos), sharing the bitmap with the +// vector reducer. + +// ============================================================ +// Public interface +// ============================================================ + +Result FtsRocksdbReducer::init( + const std::string &field_name, RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *dst_postings_cf, + rocksdb::ColumnFamilyHandle *dst_positions_cf, + rocksdb::ColumnFamilyHandle *dst_stat_cf) { + if (!dst_postings_cf || !dst_positions_cf || !dst_stat_cf) { + return tl::make_unexpected(Status::InvalidArgument( + "FtsRocksdbReducer: null destination CF. field=", field_name)); + } + + field_name_ = field_name; + ctx_ = ctx; + dst_postings_cf_ = dst_postings_cf; + dst_positions_cf_ = dst_positions_cf; + dst_stat_cf_ = dst_stat_cf; + + state_ = STATE_INITED; + return {}; +} + +Result FtsRocksdbReducer::cleanup() { + segment_stats_.clear(); + src_ctxs_.clear(); + src_postings_cfs_.clear(); + src_positions_cfs_.clear(); + scan_offset_per_seg_.clear(); + num_segments_ = 0; + state_ = STATE_UNINITED; + return {}; +} + +Result FtsRocksdbReducer::feed( + FtsSegmentStats segment_stats, RocksdbContext *src_ctx, + rocksdb::ColumnFamilyHandle *src_postings_cf, + rocksdb::ColumnFamilyHandle *src_positions_cf) { + if (state_ != STATE_INITED && state_ != STATE_FEED) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: call init() before feed(). field=", field_name_)); + } + + if (!src_postings_cf || !src_positions_cf) { + return tl::make_unexpected(Status::InvalidArgument( + "FtsRocksdbReducer: null source CF. field=", field_name_)); + } + + // doc_count == 0 segments contribute nothing; mark state and skip so the + // contiguity check and scan_offset cumsum only see non-empty inputs (the + // matching FilterRecordBatch / RowIdFilter id space behaves the same way). + if (segment_stats.doc_count == 0) { + state_ = STATE_FEED; + return {}; + } + + // Require consecutive global doc_id ranges between non-empty segments so + // the shared delete_row_id_bitmap stays aligned with input scan order. + if (!segment_stats_.empty() && + segment_stats.min_doc_id != segment_stats_.back().max_doc_id + 1) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: segments not in consecutive doc_id order. field=", + field_name_)); + } + + segment_stats_.emplace_back(std::move(segment_stats)); + src_ctxs_.emplace_back(src_ctx); + src_postings_cfs_.emplace_back(src_postings_cf); + src_positions_cfs_.emplace_back(src_positions_cf); + ++num_segments_; + + state_ = STATE_FEED; + return {}; +} + +Result FtsRocksdbReducer::reduce( + const roaring::Roaring &delete_row_id_bitmap) { + if (state_ != STATE_FEED || num_segments_ == 0) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: call feed() before reduce(). field=", field_name_)); + } + + effective_total_docs_ = 0; + effective_total_tokens_ = 0; + + // Precompute scan_offset = cumulative doc_count. Combined with the + // bitmap this lets dense_rank() resolve any (seg, local) in + // O(roaring::rank) without a per-doc table. + scan_offset_per_seg_.assign(num_segments_, 0); + uint64_t cumulative = 0; + for (uint32_t seg = 0; seg < num_segments_; ++seg) { + scan_offset_per_seg_[seg] = cumulative; + cumulative += segment_stats_[seg].doc_count; + } + + // Phase 1: streaming per-term BitPacked merge into dst_postings_cf; + // accumulates effective_total_docs_ / effective_total_tokens_. + auto ret = reduce_postings(delete_row_id_bitmap); + if (!ret) { + LOG_ERROR("FtsRocksdbReducer: reduce_postings failed. field[%s]", + field_name_.c_str()); + return ret; + } + + // Phase 2: per-segment positions CF remap (phrase queries). + for (uint32_t segment_index = 0; segment_index < num_segments_; + ++segment_index) { + ret = reduce_positions(segment_index, delete_row_id_bitmap); + if (!ret) { + LOG_ERROR( + "FtsRocksdbReducer: reduce_positions failed. segment[%u] field[%s]", + segment_index, field_name_.c_str()); + return ret; + } + } + + // Phase 3: persist effective stats — same source of truth used by Phase 1 + // when encoding block_max_score, so search-time IDF/avgdl stays consistent. + ret = flush_stat(effective_total_docs_, effective_total_tokens_); + if (!ret) { + LOG_ERROR("FtsRocksdbReducer: flush_stat failed. field[%s]", + field_name_.c_str()); + return ret; + } + + state_ = STATE_REDUCE; + LOG_INFO( + "FtsRocksdbReducer: reduce done. field[%s] segments[%u] " + "effective_docs[%zu] effective_tokens[%zu]", + field_name_.c_str(), num_segments_, (size_t)effective_total_docs_, + (size_t)effective_total_tokens_); + return {}; +} + +// ============================================================ +// Private +// ============================================================ + +Result FtsRocksdbReducer::reduce_postings( + const roaring::Roaring &delete_row_id_bitmap) { + auto ret = collect_effective_stats(delete_row_id_bitmap); + if (!ret) { + return ret; + } + // Scorer seeded with final effective stats; used by Pass 2 to compute + // block_max_score consistent with the values flushed to stat_cf. + scorer_ = std::make_shared(); + scorer_->update_stats(effective_total_docs_, effective_total_tokens_); + return merge_and_flush_postings(delete_row_id_bitmap); +} + +Result FtsRocksdbReducer::collect_effective_stats( + const roaring::Roaring &delete_row_id_bitmap) { + effective_total_docs_ = 0; + effective_total_tokens_ = 0; + + // effective_total_docs = Σ doc_count - |deletes|. Bitmap covers scan + // positions [0, Σ doc_count), so cardinality() is the exact filtered + // count. Includes empty docs, matching mutable indexer semantics. + uint64_t total_input_docs = 0; + for (const auto &s : segment_stats_) { + total_input_docs += s.doc_count; + } + const uint64_t total_deletes = delete_row_id_bitmap.cardinality(); + if (total_deletes > total_input_docs) { + return tl::make_unexpected( + Status::InternalError("FtsRocksdbReducer: delete bitmap cardinality[", + total_deletes, "] exceeds total input docs[", + total_input_docs, "]. field=", field_name_)); + } + effective_total_docs_ = total_input_docs - total_deletes; + + // effective_total_tokens_: walk every posting, sum doc_len once per + // surviving local_doc_id. Per-segment vector dedup (~125 KB / 1M + // docs) is required because immutable segments have no per-doc doc_len + // column to read from directly. + for (uint32_t seg = 0; seg < num_segments_; ++seg) { + const uint64_t seg_doc_count = segment_stats_[seg].doc_count; + const uint64_t scan_offset = scan_offset_per_seg_[seg]; + std::vector seen_docs(seg_doc_count, false); + + auto *src_cf = src_postings_cfs_[seg]; + auto iter = std::unique_ptr( + src_ctxs_[seg]->db_->NewIterator(src_ctxs_[seg]->read_opts_, src_cf)); + iter->SeekToFirst(); + + while (iter->Valid()) { + const std::string posting_data = iter->value().ToString(); + + if (!BitPackedPostingList::is_bitpacked_format(posting_data.data(), + posting_data.size())) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: source postings is not BitPacked. field=", + field_name_)); + } + + BitPackedPostingIterator bp_iter; + if (bp_iter.open(posting_data.data(), posting_data.size()) != 0) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: failed to open bitpacked postings. field=", + field_name_)); + } + + uint32_t local_doc_id = bp_iter.next_doc(); + while (local_doc_id != BitPackedPostingIterator::NO_MORE_DOCS) { + if (local_doc_id < seg_doc_count && !seen_docs[local_doc_id]) { + const uint64_t scan_pos = scan_offset + local_doc_id; + if (!delete_row_id_bitmap.contains(static_cast(scan_pos))) { + seen_docs[local_doc_id] = true; + effective_total_tokens_ += bp_iter.doc_len(); + } + } + local_doc_id = bp_iter.next_doc(); + } + iter->Next(); + } + } + + LOG_INFO( + "FtsRocksdbReducer: collect_effective_stats done. field[%s] " + "effective_docs[%zu] effective_tokens[%zu]", + field_name_.c_str(), (size_t)effective_total_docs_, + (size_t)effective_total_tokens_); + return {}; +} + +Result FtsRocksdbReducer::merge_and_flush_postings( + const roaring::Roaring &delete_row_id_bitmap) { + struct PostingEntry { + uint32_t doc_id; + uint32_t tf; + uint32_t doc_len; + }; + + // Open N iterators, one per source segment. + struct SegmentCursor { + uint32_t segment_index; + std::unique_ptr iter; + const FtsSegmentStats *stats; + }; + std::vector cursors; + cursors.reserve(num_segments_); + for (uint32_t i = 0; i < num_segments_; ++i) { + auto it = std::unique_ptr(src_ctxs_[i]->db_->NewIterator( + src_ctxs_[i]->read_opts_, src_postings_cfs_[i])); + it->SeekToFirst(); + cursors.push_back(SegmentCursor{i, std::move(it), &segment_stats_[i]}); + } + + // Reusable buffers. + std::vector term_entries; + std::vector doc_ids_buf, tfs_buf, doc_lens_buf; + + while (true) { + // Pick the lex-smallest current term across cursors. + std::string min_term; + bool found = false; + for (auto &c : cursors) { + if (!c.iter->Valid()) { + continue; + } + const std::string t = c.iter->key().ToString(); + if (!found || t < min_term) { + min_term = t; + found = true; + } + } + if (!found) { + break; + } + + // Cursors visited in segment order ⇒ dense ranks emerge ascending. + term_entries.clear(); + for (auto &c : cursors) { + if (!c.iter->Valid()) { + continue; + } + if (c.iter->key().ToString() != min_term) { + continue; + } + + const std::string posting_data = c.iter->value().ToString(); + if (!BitPackedPostingList::is_bitpacked_format(posting_data.data(), + posting_data.size())) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: source postings is not BitPacked. field=", + field_name_, " term=", min_term)); + } + + BitPackedPostingIterator bp_iter; + if (bp_iter.open(posting_data.data(), posting_data.size()) != 0) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: failed to open bitpacked postings. field=", + field_name_, " term=", min_term)); + } + + term_entries.reserve(term_entries.size() + bp_iter.cost()); + const uint64_t scan_offset = scan_offset_per_seg_[c.segment_index]; + const uint64_t seg_doc_count = c.stats->doc_count; + uint32_t local_doc_id = bp_iter.next_doc(); + while (local_doc_id != BitPackedPostingIterator::NO_MORE_DOCS) { + if (local_doc_id < seg_doc_count) { + const uint32_t new_doc_id = + dense_rank(scan_offset + local_doc_id, delete_row_id_bitmap); + if (new_doc_id != kFilteredRank) { + term_entries.push_back( + {new_doc_id, bp_iter.term_freq(), bp_iter.doc_len()}); + } + } + local_doc_id = bp_iter.next_doc(); + } + c.iter->Next(); + } + + if (term_entries.empty()) { + continue; + } + + // Encode + put per term ⇒ peak memory is one term's entries. + doc_ids_buf.clear(); + tfs_buf.clear(); + doc_lens_buf.clear(); + doc_ids_buf.reserve(term_entries.size()); + tfs_buf.reserve(term_entries.size()); + doc_lens_buf.reserve(term_entries.size()); + for (const auto &e : term_entries) { + doc_ids_buf.push_back(e.doc_id); + tfs_buf.push_back(e.tf); + doc_lens_buf.push_back(e.doc_len); + } + + std::string packed = BitPackedPostingList::encode( + doc_ids_buf.data(), tfs_buf.data(), doc_lens_buf.data(), + doc_ids_buf.size(), doc_ids_buf.size(), *scorer_); + + if (!ctx_->db_->Put(ctx_->write_opts_, dst_postings_cf_, min_term, packed) + .ok()) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: failed to put bitpacked postings. field=", + field_name_)); + } + } + + return {}; +} + +Result FtsRocksdbReducer::reduce_positions( + uint32_t segment_index, const roaring::Roaring &delete_row_id_bitmap) { + auto *src_positions_cf = src_positions_cfs_[segment_index]; + const uint64_t scan_offset = scan_offset_per_seg_[segment_index]; + const uint64_t seg_doc_count = segment_stats_[segment_index].doc_count; + + auto iter = std::unique_ptr( + src_ctxs_[segment_index]->db_->NewIterator( + src_ctxs_[segment_index]->read_opts_, src_positions_cf)); + iter->SeekToFirst(); + + for (; iter->Valid(); iter->Next()) { + const std::string key = iter->key().ToString(); + + std::string term; + uint32_t local_doc_id = 0; + if (!parse_doc_term_key(key, &term, &local_doc_id)) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: malformed positions key. field=", field_name_)); + } + + if (local_doc_id >= seg_doc_count) { + continue; + } + const uint32_t new_doc_id = + dense_rank(scan_offset + local_doc_id, delete_row_id_bitmap); + if (new_doc_id == kFilteredRank) { + continue; + } + const std::string new_key = make_doc_term_key(term, new_doc_id); + + if (!ctx_->db_ + ->Put(ctx_->write_opts_, dst_positions_cf_, new_key, + iter->value().ToString()) + .ok()) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: failed to write positions. field=", field_name_)); + } + } + + return {}; +} + +Result FtsRocksdbReducer::flush_stat(uint64_t total_docs, + uint64_t total_tokens) { + if (!ctx_->db_ + ->Put(ctx_->write_opts_, dst_stat_cf_, + make_total_docs_key(field_name_), + encode_uint64_value(total_docs)) + .ok()) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: failed to write total_docs. field=", field_name_)); + } + + if (!ctx_->db_ + ->Put(ctx_->write_opts_, dst_stat_cf_, + make_total_tokens_key(field_name_), + encode_uint64_value(total_tokens)) + .ok()) { + return tl::make_unexpected(Status::InternalError( + "FtsRocksdbReducer: failed to write total_tokens. field=", + field_name_)); + } + + return {}; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_rocksdb_reducer.h b/src/db/index/column/fts_column/fts_rocksdb_reducer.h new file mode 100644 index 000000000..02a6b4711 --- /dev/null +++ b/src/db/index/column/fts_column/fts_rocksdb_reducer.h @@ -0,0 +1,155 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "db/common/rocksdb_context.h" +#include "db/index/column/fts_column/bm25_scorer.h" +#include "db/index/column/fts_column/fts_types.h" + +namespace zvec::fts { + +class FtsRocksdbReducer; +using FtsRocksdbReducerPtr = std::shared_ptr; + +/*! FTS RocksDB segment reducer + * Merges FTS index data from multiple source segments into one destination + * segment, remapping doc_ids and filtering deleted documents. Reads only + * postings_cf (BitPacked) and positions_cf from each source segment; writes + * only postings_cf, positions_cf, and stat_cf on the destination side. + */ +class FtsRocksdbReducer { + public: + /*! Initialize the reducer with destination column families. + * \param field_name FTS field name (used for stat_cf keys) + * \param dst_postings_cf Destination postings CF (BitPacked output) + * \param dst_positions_cf Destination positions CF (phrase support) + * \param dst_stat_cf Destination segment-stat CF + * \return Result on success, or Status on failure + */ + Result init(const std::string &field_name, RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *dst_postings_cf, + rocksdb::ColumnFamilyHandle *dst_positions_cf, + rocksdb::ColumnFamilyHandle *dst_stat_cf); + + /*! Clean up internal state. */ + Result cleanup(); + + /*! Feed a source segment to be merged. + * Segments must be fed in consecutive doc_id order. + * \param segment_stats Stats of the source segment (min/max doc_id) + * \param src_ctx RocksdbContext owning the source CFs + * \param src_postings_cf Source postings CF (must be BitPacked) + * \param src_positions_cf Source positions CF + * \return Result on success, or Status on failure + */ + Result feed(FtsSegmentStats segment_stats, RocksdbContext *src_ctx, + rocksdb::ColumnFamilyHandle *src_postings_cf, + rocksdb::ColumnFamilyHandle *src_positions_cf); + + /*! Merge fed segments into the destination: per-term BitPacked postings + * to dst_postings_cf, doc_ids remapped to the new segment's dense space, + * effective total_docs / total_tokens to dst_stat_cf for BM25. + * + * \param delete_row_id_bitmap Deleted positions in input scan order, + * id space [0, Σ stats.doc_count). For segment i with + * scan_offset = Σ_{j reduce(const roaring::Roaring &delete_row_id_bitmap); + + /*! No-op: FTS data is written directly during reduce(). */ + Result dump() { + return {}; + } + + private: + // Two-pass streaming merge. Pass 1: collect effective stats. Pass 2: + // multi-way merge by term, encode + put one BitPacked posting per term + // (peak memory bounded by one term's entries). Both passes take the + // shared delete bitmap by reference rather than storing it on the + // reducer so its lifetime stays scoped to reduce(). + Result reduce_postings(const roaring::Roaring &delete_row_id_bitmap); + + // Pass 1: effective_total_docs_ = Σ stats.doc_count - bitmap.cardinality + // (counts empty docs too, like the mutable indexer); effective_total_tokens_ + // is summed from inline doc_len payloads of surviving docs. + Result collect_effective_stats( + const roaring::Roaring &delete_row_id_bitmap); + + // Pass 2: see reduce_postings. Dense rank looked up on the fly via + // the file-local dense_rank helper in the .cc. + Result merge_and_flush_postings( + const roaring::Roaring &delete_row_id_bitmap); + + // Per-segment positions CF remap (phrase query support). + Result reduce_positions(uint32_t segment_index, + const roaring::Roaring &delete_row_id_bitmap); + + // Write accumulated stats to destination stat CF. + Result flush_stat(uint64_t total_docs, uint64_t total_tokens); + + private: + enum State { + STATE_UNINITED = 0, + STATE_INITED = 1, + STATE_FEED = 2, + STATE_REDUCE = 3, + }; + + std::string field_name_{}; + + // RocksdbContext for CF-level operations (get/put/create_iter) + RocksdbContext *ctx_{nullptr}; + + // Destination column families (only the 3 active ones are tracked here; + // $TF/$MAX_TF/$DOC_LEN dst CFs exist in the RocksDB schema but the reducer + // never writes them — they will be empty in the output SST). + rocksdb::ColumnFamilyHandle *dst_postings_cf_{nullptr}; + rocksdb::ColumnFamilyHandle *dst_positions_cf_{nullptr}; + rocksdb::ColumnFamilyHandle *dst_stat_cf_{nullptr}; + + // Per-segment source RocksdbContexts, column families and stats (only + // postings + positions are needed; the empty $TF/$MAX_TF/$DOC_LEN side CFs + // are not opened here). + std::vector segment_stats_{}; + std::vector src_ctxs_{}; + std::vector src_postings_cfs_{}; + std::vector src_positions_cfs_{}; + + uint32_t num_segments_{0}; + + // Survivor-only stats; fed into scorer_ for block_max_score and written + // to dst stat_cf. + uint64_t effective_total_docs_{0}; + uint64_t effective_total_tokens_{0}; + + // Precomputed cumsum: scan_offset_per_seg_[i] = Σ_{j scan_offset_per_seg_{}; + + // BM25 scorer for computing block_max_score during BitPacked encoding. + // Initialized inside reduce() once effective stats are known. + BM25ScorerPtr scorer_; + + State state_{STATE_UNINITED}; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_types.h b/src/db/index/column/fts_column/fts_types.h new file mode 100644 index 000000000..6647e5caf --- /dev/null +++ b/src/db/index/column/fts_column/fts_types.h @@ -0,0 +1,58 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "db/index/common/index_filter.h" + +namespace zvec::fts { + +/*! FTS query parameters passed to FtsColumnIndexer::search(). */ +struct FtsQueryParams { + uint32_t topk{10}; + // Optional filter: returns true if a doc should be EXCLUDED. + // Wraps zvec::IndexFilter for push-down filtering inside the search loop. + IndexFilter::Ptr filter{nullptr}; + // Candidate-driven (brute-force) mode: ascending segment-local doc_ids; + // when non-empty, FtsColumnIndexer restricts evaluation to this set by + // AND-ing it with the root iterator. Filled by the planner via + // DocFilter::get_bf_by_keys_and_update when an invert result is highly + // selective. + std::vector candidate_ids; +}; + +/*! Per-segment statistics needed by the FTS reducer for doc_id remapping. + * - min_doc_id / max_doc_id: GLOBAL doc_id range used by the delete filter + * (filter.is_filtered() takes a global doc_id). + * - doc_count: number of FTS LOCAL doc_ids in the source segment; the posting + * list domain is [0, doc_count). For fresh (non-merged) segments this + * equals max_doc_id - min_doc_id + 1, and the local-to-global mapping is + * `global = min_doc_id + local`. + */ +struct FtsSegmentStats { + uint64_t min_doc_id{0}; + uint64_t max_doc_id{0}; + uint64_t doc_count{0}; +}; + +struct FtsIndexParams { + std::string tokenizer_name{"standard"}; + std::vector filters{"lowercase"}; + std::string extra_params; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_utils.cc b/src/db/index/column/fts_column/fts_utils.cc new file mode 100644 index 000000000..7cf8e495c --- /dev/null +++ b/src/db/index/column/fts_column/fts_utils.cc @@ -0,0 +1,38 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_utils.h" +#include + +namespace zvec::fts { + +bool parse_doc_term_key(const std::string &key, std::string *term_out, + uint32_t *doc_id_out) { + // Key format: term + '\0' + doc_id(4B big-endian) + // Minimum length: 1 byte term + 1 byte '\0' + 4 bytes doc_id = 6 bytes. + if (key.size() < 6) { + LOG_WARN("parse_doc_term_key: key too short. size[%zu]", key.size()); + return false; + } + const size_t separator_pos = key.size() - sizeof(uint32_t) - 1; + if (key[separator_pos] != '\0') { + LOG_WARN("parse_doc_term_key: missing separator. size[%zu]", key.size()); + return false; + } + *term_out = key.substr(0, separator_pos); + *doc_id_out = decode_uint32_big_endian(key.data() + separator_pos + 1); + return true; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/fts_utils.h b/src/db/index/column/fts_column/fts_utils.h new file mode 100644 index 000000000..3214bd354 --- /dev/null +++ b/src/db/index/column/fts_column/fts_utils.h @@ -0,0 +1,99 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace zvec::fts { + +// Big-endian uint32 encoding/decoding. +inline uint32_t decode_uint32_big_endian(const char *data) { + return (static_cast(static_cast(data[0])) << 24) | + (static_cast(static_cast(data[1])) << 16) | + (static_cast(static_cast(data[2])) << 8) | + static_cast(static_cast(data[3])); +} + +inline void encode_uint32_big_endian(uint32_t value, std::string *output) { + output->push_back(static_cast((value >> 24) & 0xFF)); + output->push_back(static_cast((value >> 16) & 0xFF)); + output->push_back(static_cast((value >> 8) & 0xFF)); + output->push_back(static_cast(value & 0xFF)); +} + +// Doc-term key: term + '\0' + doc_id (4-byte big-endian). +// Used by postings ($TF/$POS) column families. +inline std::string make_doc_term_key(const std::string &term, uint32_t doc_id) { + std::string key; + key.reserve(term.size() + 1 + sizeof(uint32_t)); + key.append(term); + key.push_back('\0'); + encode_uint32_big_endian(doc_id, &key); + return key; +} + +// In-place variant of make_doc_term_key: appends the key to an existing buffer. +// Callers that build many keys in a row can reserve once and reuse the buffer, +// avoiding per-key allocation. Returns the number of bytes appended so the +// caller can build Slices into the buffer. +inline size_t append_doc_term_key(const std::string &term, uint32_t doc_id, + std::string *buf) { + const size_t bytes = term.size() + 1 + sizeof(uint32_t); + buf->append(term); + buf->push_back('\0'); + encode_uint32_big_endian(doc_id, buf); + return bytes; +} + +bool parse_doc_term_key(const std::string &key, std::string *term_out, + uint32_t *doc_id_out); + +// Per-field segment-stat keys (stat_cf) for BM25 scoring. +inline std::string make_total_docs_key(const std::string &field_name) { + return field_name + "_total_docs"; +} + +inline std::string make_total_tokens_key(const std::string &field_name) { + return field_name + "_total_tokens"; +} + +// uint64 big-endian encoding for stat values. +inline std::string encode_uint64_value(uint64_t value) { + std::string out(sizeof(uint64_t), '\0'); + out[0] = static_cast((value >> 56) & 0xFF); + out[1] = static_cast((value >> 48) & 0xFF); + out[2] = static_cast((value >> 40) & 0xFF); + out[3] = static_cast((value >> 32) & 0xFF); + out[4] = static_cast((value >> 24) & 0xFF); + out[5] = static_cast((value >> 16) & 0xFF); + out[6] = static_cast((value >> 8) & 0xFF); + out[7] = static_cast(value & 0xFF); + return out; +} + +inline uint64_t decode_uint64_value(const char *data) { + return (static_cast(static_cast(data[0])) << 56) | + (static_cast(static_cast(data[1])) << 48) | + (static_cast(static_cast(data[2])) << 40) | + (static_cast(static_cast(data[3])) << 32) | + (static_cast(static_cast(data[4])) << 24) | + (static_cast(static_cast(data[5])) << 16) | + (static_cast(static_cast(data[6])) << 8) | + static_cast(static_cast(data[7])); +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/gen/FtsLexer.cc b/src/db/index/column/fts_column/gen/FtsLexer.cc new file mode 100644 index 000000000..0034ad5f8 --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsLexer.cc @@ -0,0 +1,257 @@ + +// Generated from FtsLexer.g4 by ANTLR 4.8 + + +#include "FtsLexer.h" + + +using namespace antlr4; + +using namespace antlr4; + +FtsLexer::FtsLexer(CharStream *input) : Lexer(input) { + _interpreter = new atn::LexerATNSimulator(this, _atn, _decisionToDFA, + _sharedContextCache); +} + +FtsLexer::~FtsLexer() { + delete _interpreter; +} + +std::string FtsLexer::getGrammarFileName() const { + return "FtsLexer.g4"; +} + +const std::vector &FtsLexer::getRuleNames() const { + return _ruleNames; +} + +const std::vector &FtsLexer::getChannelNames() const { + return _channelNames; +} + +const std::vector &FtsLexer::getModeNames() const { + return _modeNames; +} + +const std::vector &FtsLexer::getTokenNames() const { + return _tokenNames; +} + +dfa::Vocabulary &FtsLexer::getVocabulary() const { + return _vocabulary; +} + +const std::vector FtsLexer::getSerializedATN() const { + return _serializedATN; +} + +const atn::ATN &FtsLexer::getATN() const { + return _atn; +} + + +// Static vars and initialization. +std::vector FtsLexer::_decisionToDFA; +atn::PredictionContextCache FtsLexer::_sharedContextCache; + +// We own the ATN which in turn owns the ATN states. +atn::ATN FtsLexer::_atn; +std::vector FtsLexer::_serializedATN; + +std::vector FtsLexer::_ruleNames = { + "OR", "AND", "NOT", "PLUS_SIGN", "MINUS_SIGN", + "COLON", "CARET", "LP", "RP", "DQUOTA_STRING", + "ASCII_ALNUM", "ESCAPED_CHAR", "UNI_CHAR", "TERM_START", "TERM_BODY", + "REGULAR_ID", "NUMBER", "TERM", "SPACES", "DEFAULT"}; + +std::vector FtsLexer::_channelNames = {"DEFAULT_TOKEN_CHANNEL", + "HIDDEN"}; + +std::vector FtsLexer::_modeNames = {"DEFAULT_MODE"}; + +std::vector FtsLexer::_literalNames = { + "", "", "", "", "'+'", "'-'", "':'", "'^'", "'('", "')'"}; + +std::vector FtsLexer::_symbolicNames = { + "", "OR", "AND", "NOT", "PLUS_SIGN", "MINUS_SIGN", + "COLON", "CARET", "LP", "RP", "DQUOTA_STRING", "REGULAR_ID", + "NUMBER", "TERM", "SPACES", "DEFAULT"}; + +dfa::Vocabulary FtsLexer::_vocabulary(_literalNames, _symbolicNames); + +std::vector FtsLexer::_tokenNames; + +FtsLexer::Initializer::Initializer() { + // This code could be in a static initializer lambda, but VS doesn't allow + // access to private class members from there. + for (size_t i = 0; i < _symbolicNames.size(); ++i) { + std::string name = _vocabulary.getLiteralName(i); + if (name.empty()) { + name = _vocabulary.getSymbolicName(i); + } + + if (name.empty()) { + _tokenNames.push_back(""); + } else { + _tokenNames.push_back(name); + } + } + + _serializedATN = { + 0x3, 0x608b, 0xa72a, 0x8133, 0xb9ed, 0x417c, 0x3be7, 0x7786, 0x5964, + 0x2, 0x11, 0x82, 0x8, 0x1, 0x4, 0x2, 0x9, 0x2, + 0x4, 0x3, 0x9, 0x3, 0x4, 0x4, 0x9, 0x4, 0x4, + 0x5, 0x9, 0x5, 0x4, 0x6, 0x9, 0x6, 0x4, 0x7, + 0x9, 0x7, 0x4, 0x8, 0x9, 0x8, 0x4, 0x9, 0x9, + 0x9, 0x4, 0xa, 0x9, 0xa, 0x4, 0xb, 0x9, 0xb, + 0x4, 0xc, 0x9, 0xc, 0x4, 0xd, 0x9, 0xd, 0x4, + 0xe, 0x9, 0xe, 0x4, 0xf, 0x9, 0xf, 0x4, 0x10, + 0x9, 0x10, 0x4, 0x11, 0x9, 0x11, 0x4, 0x12, 0x9, + 0x12, 0x4, 0x13, 0x9, 0x13, 0x4, 0x14, 0x9, 0x14, + 0x4, 0x15, 0x9, 0x15, 0x3, 0x2, 0x3, 0x2, 0x3, + 0x2, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, + 0x3, 0x4, 0x3, 0x4, 0x3, 0x4, 0x3, 0x4, 0x3, + 0x5, 0x3, 0x5, 0x3, 0x6, 0x3, 0x6, 0x3, 0x7, + 0x3, 0x7, 0x3, 0x8, 0x3, 0x8, 0x3, 0x9, 0x3, + 0x9, 0x3, 0xa, 0x3, 0xa, 0x3, 0xb, 0x3, 0xb, + 0x3, 0xb, 0x3, 0xb, 0x7, 0xb, 0x47, 0xa, 0xb, + 0xc, 0xb, 0xe, 0xb, 0x4a, 0xb, 0xb, 0x3, 0xb, + 0x3, 0xb, 0x3, 0xc, 0x3, 0xc, 0x3, 0xd, 0x3, + 0xd, 0x3, 0xd, 0x3, 0xe, 0x3, 0xe, 0x3, 0xf, + 0x3, 0xf, 0x5, 0xf, 0x57, 0xa, 0xf, 0x3, 0x10, + 0x3, 0x10, 0x3, 0x10, 0x3, 0x10, 0x5, 0x10, 0x5d, + 0xa, 0x10, 0x3, 0x11, 0x3, 0x11, 0x7, 0x11, 0x61, + 0xa, 0x11, 0xc, 0x11, 0xe, 0x11, 0x64, 0xb, 0x11, + 0x3, 0x12, 0x6, 0x12, 0x67, 0xa, 0x12, 0xd, 0x12, + 0xe, 0x12, 0x68, 0x3, 0x12, 0x3, 0x12, 0x6, 0x12, + 0x6d, 0xa, 0x12, 0xd, 0x12, 0xe, 0x12, 0x6e, 0x5, + 0x12, 0x71, 0xa, 0x12, 0x3, 0x13, 0x3, 0x13, 0x7, + 0x13, 0x75, 0xa, 0x13, 0xc, 0x13, 0xe, 0x13, 0x78, + 0xb, 0x13, 0x3, 0x14, 0x6, 0x14, 0x7b, 0xa, 0x14, + 0xd, 0x14, 0xe, 0x14, 0x7c, 0x3, 0x14, 0x3, 0x14, + 0x3, 0x15, 0x3, 0x15, 0x2, 0x2, 0x16, 0x3, 0x3, + 0x5, 0x4, 0x7, 0x5, 0x9, 0x6, 0xb, 0x7, 0xd, + 0x8, 0xf, 0x9, 0x11, 0xa, 0x13, 0xb, 0x15, 0xc, + 0x17, 0x2, 0x19, 0x2, 0x1b, 0x2, 0x1d, 0x2, 0x1f, + 0x2, 0x21, 0xd, 0x23, 0xe, 0x25, 0xf, 0x27, 0x10, + 0x29, 0x11, 0x3, 0x2, 0x11, 0x4, 0x2, 0x51, 0x51, + 0x71, 0x71, 0x4, 0x2, 0x54, 0x54, 0x74, 0x74, 0x4, + 0x2, 0x43, 0x43, 0x63, 0x63, 0x4, 0x2, 0x50, 0x50, + 0x70, 0x70, 0x4, 0x2, 0x46, 0x46, 0x66, 0x66, 0x4, + 0x2, 0x56, 0x56, 0x76, 0x76, 0x6, 0x2, 0xc, 0xc, + 0xf, 0xf, 0x24, 0x24, 0x5e, 0x5e, 0x6, 0x2, 0x32, + 0x3b, 0x43, 0x5c, 0x61, 0x61, 0x63, 0x7c, 0xc, 0x2, + 0x23, 0x24, 0x28, 0x28, 0x2a, 0x2d, 0x2f, 0x2f, 0x31, + 0x31, 0x3c, 0x3c, 0x3f, 0x3f, 0x41, 0x41, 0x5d, 0x60, + 0x7d, 0x80, 0x3, 0x2, 0x82, 0x1, 0x8, 0x2, 0x25, + 0x25, 0x27, 0x27, 0x29, 0x29, 0x2f, 0x31, 0x42, 0x42, + 0x61, 0x61, 0x5, 0x2, 0x43, 0x5c, 0x61, 0x61, 0x63, + 0x7c, 0x7, 0x2, 0x2f, 0x2f, 0x32, 0x3b, 0x43, 0x5c, + 0x61, 0x61, 0x63, 0x7c, 0x3, 0x2, 0x32, 0x3b, 0x5, + 0x2, 0xb, 0xc, 0xf, 0xf, 0x22, 0x22, 0x2, 0x88, + 0x2, 0x3, 0x3, 0x2, 0x2, 0x2, 0x2, 0x5, 0x3, + 0x2, 0x2, 0x2, 0x2, 0x7, 0x3, 0x2, 0x2, 0x2, + 0x2, 0x9, 0x3, 0x2, 0x2, 0x2, 0x2, 0xb, 0x3, + 0x2, 0x2, 0x2, 0x2, 0xd, 0x3, 0x2, 0x2, 0x2, + 0x2, 0xf, 0x3, 0x2, 0x2, 0x2, 0x2, 0x11, 0x3, + 0x2, 0x2, 0x2, 0x2, 0x13, 0x3, 0x2, 0x2, 0x2, + 0x2, 0x15, 0x3, 0x2, 0x2, 0x2, 0x2, 0x21, 0x3, + 0x2, 0x2, 0x2, 0x2, 0x23, 0x3, 0x2, 0x2, 0x2, + 0x2, 0x25, 0x3, 0x2, 0x2, 0x2, 0x2, 0x27, 0x3, + 0x2, 0x2, 0x2, 0x2, 0x29, 0x3, 0x2, 0x2, 0x2, + 0x3, 0x2b, 0x3, 0x2, 0x2, 0x2, 0x5, 0x2e, 0x3, + 0x2, 0x2, 0x2, 0x7, 0x32, 0x3, 0x2, 0x2, 0x2, + 0x9, 0x36, 0x3, 0x2, 0x2, 0x2, 0xb, 0x38, 0x3, + 0x2, 0x2, 0x2, 0xd, 0x3a, 0x3, 0x2, 0x2, 0x2, + 0xf, 0x3c, 0x3, 0x2, 0x2, 0x2, 0x11, 0x3e, 0x3, + 0x2, 0x2, 0x2, 0x13, 0x40, 0x3, 0x2, 0x2, 0x2, + 0x15, 0x42, 0x3, 0x2, 0x2, 0x2, 0x17, 0x4d, 0x3, + 0x2, 0x2, 0x2, 0x19, 0x4f, 0x3, 0x2, 0x2, 0x2, + 0x1b, 0x52, 0x3, 0x2, 0x2, 0x2, 0x1d, 0x56, 0x3, + 0x2, 0x2, 0x2, 0x1f, 0x5c, 0x3, 0x2, 0x2, 0x2, + 0x21, 0x5e, 0x3, 0x2, 0x2, 0x2, 0x23, 0x66, 0x3, + 0x2, 0x2, 0x2, 0x25, 0x72, 0x3, 0x2, 0x2, 0x2, + 0x27, 0x7a, 0x3, 0x2, 0x2, 0x2, 0x29, 0x80, 0x3, + 0x2, 0x2, 0x2, 0x2b, 0x2c, 0x9, 0x2, 0x2, 0x2, + 0x2c, 0x2d, 0x9, 0x3, 0x2, 0x2, 0x2d, 0x4, 0x3, + 0x2, 0x2, 0x2, 0x2e, 0x2f, 0x9, 0x4, 0x2, 0x2, + 0x2f, 0x30, 0x9, 0x5, 0x2, 0x2, 0x30, 0x31, 0x9, + 0x6, 0x2, 0x2, 0x31, 0x6, 0x3, 0x2, 0x2, 0x2, + 0x32, 0x33, 0x9, 0x5, 0x2, 0x2, 0x33, 0x34, 0x9, + 0x2, 0x2, 0x2, 0x34, 0x35, 0x9, 0x7, 0x2, 0x2, + 0x35, 0x8, 0x3, 0x2, 0x2, 0x2, 0x36, 0x37, 0x7, + 0x2d, 0x2, 0x2, 0x37, 0xa, 0x3, 0x2, 0x2, 0x2, + 0x38, 0x39, 0x7, 0x2f, 0x2, 0x2, 0x39, 0xc, 0x3, + 0x2, 0x2, 0x2, 0x3a, 0x3b, 0x7, 0x3c, 0x2, 0x2, + 0x3b, 0xe, 0x3, 0x2, 0x2, 0x2, 0x3c, 0x3d, 0x7, + 0x60, 0x2, 0x2, 0x3d, 0x10, 0x3, 0x2, 0x2, 0x2, + 0x3e, 0x3f, 0x7, 0x2a, 0x2, 0x2, 0x3f, 0x12, 0x3, + 0x2, 0x2, 0x2, 0x40, 0x41, 0x7, 0x2b, 0x2, 0x2, + 0x41, 0x14, 0x3, 0x2, 0x2, 0x2, 0x42, 0x48, 0x7, + 0x24, 0x2, 0x2, 0x43, 0x47, 0xa, 0x8, 0x2, 0x2, + 0x44, 0x45, 0x7, 0x5e, 0x2, 0x2, 0x45, 0x47, 0xb, + 0x2, 0x2, 0x2, 0x46, 0x43, 0x3, 0x2, 0x2, 0x2, + 0x46, 0x44, 0x3, 0x2, 0x2, 0x2, 0x47, 0x4a, 0x3, + 0x2, 0x2, 0x2, 0x48, 0x46, 0x3, 0x2, 0x2, 0x2, + 0x48, 0x49, 0x3, 0x2, 0x2, 0x2, 0x49, 0x4b, 0x3, + 0x2, 0x2, 0x2, 0x4a, 0x48, 0x3, 0x2, 0x2, 0x2, + 0x4b, 0x4c, 0x7, 0x24, 0x2, 0x2, 0x4c, 0x16, 0x3, + 0x2, 0x2, 0x2, 0x4d, 0x4e, 0x9, 0x9, 0x2, 0x2, + 0x4e, 0x18, 0x3, 0x2, 0x2, 0x2, 0x4f, 0x50, 0x7, + 0x5e, 0x2, 0x2, 0x50, 0x51, 0x9, 0xa, 0x2, 0x2, + 0x51, 0x1a, 0x3, 0x2, 0x2, 0x2, 0x52, 0x53, 0x9, + 0xb, 0x2, 0x2, 0x53, 0x1c, 0x3, 0x2, 0x2, 0x2, + 0x54, 0x57, 0x5, 0x17, 0xc, 0x2, 0x55, 0x57, 0x5, + 0x1b, 0xe, 0x2, 0x56, 0x54, 0x3, 0x2, 0x2, 0x2, + 0x56, 0x55, 0x3, 0x2, 0x2, 0x2, 0x57, 0x1e, 0x3, + 0x2, 0x2, 0x2, 0x58, 0x5d, 0x5, 0x17, 0xc, 0x2, + 0x59, 0x5d, 0x5, 0x1b, 0xe, 0x2, 0x5a, 0x5d, 0x9, + 0xc, 0x2, 0x2, 0x5b, 0x5d, 0x5, 0x19, 0xd, 0x2, + 0x5c, 0x58, 0x3, 0x2, 0x2, 0x2, 0x5c, 0x59, 0x3, + 0x2, 0x2, 0x2, 0x5c, 0x5a, 0x3, 0x2, 0x2, 0x2, + 0x5c, 0x5b, 0x3, 0x2, 0x2, 0x2, 0x5d, 0x20, 0x3, + 0x2, 0x2, 0x2, 0x5e, 0x62, 0x9, 0xd, 0x2, 0x2, + 0x5f, 0x61, 0x9, 0xe, 0x2, 0x2, 0x60, 0x5f, 0x3, + 0x2, 0x2, 0x2, 0x61, 0x64, 0x3, 0x2, 0x2, 0x2, + 0x62, 0x60, 0x3, 0x2, 0x2, 0x2, 0x62, 0x63, 0x3, + 0x2, 0x2, 0x2, 0x63, 0x22, 0x3, 0x2, 0x2, 0x2, + 0x64, 0x62, 0x3, 0x2, 0x2, 0x2, 0x65, 0x67, 0x9, + 0xf, 0x2, 0x2, 0x66, 0x65, 0x3, 0x2, 0x2, 0x2, + 0x67, 0x68, 0x3, 0x2, 0x2, 0x2, 0x68, 0x66, 0x3, + 0x2, 0x2, 0x2, 0x68, 0x69, 0x3, 0x2, 0x2, 0x2, + 0x69, 0x70, 0x3, 0x2, 0x2, 0x2, 0x6a, 0x6c, 0x7, + 0x30, 0x2, 0x2, 0x6b, 0x6d, 0x9, 0xf, 0x2, 0x2, + 0x6c, 0x6b, 0x3, 0x2, 0x2, 0x2, 0x6d, 0x6e, 0x3, + 0x2, 0x2, 0x2, 0x6e, 0x6c, 0x3, 0x2, 0x2, 0x2, + 0x6e, 0x6f, 0x3, 0x2, 0x2, 0x2, 0x6f, 0x71, 0x3, + 0x2, 0x2, 0x2, 0x70, 0x6a, 0x3, 0x2, 0x2, 0x2, + 0x70, 0x71, 0x3, 0x2, 0x2, 0x2, 0x71, 0x24, 0x3, + 0x2, 0x2, 0x2, 0x72, 0x76, 0x5, 0x1d, 0xf, 0x2, + 0x73, 0x75, 0x5, 0x1f, 0x10, 0x2, 0x74, 0x73, 0x3, + 0x2, 0x2, 0x2, 0x75, 0x78, 0x3, 0x2, 0x2, 0x2, + 0x76, 0x74, 0x3, 0x2, 0x2, 0x2, 0x76, 0x77, 0x3, + 0x2, 0x2, 0x2, 0x77, 0x26, 0x3, 0x2, 0x2, 0x2, + 0x78, 0x76, 0x3, 0x2, 0x2, 0x2, 0x79, 0x7b, 0x9, + 0x10, 0x2, 0x2, 0x7a, 0x79, 0x3, 0x2, 0x2, 0x2, + 0x7b, 0x7c, 0x3, 0x2, 0x2, 0x2, 0x7c, 0x7a, 0x3, + 0x2, 0x2, 0x2, 0x7c, 0x7d, 0x3, 0x2, 0x2, 0x2, + 0x7d, 0x7e, 0x3, 0x2, 0x2, 0x2, 0x7e, 0x7f, 0x8, + 0x14, 0x2, 0x2, 0x7f, 0x28, 0x3, 0x2, 0x2, 0x2, + 0x80, 0x81, 0xb, 0x2, 0x2, 0x2, 0x81, 0x2a, 0x3, + 0x2, 0x2, 0x2, 0xd, 0x2, 0x46, 0x48, 0x56, 0x5c, + 0x62, 0x68, 0x6e, 0x70, 0x76, 0x7c, 0x3, 0x8, 0x2, + 0x2, + }; + + atn::ATNDeserializer deserializer; + _atn = deserializer.deserialize(_serializedATN); + + size_t count = _atn.getNumberOfDecisions(); + _decisionToDFA.reserve(count); + for (size_t i = 0; i < count; i++) { + _decisionToDFA.emplace_back(_atn.getDecisionState(i), i); + } +} + +FtsLexer::Initializer FtsLexer::_init; diff --git a/src/db/index/column/fts_column/gen/FtsLexer.h b/src/db/index/column/fts_column/gen/FtsLexer.h new file mode 100644 index 000000000..9843b865e --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsLexer.h @@ -0,0 +1,73 @@ + +// Generated from FtsLexer.g4 by ANTLR 4.8 + +#pragma once + + +#include "antlr4-runtime.h" + + +namespace antlr4 { + + +class FtsLexer : public antlr4::Lexer { + public: + enum { + OR = 1, + AND = 2, + NOT = 3, + PLUS_SIGN = 4, + MINUS_SIGN = 5, + COLON = 6, + CARET = 7, + LP = 8, + RP = 9, + DQUOTA_STRING = 10, + REGULAR_ID = 11, + NUMBER = 12, + TERM = 13, + SPACES = 14, + DEFAULT = 15 + }; + + FtsLexer(antlr4::CharStream *input); + ~FtsLexer(); + + virtual std::string getGrammarFileName() const override; + virtual const std::vector &getRuleNames() const override; + + virtual const std::vector &getChannelNames() const override; + virtual const std::vector &getModeNames() const override; + virtual const std::vector &getTokenNames() + const override; // deprecated, use vocabulary instead + virtual antlr4::dfa::Vocabulary &getVocabulary() const override; + + virtual const std::vector getSerializedATN() const override; + virtual const antlr4::atn::ATN &getATN() const override; + + private: + static std::vector _decisionToDFA; + static antlr4::atn::PredictionContextCache _sharedContextCache; + static std::vector _ruleNames; + static std::vector _tokenNames; + static std::vector _channelNames; + static std::vector _modeNames; + + static std::vector _literalNames; + static std::vector _symbolicNames; + static antlr4::dfa::Vocabulary _vocabulary; + static antlr4::atn::ATN _atn; + static std::vector _serializedATN; + + + // Individual action functions triggered by action() above. + + // Individual semantic predicate functions triggered by sempred() above. + + struct Initializer { + Initializer(); + }; + static Initializer _init; +}; + +} // namespace antlr4 diff --git a/src/db/index/column/fts_column/gen/FtsLexer.interp b/src/db/index/column/fts_column/gen/FtsLexer.interp new file mode 100644 index 000000000..384c23305 --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsLexer.interp @@ -0,0 +1,67 @@ +token literal names: +null +null +null +null +'+' +'-' +':' +'^' +'(' +')' +null +null +null +null +null +null + +token symbolic names: +null +OR +AND +NOT +PLUS_SIGN +MINUS_SIGN +COLON +CARET +LP +RP +DQUOTA_STRING +REGULAR_ID +NUMBER +TERM +SPACES +DEFAULT + +rule names: +OR +AND +NOT +PLUS_SIGN +MINUS_SIGN +COLON +CARET +LP +RP +DQUOTA_STRING +ASCII_ALNUM +ESCAPED_CHAR +UNI_CHAR +TERM_START +TERM_BODY +REGULAR_ID +NUMBER +TERM +SPACES +DEFAULT + +channel names: +DEFAULT_TOKEN_CHANNEL +HIDDEN + +mode names: +DEFAULT_MODE + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 17, 130, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 3, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 4, 3, 4, 3, 4, 3, 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 8, 3, 8, 3, 9, 3, 9, 3, 10, 3, 10, 3, 11, 3, 11, 3, 11, 3, 11, 7, 11, 71, 10, 11, 12, 11, 14, 11, 74, 11, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 13, 3, 13, 3, 13, 3, 14, 3, 14, 3, 15, 3, 15, 5, 15, 87, 10, 15, 3, 16, 3, 16, 3, 16, 3, 16, 5, 16, 93, 10, 16, 3, 17, 3, 17, 7, 17, 97, 10, 17, 12, 17, 14, 17, 100, 11, 17, 3, 18, 6, 18, 103, 10, 18, 13, 18, 14, 18, 104, 3, 18, 3, 18, 6, 18, 109, 10, 18, 13, 18, 14, 18, 110, 5, 18, 113, 10, 18, 3, 19, 3, 19, 7, 19, 117, 10, 19, 12, 19, 14, 19, 120, 11, 19, 3, 20, 6, 20, 123, 10, 20, 13, 20, 14, 20, 124, 3, 20, 3, 20, 3, 21, 3, 21, 2, 2, 22, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 2, 25, 2, 27, 2, 29, 2, 31, 2, 33, 13, 35, 14, 37, 15, 39, 16, 41, 17, 3, 2, 17, 4, 2, 81, 81, 113, 113, 4, 2, 84, 84, 116, 116, 4, 2, 67, 67, 99, 99, 4, 2, 80, 80, 112, 112, 4, 2, 70, 70, 102, 102, 4, 2, 86, 86, 118, 118, 6, 2, 12, 12, 15, 15, 36, 36, 94, 94, 6, 2, 50, 59, 67, 92, 97, 97, 99, 124, 12, 2, 35, 36, 40, 40, 42, 45, 47, 47, 49, 49, 60, 60, 63, 63, 65, 65, 93, 96, 125, 128, 3, 2, 130, 1, 8, 2, 37, 37, 39, 39, 41, 41, 47, 49, 66, 66, 97, 97, 5, 2, 67, 92, 97, 97, 99, 124, 7, 2, 47, 47, 50, 59, 67, 92, 97, 97, 99, 124, 3, 2, 50, 59, 5, 2, 11, 12, 15, 15, 34, 34, 2, 136, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 3, 43, 3, 2, 2, 2, 5, 46, 3, 2, 2, 2, 7, 50, 3, 2, 2, 2, 9, 54, 3, 2, 2, 2, 11, 56, 3, 2, 2, 2, 13, 58, 3, 2, 2, 2, 15, 60, 3, 2, 2, 2, 17, 62, 3, 2, 2, 2, 19, 64, 3, 2, 2, 2, 21, 66, 3, 2, 2, 2, 23, 77, 3, 2, 2, 2, 25, 79, 3, 2, 2, 2, 27, 82, 3, 2, 2, 2, 29, 86, 3, 2, 2, 2, 31, 92, 3, 2, 2, 2, 33, 94, 3, 2, 2, 2, 35, 102, 3, 2, 2, 2, 37, 114, 3, 2, 2, 2, 39, 122, 3, 2, 2, 2, 41, 128, 3, 2, 2, 2, 43, 44, 9, 2, 2, 2, 44, 45, 9, 3, 2, 2, 45, 4, 3, 2, 2, 2, 46, 47, 9, 4, 2, 2, 47, 48, 9, 5, 2, 2, 48, 49, 9, 6, 2, 2, 49, 6, 3, 2, 2, 2, 50, 51, 9, 5, 2, 2, 51, 52, 9, 2, 2, 2, 52, 53, 9, 7, 2, 2, 53, 8, 3, 2, 2, 2, 54, 55, 7, 45, 2, 2, 55, 10, 3, 2, 2, 2, 56, 57, 7, 47, 2, 2, 57, 12, 3, 2, 2, 2, 58, 59, 7, 60, 2, 2, 59, 14, 3, 2, 2, 2, 60, 61, 7, 96, 2, 2, 61, 16, 3, 2, 2, 2, 62, 63, 7, 42, 2, 2, 63, 18, 3, 2, 2, 2, 64, 65, 7, 43, 2, 2, 65, 20, 3, 2, 2, 2, 66, 72, 7, 36, 2, 2, 67, 71, 10, 8, 2, 2, 68, 69, 7, 94, 2, 2, 69, 71, 11, 2, 2, 2, 70, 67, 3, 2, 2, 2, 70, 68, 3, 2, 2, 2, 71, 74, 3, 2, 2, 2, 72, 70, 3, 2, 2, 2, 72, 73, 3, 2, 2, 2, 73, 75, 3, 2, 2, 2, 74, 72, 3, 2, 2, 2, 75, 76, 7, 36, 2, 2, 76, 22, 3, 2, 2, 2, 77, 78, 9, 9, 2, 2, 78, 24, 3, 2, 2, 2, 79, 80, 7, 94, 2, 2, 80, 81, 9, 10, 2, 2, 81, 26, 3, 2, 2, 2, 82, 83, 9, 11, 2, 2, 83, 28, 3, 2, 2, 2, 84, 87, 5, 23, 12, 2, 85, 87, 5, 27, 14, 2, 86, 84, 3, 2, 2, 2, 86, 85, 3, 2, 2, 2, 87, 30, 3, 2, 2, 2, 88, 93, 5, 23, 12, 2, 89, 93, 5, 27, 14, 2, 90, 93, 9, 12, 2, 2, 91, 93, 5, 25, 13, 2, 92, 88, 3, 2, 2, 2, 92, 89, 3, 2, 2, 2, 92, 90, 3, 2, 2, 2, 92, 91, 3, 2, 2, 2, 93, 32, 3, 2, 2, 2, 94, 98, 9, 13, 2, 2, 95, 97, 9, 14, 2, 2, 96, 95, 3, 2, 2, 2, 97, 100, 3, 2, 2, 2, 98, 96, 3, 2, 2, 2, 98, 99, 3, 2, 2, 2, 99, 34, 3, 2, 2, 2, 100, 98, 3, 2, 2, 2, 101, 103, 9, 15, 2, 2, 102, 101, 3, 2, 2, 2, 103, 104, 3, 2, 2, 2, 104, 102, 3, 2, 2, 2, 104, 105, 3, 2, 2, 2, 105, 112, 3, 2, 2, 2, 106, 108, 7, 48, 2, 2, 107, 109, 9, 15, 2, 2, 108, 107, 3, 2, 2, 2, 109, 110, 3, 2, 2, 2, 110, 108, 3, 2, 2, 2, 110, 111, 3, 2, 2, 2, 111, 113, 3, 2, 2, 2, 112, 106, 3, 2, 2, 2, 112, 113, 3, 2, 2, 2, 113, 36, 3, 2, 2, 2, 114, 118, 5, 29, 15, 2, 115, 117, 5, 31, 16, 2, 116, 115, 3, 2, 2, 2, 117, 120, 3, 2, 2, 2, 118, 116, 3, 2, 2, 2, 118, 119, 3, 2, 2, 2, 119, 38, 3, 2, 2, 2, 120, 118, 3, 2, 2, 2, 121, 123, 9, 16, 2, 2, 122, 121, 3, 2, 2, 2, 123, 124, 3, 2, 2, 2, 124, 122, 3, 2, 2, 2, 124, 125, 3, 2, 2, 2, 125, 126, 3, 2, 2, 2, 126, 127, 8, 20, 2, 2, 127, 40, 3, 2, 2, 2, 128, 129, 11, 2, 2, 2, 129, 42, 3, 2, 2, 2, 13, 2, 70, 72, 86, 92, 98, 104, 110, 112, 118, 124, 3, 8, 2, 2] diff --git a/src/db/index/column/fts_column/gen/FtsLexer.tokens b/src/db/index/column/fts_column/gen/FtsLexer.tokens new file mode 100644 index 000000000..cd6e2db20 --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsLexer.tokens @@ -0,0 +1,21 @@ +OR=1 +AND=2 +NOT=3 +PLUS_SIGN=4 +MINUS_SIGN=5 +COLON=6 +CARET=7 +LP=8 +RP=9 +DQUOTA_STRING=10 +REGULAR_ID=11 +NUMBER=12 +TERM=13 +SPACES=14 +DEFAULT=15 +'+'=4 +'-'=5 +':'=6 +'^'=7 +'('=8 +')'=9 diff --git a/src/db/index/column/fts_column/gen/FtsParser.cc b/src/db/index/column/fts_column/gen/FtsParser.cc new file mode 100644 index 000000000..8fc31950b --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsParser.cc @@ -0,0 +1,1116 @@ + +// Generated from FtsParser.g4 by ANTLR 4.8 + + +#include "FtsParser.h" +#include "FtsParserListener.h" + + +using namespace antlrcpp; +using namespace antlr4; +using namespace antlr4; + +FtsParser::FtsParser(TokenStream *input) : Parser(input) { + _interpreter = new atn::ParserATNSimulator(this, _atn, _decisionToDFA, + _sharedContextCache); +} + +FtsParser::~FtsParser() { + delete _interpreter; +} + +std::string FtsParser::getGrammarFileName() const { + return "FtsParser.g4"; +} + +const std::vector &FtsParser::getRuleNames() const { + return _ruleNames; +} + +dfa::Vocabulary &FtsParser::getVocabulary() const { + return _vocabulary; +} + + +//----------------- Fts_query_unitContext +//------------------------------------------------------------------ + +FtsParser::Fts_query_unitContext::Fts_query_unitContext( + ParserRuleContext *parent_ctx, size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +FtsParser::Fts_or_exprContext *FtsParser::Fts_query_unitContext::fts_or_expr() { + return getRuleContext(0); +} + +tree::TerminalNode *FtsParser::Fts_query_unitContext::EOF() { + return getToken(FtsParser::EOF, 0); +} + + +size_t FtsParser::Fts_query_unitContext::getRuleIndex() const { + return FtsParser::RuleFts_query_unit; +} + +void FtsParser::Fts_query_unitContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_query_unit(this); +} + +void FtsParser::Fts_query_unitContext::exitRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_query_unit(this); +} + +FtsParser::Fts_query_unitContext *FtsParser::fts_query_unit() { + Fts_query_unitContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 0, FtsParser::RuleFts_query_unit); + + auto onExit = finally([=] { exitRule(); }); + try { + enterOuterAlt(_localctx, 1); + setState(24); + fts_or_expr(); + setState(25); + match(FtsParser::EOF); + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_or_exprContext +//------------------------------------------------------------------ + +FtsParser::Fts_or_exprContext::Fts_or_exprContext(ParserRuleContext *parent_ctx, + size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +std::vector +FtsParser::Fts_or_exprContext::fts_and_expr() { + return getRuleContexts(); +} + +FtsParser::Fts_and_exprContext *FtsParser::Fts_or_exprContext::fts_and_expr( + size_t i) { + return getRuleContext(i); +} + +std::vector FtsParser::Fts_or_exprContext::OR() { + return getTokens(FtsParser::OR); +} + +tree::TerminalNode *FtsParser::Fts_or_exprContext::OR(size_t i) { + return getToken(FtsParser::OR, i); +} + + +size_t FtsParser::Fts_or_exprContext::getRuleIndex() const { + return FtsParser::RuleFts_or_expr; +} + +void FtsParser::Fts_or_exprContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_or_expr(this); +} + +void FtsParser::Fts_or_exprContext::exitRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_or_expr(this); +} + +FtsParser::Fts_or_exprContext *FtsParser::fts_or_expr() { + Fts_or_exprContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 2, FtsParser::RuleFts_or_expr); + size_t _la = 0; + + auto onExit = finally([=] { exitRule(); }); + try { + enterOuterAlt(_localctx, 1); + setState(27); + fts_and_expr(); + setState(32); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == FtsParser::OR) { + setState(28); + match(FtsParser::OR); + setState(29); + fts_and_expr(); + setState(34); + _errHandler->sync(this); + _la = _input->LA(1); + } + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_and_exprContext +//------------------------------------------------------------------ + +FtsParser::Fts_and_exprContext::Fts_and_exprContext( + ParserRuleContext *parent_ctx, size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +std::vector +FtsParser::Fts_and_exprContext::fts_seq_expr() { + return getRuleContexts(); +} + +FtsParser::Fts_seq_exprContext *FtsParser::Fts_and_exprContext::fts_seq_expr( + size_t i) { + return getRuleContext(i); +} + +std::vector FtsParser::Fts_and_exprContext::AND() { + return getTokens(FtsParser::AND); +} + +tree::TerminalNode *FtsParser::Fts_and_exprContext::AND(size_t i) { + return getToken(FtsParser::AND, i); +} + +std::vector FtsParser::Fts_and_exprContext::NOT() { + return getTokens(FtsParser::NOT); +} + +tree::TerminalNode *FtsParser::Fts_and_exprContext::NOT(size_t i) { + return getToken(FtsParser::NOT, i); +} + + +size_t FtsParser::Fts_and_exprContext::getRuleIndex() const { + return FtsParser::RuleFts_and_expr; +} + +void FtsParser::Fts_and_exprContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_and_expr(this); +} + +void FtsParser::Fts_and_exprContext::exitRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_and_expr(this); +} + +FtsParser::Fts_and_exprContext *FtsParser::fts_and_expr() { + Fts_and_exprContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 4, FtsParser::RuleFts_and_expr); + size_t _la = 0; + + auto onExit = finally([=] { exitRule(); }); + try { + enterOuterAlt(_localctx, 1); + setState(35); + fts_seq_expr(); + setState(46); + _errHandler->sync(this); + _la = _input->LA(1); + while (_la == FtsParser::AND + + || _la == FtsParser::NOT) { + setState(41); + _errHandler->sync(this); + switch (_input->LA(1)) { + case FtsParser::AND: { + setState(36); + match(FtsParser::AND); + setState(38); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == FtsParser::NOT) { + setState(37); + match(FtsParser::NOT); + } + break; + } + + case FtsParser::NOT: { + setState(40); + match(FtsParser::NOT); + break; + } + + default: + throw NoViableAltException(this); + } + setState(43); + fts_seq_expr(); + setState(48); + _errHandler->sync(this); + _la = _input->LA(1); + } + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_seq_exprContext +//------------------------------------------------------------------ + +FtsParser::Fts_seq_exprContext::Fts_seq_exprContext( + ParserRuleContext *parent_ctx, size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +std::vector +FtsParser::Fts_seq_exprContext::fts_unary() { + return getRuleContexts(); +} + +FtsParser::Fts_unaryContext *FtsParser::Fts_seq_exprContext::fts_unary( + size_t i) { + return getRuleContext(i); +} + + +size_t FtsParser::Fts_seq_exprContext::getRuleIndex() const { + return FtsParser::RuleFts_seq_expr; +} + +void FtsParser::Fts_seq_exprContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_seq_expr(this); +} + +void FtsParser::Fts_seq_exprContext::exitRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_seq_expr(this); +} + +FtsParser::Fts_seq_exprContext *FtsParser::fts_seq_expr() { + Fts_seq_exprContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 6, FtsParser::RuleFts_seq_expr); + size_t _la = 0; + + auto onExit = finally([=] { exitRule(); }); + try { + enterOuterAlt(_localctx, 1); + setState(50); + _errHandler->sync(this); + _la = _input->LA(1); + do { + setState(49); + fts_unary(); + setState(52); + _errHandler->sync(this); + _la = _input->LA(1); + } while ( + (((_la & ~0x3fULL) == 0) && + ((1ULL << _la) & + ((1ULL << FtsParser::PLUS_SIGN) | (1ULL << FtsParser::MINUS_SIGN) | + (1ULL << FtsParser::LP) | (1ULL << FtsParser::DQUOTA_STRING) | + (1ULL << FtsParser::REGULAR_ID) | (1ULL << FtsParser::NUMBER) | + (1ULL << FtsParser::TERM) | (1ULL << FtsParser::DEFAULT))) != 0)); + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_unaryContext +//------------------------------------------------------------------ + +FtsParser::Fts_unaryContext::Fts_unaryContext(ParserRuleContext *parent_ctx, + size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + + +size_t FtsParser::Fts_unaryContext::getRuleIndex() const { + return FtsParser::RuleFts_unary; +} + +void FtsParser::Fts_unaryContext::copyFrom(Fts_unaryContext *ctx) { + ParserRuleContext::copyFrom(ctx); +} + +//----------------- Must_not_atomContext +//------------------------------------------------------------------ + +tree::TerminalNode *FtsParser::Must_not_atomContext::MINUS_SIGN() { + return getToken(FtsParser::MINUS_SIGN, 0); +} + +FtsParser::Fts_atomContext *FtsParser::Must_not_atomContext::fts_atom() { + return getRuleContext(0); +} + +FtsParser::Must_not_atomContext::Must_not_atomContext(Fts_unaryContext *ctx) { + copyFrom(ctx); +} + +void FtsParser::Must_not_atomContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterMust_not_atom(this); +} +void FtsParser::Must_not_atomContext::exitRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitMust_not_atom(this); +} +//----------------- Must_atomContext +//------------------------------------------------------------------ + +tree::TerminalNode *FtsParser::Must_atomContext::PLUS_SIGN() { + return getToken(FtsParser::PLUS_SIGN, 0); +} + +FtsParser::Fts_atomContext *FtsParser::Must_atomContext::fts_atom() { + return getRuleContext(0); +} + +FtsParser::Must_atomContext::Must_atomContext(Fts_unaryContext *ctx) { + copyFrom(ctx); +} + +void FtsParser::Must_atomContext::enterRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterMust_atom(this); +} +void FtsParser::Must_atomContext::exitRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitMust_atom(this); +} +//----------------- Plain_atomContext +//------------------------------------------------------------------ + +FtsParser::Fts_atomContext *FtsParser::Plain_atomContext::fts_atom() { + return getRuleContext(0); +} + +FtsParser::Plain_atomContext::Plain_atomContext(Fts_unaryContext *ctx) { + copyFrom(ctx); +} + +void FtsParser::Plain_atomContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterPlain_atom(this); +} +void FtsParser::Plain_atomContext::exitRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitPlain_atom(this); +} +FtsParser::Fts_unaryContext *FtsParser::fts_unary() { + Fts_unaryContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 8, FtsParser::RuleFts_unary); + + auto onExit = finally([=] { exitRule(); }); + try { + setState(59); + _errHandler->sync(this); + switch (_input->LA(1)) { + case FtsParser::PLUS_SIGN: { + _localctx = dynamic_cast( + _tracker.createInstance(_localctx)); + enterOuterAlt(_localctx, 1); + setState(54); + match(FtsParser::PLUS_SIGN); + setState(55); + fts_atom(); + break; + } + + case FtsParser::MINUS_SIGN: { + _localctx = dynamic_cast( + _tracker.createInstance( + _localctx)); + enterOuterAlt(_localctx, 2); + setState(56); + match(FtsParser::MINUS_SIGN); + setState(57); + fts_atom(); + break; + } + + case FtsParser::LP: + case FtsParser::DQUOTA_STRING: + case FtsParser::REGULAR_ID: + case FtsParser::NUMBER: + case FtsParser::TERM: + case FtsParser::DEFAULT: { + _localctx = dynamic_cast( + _tracker.createInstance(_localctx)); + enterOuterAlt(_localctx, 3); + setState(58); + fts_atom(); + break; + } + + default: + throw NoViableAltException(this); + } + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_atomContext +//------------------------------------------------------------------ + +FtsParser::Fts_atomContext::Fts_atomContext(ParserRuleContext *parent_ctx, + size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +FtsParser::Fts_primaryContext *FtsParser::Fts_atomContext::fts_primary() { + return getRuleContext(0); +} + +FtsParser::Fts_field_prefixContext * +FtsParser::Fts_atomContext::fts_field_prefix() { + return getRuleContext(0); +} + +FtsParser::Fts_boostContext *FtsParser::Fts_atomContext::fts_boost() { + return getRuleContext(0); +} + + +size_t FtsParser::Fts_atomContext::getRuleIndex() const { + return FtsParser::RuleFts_atom; +} + +void FtsParser::Fts_atomContext::enterRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_atom(this); +} + +void FtsParser::Fts_atomContext::exitRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_atom(this); +} + +FtsParser::Fts_atomContext *FtsParser::fts_atom() { + Fts_atomContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 10, FtsParser::RuleFts_atom); + size_t _la = 0; + + auto onExit = finally([=] { exitRule(); }); + try { + enterOuterAlt(_localctx, 1); + setState(62); + _errHandler->sync(this); + + switch (getInterpreter()->adaptivePredict( + _input, 6, _ctx)) { + case 1: { + setState(61); + fts_field_prefix(); + break; + } + } + setState(64); + fts_primary(); + setState(66); + _errHandler->sync(this); + + _la = _input->LA(1); + if (_la == FtsParser::CARET) { + setState(65); + fts_boost(); + } + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_field_prefixContext +//------------------------------------------------------------------ + +FtsParser::Fts_field_prefixContext::Fts_field_prefixContext( + ParserRuleContext *parent_ctx, size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +tree::TerminalNode *FtsParser::Fts_field_prefixContext::REGULAR_ID() { + return getToken(FtsParser::REGULAR_ID, 0); +} + +tree::TerminalNode *FtsParser::Fts_field_prefixContext::COLON() { + return getToken(FtsParser::COLON, 0); +} + + +size_t FtsParser::Fts_field_prefixContext::getRuleIndex() const { + return FtsParser::RuleFts_field_prefix; +} + +void FtsParser::Fts_field_prefixContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_field_prefix(this); +} + +void FtsParser::Fts_field_prefixContext::exitRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_field_prefix(this); +} + +FtsParser::Fts_field_prefixContext *FtsParser::fts_field_prefix() { + Fts_field_prefixContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 12, FtsParser::RuleFts_field_prefix); + + auto onExit = finally([=] { exitRule(); }); + try { + enterOuterAlt(_localctx, 1); + setState(68); + match(FtsParser::REGULAR_ID); + setState(69); + match(FtsParser::COLON); + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_primaryContext +//------------------------------------------------------------------ + +FtsParser::Fts_primaryContext::Fts_primaryContext(ParserRuleContext *parent_ctx, + size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +FtsParser::Fts_termContext *FtsParser::Fts_primaryContext::fts_term() { + return getRuleContext(0); +} + +FtsParser::Fts_phraseContext *FtsParser::Fts_primaryContext::fts_phrase() { + return getRuleContext(0); +} + +tree::TerminalNode *FtsParser::Fts_primaryContext::LP() { + return getToken(FtsParser::LP, 0); +} + +FtsParser::Fts_or_exprContext *FtsParser::Fts_primaryContext::fts_or_expr() { + return getRuleContext(0); +} + +tree::TerminalNode *FtsParser::Fts_primaryContext::RP() { + return getToken(FtsParser::RP, 0); +} + + +size_t FtsParser::Fts_primaryContext::getRuleIndex() const { + return FtsParser::RuleFts_primary; +} + +void FtsParser::Fts_primaryContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_primary(this); +} + +void FtsParser::Fts_primaryContext::exitRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_primary(this); +} + +FtsParser::Fts_primaryContext *FtsParser::fts_primary() { + Fts_primaryContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 14, FtsParser::RuleFts_primary); + + auto onExit = finally([=] { exitRule(); }); + try { + setState(77); + _errHandler->sync(this); + switch (_input->LA(1)) { + case FtsParser::REGULAR_ID: + case FtsParser::NUMBER: + case FtsParser::TERM: + case FtsParser::DEFAULT: { + enterOuterAlt(_localctx, 1); + setState(71); + fts_term(); + break; + } + + case FtsParser::DQUOTA_STRING: { + enterOuterAlt(_localctx, 2); + setState(72); + fts_phrase(); + break; + } + + case FtsParser::LP: { + enterOuterAlt(_localctx, 3); + setState(73); + match(FtsParser::LP); + setState(74); + fts_or_expr(); + setState(75); + match(FtsParser::RP); + break; + } + + default: + throw NoViableAltException(this); + } + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_boostContext +//------------------------------------------------------------------ + +FtsParser::Fts_boostContext::Fts_boostContext(ParserRuleContext *parent_ctx, + size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +tree::TerminalNode *FtsParser::Fts_boostContext::CARET() { + return getToken(FtsParser::CARET, 0); +} + +tree::TerminalNode *FtsParser::Fts_boostContext::NUMBER() { + return getToken(FtsParser::NUMBER, 0); +} + + +size_t FtsParser::Fts_boostContext::getRuleIndex() const { + return FtsParser::RuleFts_boost; +} + +void FtsParser::Fts_boostContext::enterRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_boost(this); +} + +void FtsParser::Fts_boostContext::exitRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_boost(this); +} + +FtsParser::Fts_boostContext *FtsParser::fts_boost() { + Fts_boostContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 16, FtsParser::RuleFts_boost); + + auto onExit = finally([=] { exitRule(); }); + try { + enterOuterAlt(_localctx, 1); + setState(79); + match(FtsParser::CARET); + setState(80); + match(FtsParser::NUMBER); + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_natural_termContext +//------------------------------------------------------------------ + +FtsParser::Fts_natural_termContext::Fts_natural_termContext( + ParserRuleContext *parent_ctx, size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +std::vector +FtsParser::Fts_natural_termContext::DEFAULT() { + return getTokens(FtsParser::DEFAULT); +} + +tree::TerminalNode *FtsParser::Fts_natural_termContext::DEFAULT(size_t i) { + return getToken(FtsParser::DEFAULT, i); +} + + +size_t FtsParser::Fts_natural_termContext::getRuleIndex() const { + return FtsParser::RuleFts_natural_term; +} + +void FtsParser::Fts_natural_termContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_natural_term(this); +} + +void FtsParser::Fts_natural_termContext::exitRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_natural_term(this); +} + +FtsParser::Fts_natural_termContext *FtsParser::fts_natural_term() { + Fts_natural_termContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 18, FtsParser::RuleFts_natural_term); + + auto onExit = finally([=] { exitRule(); }); + try { + size_t alt; + enterOuterAlt(_localctx, 1); + setState(83); + _errHandler->sync(this); + alt = 1; + do { + switch (alt) { + case 1: { + setState(82); + match(FtsParser::DEFAULT); + break; + } + + default: + throw NoViableAltException(this); + } + setState(85); + _errHandler->sync(this); + alt = getInterpreter()->adaptivePredict(_input, + 9, _ctx); + } while (alt != 2 && alt != atn::ATN::INVALID_ALT_NUMBER); + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_termContext +//------------------------------------------------------------------ + +FtsParser::Fts_termContext::Fts_termContext(ParserRuleContext *parent_ctx, + size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +tree::TerminalNode *FtsParser::Fts_termContext::TERM() { + return getToken(FtsParser::TERM, 0); +} + +tree::TerminalNode *FtsParser::Fts_termContext::REGULAR_ID() { + return getToken(FtsParser::REGULAR_ID, 0); +} + +tree::TerminalNode *FtsParser::Fts_termContext::NUMBER() { + return getToken(FtsParser::NUMBER, 0); +} + +FtsParser::Fts_natural_termContext * +FtsParser::Fts_termContext::fts_natural_term() { + return getRuleContext(0); +} + + +size_t FtsParser::Fts_termContext::getRuleIndex() const { + return FtsParser::RuleFts_term; +} + +void FtsParser::Fts_termContext::enterRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_term(this); +} + +void FtsParser::Fts_termContext::exitRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_term(this); +} + +FtsParser::Fts_termContext *FtsParser::fts_term() { + Fts_termContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 20, FtsParser::RuleFts_term); + + auto onExit = finally([=] { exitRule(); }); + try { + setState(91); + _errHandler->sync(this); + switch (_input->LA(1)) { + case FtsParser::TERM: { + enterOuterAlt(_localctx, 1); + setState(87); + match(FtsParser::TERM); + break; + } + + case FtsParser::REGULAR_ID: { + enterOuterAlt(_localctx, 2); + setState(88); + match(FtsParser::REGULAR_ID); + break; + } + + case FtsParser::NUMBER: { + enterOuterAlt(_localctx, 3); + setState(89); + match(FtsParser::NUMBER); + break; + } + + case FtsParser::DEFAULT: { + enterOuterAlt(_localctx, 4); + setState(90); + fts_natural_term(); + break; + } + + default: + throw NoViableAltException(this); + } + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +//----------------- Fts_phraseContext +//------------------------------------------------------------------ + +FtsParser::Fts_phraseContext::Fts_phraseContext(ParserRuleContext *parent_ctx, + size_t invoking_state) + : ParserRuleContext(parent_ctx, invoking_state) {} + +tree::TerminalNode *FtsParser::Fts_phraseContext::DQUOTA_STRING() { + return getToken(FtsParser::DQUOTA_STRING, 0); +} + + +size_t FtsParser::Fts_phraseContext::getRuleIndex() const { + return FtsParser::RuleFts_phrase; +} + +void FtsParser::Fts_phraseContext::enterRule( + tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->enterFts_phrase(this); +} + +void FtsParser::Fts_phraseContext::exitRule(tree::ParseTreeListener *listener) { + auto parserListener = dynamic_cast(listener); + if (parserListener != nullptr) parserListener->exitFts_phrase(this); +} + +FtsParser::Fts_phraseContext *FtsParser::fts_phrase() { + Fts_phraseContext *_localctx = + _tracker.createInstance(_ctx, getState()); + enterRule(_localctx, 22, FtsParser::RuleFts_phrase); + + auto onExit = finally([=] { exitRule(); }); + try { + enterOuterAlt(_localctx, 1); + setState(93); + match(FtsParser::DQUOTA_STRING); + + } catch (RecognitionException &e) { + _errHandler->reportError(this, e); + _localctx->exception = std::current_exception(); + _errHandler->recover(this, _localctx->exception); + } + + return _localctx; +} + +// Static vars and initialization. +std::vector FtsParser::_decisionToDFA; +atn::PredictionContextCache FtsParser::_sharedContextCache; + +// We own the ATN which in turn owns the ATN states. +atn::ATN FtsParser::_atn; +std::vector FtsParser::_serializedATN; + +std::vector FtsParser::_ruleNames = { + "fts_query_unit", "fts_or_expr", "fts_and_expr", "fts_seq_expr", + "fts_unary", "fts_atom", "fts_field_prefix", "fts_primary", + "fts_boost", "fts_natural_term", "fts_term", "fts_phrase"}; + +std::vector FtsParser::_literalNames = { + "", "", "", "", "'+'", "'-'", "':'", "'^'", "'('", "')'"}; + +std::vector FtsParser::_symbolicNames = { + "", "OR", "AND", "NOT", "PLUS_SIGN", "MINUS_SIGN", + "COLON", "CARET", "LP", "RP", "DQUOTA_STRING", "REGULAR_ID", + "NUMBER", "TERM", "SPACES", "DEFAULT"}; + +dfa::Vocabulary FtsParser::_vocabulary(_literalNames, _symbolicNames); + +std::vector FtsParser::_tokenNames; + +FtsParser::Initializer::Initializer() { + for (size_t i = 0; i < _symbolicNames.size(); ++i) { + std::string name = _vocabulary.getLiteralName(i); + if (name.empty()) { + name = _vocabulary.getSymbolicName(i); + } + + if (name.empty()) { + _tokenNames.push_back(""); + } else { + _tokenNames.push_back(name); + } + } + + _serializedATN = { + 0x3, 0x608b, 0xa72a, 0x8133, 0xb9ed, 0x417c, 0x3be7, 0x7786, 0x5964, + 0x3, 0x11, 0x62, 0x4, 0x2, 0x9, 0x2, 0x4, 0x3, + 0x9, 0x3, 0x4, 0x4, 0x9, 0x4, 0x4, 0x5, 0x9, + 0x5, 0x4, 0x6, 0x9, 0x6, 0x4, 0x7, 0x9, 0x7, + 0x4, 0x8, 0x9, 0x8, 0x4, 0x9, 0x9, 0x9, 0x4, + 0xa, 0x9, 0xa, 0x4, 0xb, 0x9, 0xb, 0x4, 0xc, + 0x9, 0xc, 0x4, 0xd, 0x9, 0xd, 0x3, 0x2, 0x3, + 0x2, 0x3, 0x2, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, + 0x7, 0x3, 0x21, 0xa, 0x3, 0xc, 0x3, 0xe, 0x3, + 0x24, 0xb, 0x3, 0x3, 0x4, 0x3, 0x4, 0x3, 0x4, + 0x5, 0x4, 0x29, 0xa, 0x4, 0x3, 0x4, 0x5, 0x4, + 0x2c, 0xa, 0x4, 0x3, 0x4, 0x7, 0x4, 0x2f, 0xa, + 0x4, 0xc, 0x4, 0xe, 0x4, 0x32, 0xb, 0x4, 0x3, + 0x5, 0x6, 0x5, 0x35, 0xa, 0x5, 0xd, 0x5, 0xe, + 0x5, 0x36, 0x3, 0x6, 0x3, 0x6, 0x3, 0x6, 0x3, + 0x6, 0x3, 0x6, 0x5, 0x6, 0x3e, 0xa, 0x6, 0x3, + 0x7, 0x5, 0x7, 0x41, 0xa, 0x7, 0x3, 0x7, 0x3, + 0x7, 0x5, 0x7, 0x45, 0xa, 0x7, 0x3, 0x8, 0x3, + 0x8, 0x3, 0x8, 0x3, 0x9, 0x3, 0x9, 0x3, 0x9, + 0x3, 0x9, 0x3, 0x9, 0x3, 0x9, 0x5, 0x9, 0x50, + 0xa, 0x9, 0x3, 0xa, 0x3, 0xa, 0x3, 0xa, 0x3, + 0xb, 0x6, 0xb, 0x56, 0xa, 0xb, 0xd, 0xb, 0xe, + 0xb, 0x57, 0x3, 0xc, 0x3, 0xc, 0x3, 0xc, 0x3, + 0xc, 0x5, 0xc, 0x5e, 0xa, 0xc, 0x3, 0xd, 0x3, + 0xd, 0x3, 0xd, 0x2, 0x2, 0xe, 0x2, 0x4, 0x6, + 0x8, 0xa, 0xc, 0xe, 0x10, 0x12, 0x14, 0x16, 0x18, + 0x2, 0x2, 0x2, 0x64, 0x2, 0x1a, 0x3, 0x2, 0x2, + 0x2, 0x4, 0x1d, 0x3, 0x2, 0x2, 0x2, 0x6, 0x25, + 0x3, 0x2, 0x2, 0x2, 0x8, 0x34, 0x3, 0x2, 0x2, + 0x2, 0xa, 0x3d, 0x3, 0x2, 0x2, 0x2, 0xc, 0x40, + 0x3, 0x2, 0x2, 0x2, 0xe, 0x46, 0x3, 0x2, 0x2, + 0x2, 0x10, 0x4f, 0x3, 0x2, 0x2, 0x2, 0x12, 0x51, + 0x3, 0x2, 0x2, 0x2, 0x14, 0x55, 0x3, 0x2, 0x2, + 0x2, 0x16, 0x5d, 0x3, 0x2, 0x2, 0x2, 0x18, 0x5f, + 0x3, 0x2, 0x2, 0x2, 0x1a, 0x1b, 0x5, 0x4, 0x3, + 0x2, 0x1b, 0x1c, 0x7, 0x2, 0x2, 0x3, 0x1c, 0x3, + 0x3, 0x2, 0x2, 0x2, 0x1d, 0x22, 0x5, 0x6, 0x4, + 0x2, 0x1e, 0x1f, 0x7, 0x3, 0x2, 0x2, 0x1f, 0x21, + 0x5, 0x6, 0x4, 0x2, 0x20, 0x1e, 0x3, 0x2, 0x2, + 0x2, 0x21, 0x24, 0x3, 0x2, 0x2, 0x2, 0x22, 0x20, + 0x3, 0x2, 0x2, 0x2, 0x22, 0x23, 0x3, 0x2, 0x2, + 0x2, 0x23, 0x5, 0x3, 0x2, 0x2, 0x2, 0x24, 0x22, + 0x3, 0x2, 0x2, 0x2, 0x25, 0x30, 0x5, 0x8, 0x5, + 0x2, 0x26, 0x28, 0x7, 0x4, 0x2, 0x2, 0x27, 0x29, + 0x7, 0x5, 0x2, 0x2, 0x28, 0x27, 0x3, 0x2, 0x2, + 0x2, 0x28, 0x29, 0x3, 0x2, 0x2, 0x2, 0x29, 0x2c, + 0x3, 0x2, 0x2, 0x2, 0x2a, 0x2c, 0x7, 0x5, 0x2, + 0x2, 0x2b, 0x26, 0x3, 0x2, 0x2, 0x2, 0x2b, 0x2a, + 0x3, 0x2, 0x2, 0x2, 0x2c, 0x2d, 0x3, 0x2, 0x2, + 0x2, 0x2d, 0x2f, 0x5, 0x8, 0x5, 0x2, 0x2e, 0x2b, + 0x3, 0x2, 0x2, 0x2, 0x2f, 0x32, 0x3, 0x2, 0x2, + 0x2, 0x30, 0x2e, 0x3, 0x2, 0x2, 0x2, 0x30, 0x31, + 0x3, 0x2, 0x2, 0x2, 0x31, 0x7, 0x3, 0x2, 0x2, + 0x2, 0x32, 0x30, 0x3, 0x2, 0x2, 0x2, 0x33, 0x35, + 0x5, 0xa, 0x6, 0x2, 0x34, 0x33, 0x3, 0x2, 0x2, + 0x2, 0x35, 0x36, 0x3, 0x2, 0x2, 0x2, 0x36, 0x34, + 0x3, 0x2, 0x2, 0x2, 0x36, 0x37, 0x3, 0x2, 0x2, + 0x2, 0x37, 0x9, 0x3, 0x2, 0x2, 0x2, 0x38, 0x39, + 0x7, 0x6, 0x2, 0x2, 0x39, 0x3e, 0x5, 0xc, 0x7, + 0x2, 0x3a, 0x3b, 0x7, 0x7, 0x2, 0x2, 0x3b, 0x3e, + 0x5, 0xc, 0x7, 0x2, 0x3c, 0x3e, 0x5, 0xc, 0x7, + 0x2, 0x3d, 0x38, 0x3, 0x2, 0x2, 0x2, 0x3d, 0x3a, + 0x3, 0x2, 0x2, 0x2, 0x3d, 0x3c, 0x3, 0x2, 0x2, + 0x2, 0x3e, 0xb, 0x3, 0x2, 0x2, 0x2, 0x3f, 0x41, + 0x5, 0xe, 0x8, 0x2, 0x40, 0x3f, 0x3, 0x2, 0x2, + 0x2, 0x40, 0x41, 0x3, 0x2, 0x2, 0x2, 0x41, 0x42, + 0x3, 0x2, 0x2, 0x2, 0x42, 0x44, 0x5, 0x10, 0x9, + 0x2, 0x43, 0x45, 0x5, 0x12, 0xa, 0x2, 0x44, 0x43, + 0x3, 0x2, 0x2, 0x2, 0x44, 0x45, 0x3, 0x2, 0x2, + 0x2, 0x45, 0xd, 0x3, 0x2, 0x2, 0x2, 0x46, 0x47, + 0x7, 0xd, 0x2, 0x2, 0x47, 0x48, 0x7, 0x8, 0x2, + 0x2, 0x48, 0xf, 0x3, 0x2, 0x2, 0x2, 0x49, 0x50, + 0x5, 0x16, 0xc, 0x2, 0x4a, 0x50, 0x5, 0x18, 0xd, + 0x2, 0x4b, 0x4c, 0x7, 0xa, 0x2, 0x2, 0x4c, 0x4d, + 0x5, 0x4, 0x3, 0x2, 0x4d, 0x4e, 0x7, 0xb, 0x2, + 0x2, 0x4e, 0x50, 0x3, 0x2, 0x2, 0x2, 0x4f, 0x49, + 0x3, 0x2, 0x2, 0x2, 0x4f, 0x4a, 0x3, 0x2, 0x2, + 0x2, 0x4f, 0x4b, 0x3, 0x2, 0x2, 0x2, 0x50, 0x11, + 0x3, 0x2, 0x2, 0x2, 0x51, 0x52, 0x7, 0x9, 0x2, + 0x2, 0x52, 0x53, 0x7, 0xe, 0x2, 0x2, 0x53, 0x13, + 0x3, 0x2, 0x2, 0x2, 0x54, 0x56, 0x7, 0x11, 0x2, + 0x2, 0x55, 0x54, 0x3, 0x2, 0x2, 0x2, 0x56, 0x57, + 0x3, 0x2, 0x2, 0x2, 0x57, 0x55, 0x3, 0x2, 0x2, + 0x2, 0x57, 0x58, 0x3, 0x2, 0x2, 0x2, 0x58, 0x15, + 0x3, 0x2, 0x2, 0x2, 0x59, 0x5e, 0x7, 0xf, 0x2, + 0x2, 0x5a, 0x5e, 0x7, 0xd, 0x2, 0x2, 0x5b, 0x5e, + 0x7, 0xe, 0x2, 0x2, 0x5c, 0x5e, 0x5, 0x14, 0xb, + 0x2, 0x5d, 0x59, 0x3, 0x2, 0x2, 0x2, 0x5d, 0x5a, + 0x3, 0x2, 0x2, 0x2, 0x5d, 0x5b, 0x3, 0x2, 0x2, + 0x2, 0x5d, 0x5c, 0x3, 0x2, 0x2, 0x2, 0x5e, 0x17, + 0x3, 0x2, 0x2, 0x2, 0x5f, 0x60, 0x7, 0xc, 0x2, + 0x2, 0x60, 0x19, 0x3, 0x2, 0x2, 0x2, 0xd, 0x22, + 0x28, 0x2b, 0x30, 0x36, 0x3d, 0x40, 0x44, 0x4f, 0x57, + 0x5d, + }; + + atn::ATNDeserializer deserializer; + _atn = deserializer.deserialize(_serializedATN); + + size_t count = _atn.getNumberOfDecisions(); + _decisionToDFA.reserve(count); + for (size_t i = 0; i < count; i++) { + _decisionToDFA.emplace_back(_atn.getDecisionState(i), i); + } +} + +FtsParser::Initializer FtsParser::_init; diff --git a/src/db/index/column/fts_column/gen/FtsParser.h b/src/db/index/column/fts_column/gen/FtsParser.h new file mode 100644 index 000000000..3f291557b --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsParser.h @@ -0,0 +1,303 @@ + +// Generated from FtsParser.g4 by ANTLR 4.8 + +#pragma once + + +#include "antlr4-runtime.h" + + +namespace antlr4 { + + +class FtsParser : public antlr4::Parser { + public: + enum { + OR = 1, + AND = 2, + NOT = 3, + PLUS_SIGN = 4, + MINUS_SIGN = 5, + COLON = 6, + CARET = 7, + LP = 8, + RP = 9, + DQUOTA_STRING = 10, + REGULAR_ID = 11, + NUMBER = 12, + TERM = 13, + SPACES = 14, + DEFAULT = 15 + }; + + enum { + RuleFts_query_unit = 0, + RuleFts_or_expr = 1, + RuleFts_and_expr = 2, + RuleFts_seq_expr = 3, + RuleFts_unary = 4, + RuleFts_atom = 5, + RuleFts_field_prefix = 6, + RuleFts_primary = 7, + RuleFts_boost = 8, + RuleFts_natural_term = 9, + RuleFts_term = 10, + RuleFts_phrase = 11 + }; + + FtsParser(antlr4::TokenStream *input); + ~FtsParser(); + + virtual std::string getGrammarFileName() const override; + virtual const antlr4::atn::ATN &getATN() const override { + return _atn; + }; + virtual const std::vector &getTokenNames() const override { + return _tokenNames; + }; // deprecated: use vocabulary instead. + virtual const std::vector &getRuleNames() const override; + virtual antlr4::dfa::Vocabulary &getVocabulary() const override; + + + class Fts_query_unitContext; + class Fts_or_exprContext; + class Fts_and_exprContext; + class Fts_seq_exprContext; + class Fts_unaryContext; + class Fts_atomContext; + class Fts_field_prefixContext; + class Fts_primaryContext; + class Fts_boostContext; + class Fts_natural_termContext; + class Fts_termContext; + class Fts_phraseContext; + + class Fts_query_unitContext : public antlr4::ParserRuleContext { + public: + Fts_query_unitContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + Fts_or_exprContext *fts_or_expr(); + antlr4::tree::TerminalNode *EOF(); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_query_unitContext *fts_query_unit(); + + class Fts_or_exprContext : public antlr4::ParserRuleContext { + public: + Fts_or_exprContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + std::vector fts_and_expr(); + Fts_and_exprContext *fts_and_expr(size_t i); + std::vector OR(); + antlr4::tree::TerminalNode *OR(size_t i); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_or_exprContext *fts_or_expr(); + + class Fts_and_exprContext : public antlr4::ParserRuleContext { + public: + Fts_and_exprContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + std::vector fts_seq_expr(); + Fts_seq_exprContext *fts_seq_expr(size_t i); + std::vector AND(); + antlr4::tree::TerminalNode *AND(size_t i); + std::vector NOT(); + antlr4::tree::TerminalNode *NOT(size_t i); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_and_exprContext *fts_and_expr(); + + class Fts_seq_exprContext : public antlr4::ParserRuleContext { + public: + Fts_seq_exprContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + std::vector fts_unary(); + Fts_unaryContext *fts_unary(size_t i); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_seq_exprContext *fts_seq_expr(); + + class Fts_unaryContext : public antlr4::ParserRuleContext { + public: + Fts_unaryContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + + Fts_unaryContext() = default; + void copyFrom(Fts_unaryContext *context); + using antlr4::ParserRuleContext::copyFrom; + + virtual size_t getRuleIndex() const override; + }; + + class Must_not_atomContext : public Fts_unaryContext { + public: + Must_not_atomContext(Fts_unaryContext *ctx); + + antlr4::tree::TerminalNode *MINUS_SIGN(); + Fts_atomContext *fts_atom(); + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + class Must_atomContext : public Fts_unaryContext { + public: + Must_atomContext(Fts_unaryContext *ctx); + + antlr4::tree::TerminalNode *PLUS_SIGN(); + Fts_atomContext *fts_atom(); + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + class Plain_atomContext : public Fts_unaryContext { + public: + Plain_atomContext(Fts_unaryContext *ctx); + + Fts_atomContext *fts_atom(); + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_unaryContext *fts_unary(); + + class Fts_atomContext : public antlr4::ParserRuleContext { + public: + Fts_atomContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + Fts_primaryContext *fts_primary(); + Fts_field_prefixContext *fts_field_prefix(); + Fts_boostContext *fts_boost(); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_atomContext *fts_atom(); + + class Fts_field_prefixContext : public antlr4::ParserRuleContext { + public: + Fts_field_prefixContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *REGULAR_ID(); + antlr4::tree::TerminalNode *COLON(); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_field_prefixContext *fts_field_prefix(); + + class Fts_primaryContext : public antlr4::ParserRuleContext { + public: + Fts_primaryContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + Fts_termContext *fts_term(); + Fts_phraseContext *fts_phrase(); + antlr4::tree::TerminalNode *LP(); + Fts_or_exprContext *fts_or_expr(); + antlr4::tree::TerminalNode *RP(); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_primaryContext *fts_primary(); + + class Fts_boostContext : public antlr4::ParserRuleContext { + public: + Fts_boostContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *CARET(); + antlr4::tree::TerminalNode *NUMBER(); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_boostContext *fts_boost(); + + class Fts_natural_termContext : public antlr4::ParserRuleContext { + public: + Fts_natural_termContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + std::vector DEFAULT(); + antlr4::tree::TerminalNode *DEFAULT(size_t i); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_natural_termContext *fts_natural_term(); + + class Fts_termContext : public antlr4::ParserRuleContext { + public: + Fts_termContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *TERM(); + antlr4::tree::TerminalNode *REGULAR_ID(); + antlr4::tree::TerminalNode *NUMBER(); + Fts_natural_termContext *fts_natural_term(); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_termContext *fts_term(); + + class Fts_phraseContext : public antlr4::ParserRuleContext { + public: + Fts_phraseContext(antlr4::ParserRuleContext *parent_ctx, + size_t invoking_state); + virtual size_t getRuleIndex() const override; + antlr4::tree::TerminalNode *DQUOTA_STRING(); + + virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override; + virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override; + }; + + Fts_phraseContext *fts_phrase(); + + + private: + static std::vector _decisionToDFA; + static antlr4::atn::PredictionContextCache _sharedContextCache; + static std::vector _ruleNames; + static std::vector _tokenNames; + + static std::vector _literalNames; + static std::vector _symbolicNames; + static antlr4::dfa::Vocabulary _vocabulary; + static antlr4::atn::ATN _atn; + static std::vector _serializedATN; + + + struct Initializer { + Initializer(); + }; + static Initializer _init; +}; + +} // namespace antlr4 diff --git a/src/db/index/column/fts_column/gen/FtsParser.interp b/src/db/index/column/fts_column/gen/FtsParser.interp new file mode 100644 index 000000000..88d3cfe81 --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsParser.interp @@ -0,0 +1,53 @@ +token literal names: +null +null +null +null +'+' +'-' +':' +'^' +'(' +')' +null +null +null +null +null +null + +token symbolic names: +null +OR +AND +NOT +PLUS_SIGN +MINUS_SIGN +COLON +CARET +LP +RP +DQUOTA_STRING +REGULAR_ID +NUMBER +TERM +SPACES +DEFAULT + +rule names: +fts_query_unit +fts_or_expr +fts_and_expr +fts_seq_expr +fts_unary +fts_atom +fts_field_prefix +fts_primary +fts_boost +fts_natural_term +fts_term +fts_phrase + + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 17, 98, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 3, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 7, 3, 33, 10, 3, 12, 3, 14, 3, 36, 11, 3, 3, 4, 3, 4, 3, 4, 5, 4, 41, 10, 4, 3, 4, 5, 4, 44, 10, 4, 3, 4, 7, 4, 47, 10, 4, 12, 4, 14, 4, 50, 11, 4, 3, 5, 6, 5, 53, 10, 5, 13, 5, 14, 5, 54, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 5, 6, 62, 10, 6, 3, 7, 5, 7, 65, 10, 7, 3, 7, 3, 7, 5, 7, 69, 10, 7, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 9, 3, 9, 3, 9, 3, 9, 5, 9, 80, 10, 9, 3, 10, 3, 10, 3, 10, 3, 11, 6, 11, 86, 10, 11, 13, 11, 14, 11, 87, 3, 12, 3, 12, 3, 12, 3, 12, 5, 12, 94, 10, 12, 3, 13, 3, 13, 3, 13, 2, 2, 14, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 2, 2, 2, 100, 2, 26, 3, 2, 2, 2, 4, 29, 3, 2, 2, 2, 6, 37, 3, 2, 2, 2, 8, 52, 3, 2, 2, 2, 10, 61, 3, 2, 2, 2, 12, 64, 3, 2, 2, 2, 14, 70, 3, 2, 2, 2, 16, 79, 3, 2, 2, 2, 18, 81, 3, 2, 2, 2, 20, 85, 3, 2, 2, 2, 22, 93, 3, 2, 2, 2, 24, 95, 3, 2, 2, 2, 26, 27, 5, 4, 3, 2, 27, 28, 7, 2, 2, 3, 28, 3, 3, 2, 2, 2, 29, 34, 5, 6, 4, 2, 30, 31, 7, 3, 2, 2, 31, 33, 5, 6, 4, 2, 32, 30, 3, 2, 2, 2, 33, 36, 3, 2, 2, 2, 34, 32, 3, 2, 2, 2, 34, 35, 3, 2, 2, 2, 35, 5, 3, 2, 2, 2, 36, 34, 3, 2, 2, 2, 37, 48, 5, 8, 5, 2, 38, 40, 7, 4, 2, 2, 39, 41, 7, 5, 2, 2, 40, 39, 3, 2, 2, 2, 40, 41, 3, 2, 2, 2, 41, 44, 3, 2, 2, 2, 42, 44, 7, 5, 2, 2, 43, 38, 3, 2, 2, 2, 43, 42, 3, 2, 2, 2, 44, 45, 3, 2, 2, 2, 45, 47, 5, 8, 5, 2, 46, 43, 3, 2, 2, 2, 47, 50, 3, 2, 2, 2, 48, 46, 3, 2, 2, 2, 48, 49, 3, 2, 2, 2, 49, 7, 3, 2, 2, 2, 50, 48, 3, 2, 2, 2, 51, 53, 5, 10, 6, 2, 52, 51, 3, 2, 2, 2, 53, 54, 3, 2, 2, 2, 54, 52, 3, 2, 2, 2, 54, 55, 3, 2, 2, 2, 55, 9, 3, 2, 2, 2, 56, 57, 7, 6, 2, 2, 57, 62, 5, 12, 7, 2, 58, 59, 7, 7, 2, 2, 59, 62, 5, 12, 7, 2, 60, 62, 5, 12, 7, 2, 61, 56, 3, 2, 2, 2, 61, 58, 3, 2, 2, 2, 61, 60, 3, 2, 2, 2, 62, 11, 3, 2, 2, 2, 63, 65, 5, 14, 8, 2, 64, 63, 3, 2, 2, 2, 64, 65, 3, 2, 2, 2, 65, 66, 3, 2, 2, 2, 66, 68, 5, 16, 9, 2, 67, 69, 5, 18, 10, 2, 68, 67, 3, 2, 2, 2, 68, 69, 3, 2, 2, 2, 69, 13, 3, 2, 2, 2, 70, 71, 7, 13, 2, 2, 71, 72, 7, 8, 2, 2, 72, 15, 3, 2, 2, 2, 73, 80, 5, 22, 12, 2, 74, 80, 5, 24, 13, 2, 75, 76, 7, 10, 2, 2, 76, 77, 5, 4, 3, 2, 77, 78, 7, 11, 2, 2, 78, 80, 3, 2, 2, 2, 79, 73, 3, 2, 2, 2, 79, 74, 3, 2, 2, 2, 79, 75, 3, 2, 2, 2, 80, 17, 3, 2, 2, 2, 81, 82, 7, 9, 2, 2, 82, 83, 7, 14, 2, 2, 83, 19, 3, 2, 2, 2, 84, 86, 7, 17, 2, 2, 85, 84, 3, 2, 2, 2, 86, 87, 3, 2, 2, 2, 87, 85, 3, 2, 2, 2, 87, 88, 3, 2, 2, 2, 88, 21, 3, 2, 2, 2, 89, 94, 7, 15, 2, 2, 90, 94, 7, 13, 2, 2, 91, 94, 7, 14, 2, 2, 92, 94, 5, 20, 11, 2, 93, 89, 3, 2, 2, 2, 93, 90, 3, 2, 2, 2, 93, 91, 3, 2, 2, 2, 93, 92, 3, 2, 2, 2, 94, 23, 3, 2, 2, 2, 95, 96, 7, 12, 2, 2, 96, 25, 3, 2, 2, 2, 13, 34, 40, 43, 48, 54, 61, 64, 68, 79, 87, 93] diff --git a/src/db/index/column/fts_column/gen/FtsParser.tokens b/src/db/index/column/fts_column/gen/FtsParser.tokens new file mode 100644 index 000000000..cd6e2db20 --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsParser.tokens @@ -0,0 +1,21 @@ +OR=1 +AND=2 +NOT=3 +PLUS_SIGN=4 +MINUS_SIGN=5 +COLON=6 +CARET=7 +LP=8 +RP=9 +DQUOTA_STRING=10 +REGULAR_ID=11 +NUMBER=12 +TERM=13 +SPACES=14 +DEFAULT=15 +'+'=4 +'-'=5 +':'=6 +'^'=7 +'('=8 +')'=9 diff --git a/src/db/index/column/fts_column/gen/FtsParserBaseListener.cc b/src/db/index/column/fts_column/gen/FtsParserBaseListener.cc new file mode 100644 index 000000000..a78804a3a --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsParserBaseListener.cc @@ -0,0 +1,8 @@ + +// Generated from FtsParser.g4 by ANTLR 4.8 + + +#include "FtsParserBaseListener.h" + + +using namespace antlr4; diff --git a/src/db/index/column/fts_column/gen/FtsParserBaseListener.h b/src/db/index/column/fts_column/gen/FtsParserBaseListener.h new file mode 100644 index 000000000..e88465570 --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsParserBaseListener.h @@ -0,0 +1,89 @@ + +// Generated from FtsParser.g4 by ANTLR 4.8 + +#pragma once + + +#include "FtsParserListener.h" +#include "antlr4-runtime.h" + + +namespace antlr4 { + +/** + * This class provides an empty implementation of FtsParserListener, + * which can be extended to create a listener which only needs to handle a + * subset of the available methods. + */ +class FtsParserBaseListener : public FtsParserListener { + public: + virtual void enterFts_query_unit( + FtsParser::Fts_query_unitContext * /*ctx*/) override {} + virtual void exitFts_query_unit( + FtsParser::Fts_query_unitContext * /*ctx*/) override {} + + virtual void enterFts_or_expr( + FtsParser::Fts_or_exprContext * /*ctx*/) override {} + virtual void exitFts_or_expr( + FtsParser::Fts_or_exprContext * /*ctx*/) override {} + + virtual void enterFts_and_expr( + FtsParser::Fts_and_exprContext * /*ctx*/) override {} + virtual void exitFts_and_expr( + FtsParser::Fts_and_exprContext * /*ctx*/) override {} + + virtual void enterFts_seq_expr( + FtsParser::Fts_seq_exprContext * /*ctx*/) override {} + virtual void exitFts_seq_expr( + FtsParser::Fts_seq_exprContext * /*ctx*/) override {} + + virtual void enterMust_atom(FtsParser::Must_atomContext * /*ctx*/) override {} + virtual void exitMust_atom(FtsParser::Must_atomContext * /*ctx*/) override {} + + virtual void enterMust_not_atom( + FtsParser::Must_not_atomContext * /*ctx*/) override {} + virtual void exitMust_not_atom( + FtsParser::Must_not_atomContext * /*ctx*/) override {} + + virtual void enterPlain_atom( + FtsParser::Plain_atomContext * /*ctx*/) override {} + virtual void exitPlain_atom(FtsParser::Plain_atomContext * /*ctx*/) override { + } + + virtual void enterFts_atom(FtsParser::Fts_atomContext * /*ctx*/) override {} + virtual void exitFts_atom(FtsParser::Fts_atomContext * /*ctx*/) override {} + + virtual void enterFts_field_prefix( + FtsParser::Fts_field_prefixContext * /*ctx*/) override {} + virtual void exitFts_field_prefix( + FtsParser::Fts_field_prefixContext * /*ctx*/) override {} + + virtual void enterFts_primary( + FtsParser::Fts_primaryContext * /*ctx*/) override {} + virtual void exitFts_primary( + FtsParser::Fts_primaryContext * /*ctx*/) override {} + + virtual void enterFts_boost(FtsParser::Fts_boostContext * /*ctx*/) override {} + virtual void exitFts_boost(FtsParser::Fts_boostContext * /*ctx*/) override {} + + virtual void enterFts_natural_term( + FtsParser::Fts_natural_termContext * /*ctx*/) override {} + virtual void exitFts_natural_term( + FtsParser::Fts_natural_termContext * /*ctx*/) override {} + + virtual void enterFts_term(FtsParser::Fts_termContext * /*ctx*/) override {} + virtual void exitFts_term(FtsParser::Fts_termContext * /*ctx*/) override {} + + virtual void enterFts_phrase( + FtsParser::Fts_phraseContext * /*ctx*/) override {} + virtual void exitFts_phrase(FtsParser::Fts_phraseContext * /*ctx*/) override { + } + + + virtual void enterEveryRule(antlr4::ParserRuleContext * /*ctx*/) override {} + virtual void exitEveryRule(antlr4::ParserRuleContext * /*ctx*/) override {} + virtual void visitTerminal(antlr4::tree::TerminalNode * /*node*/) override {} + virtual void visitErrorNode(antlr4::tree::ErrorNode * /*node*/) override {} +}; + +} // namespace antlr4 diff --git a/src/db/index/column/fts_column/gen/FtsParserListener.cc b/src/db/index/column/fts_column/gen/FtsParserListener.cc new file mode 100644 index 000000000..b794fd4db --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsParserListener.cc @@ -0,0 +1,8 @@ + +// Generated from FtsParser.g4 by ANTLR 4.8 + + +#include "FtsParserListener.h" + + +using namespace antlr4; diff --git a/src/db/index/column/fts_column/gen/FtsParserListener.h b/src/db/index/column/fts_column/gen/FtsParserListener.h new file mode 100644 index 000000000..71be04b8a --- /dev/null +++ b/src/db/index/column/fts_column/gen/FtsParserListener.h @@ -0,0 +1,66 @@ + +// Generated from FtsParser.g4 by ANTLR 4.8 + +#pragma once + + +#include "FtsParser.h" +#include "antlr4-runtime.h" + + +namespace antlr4 { + +/** + * This interface defines an abstract listener for a parse tree produced by + * FtsParser. + */ +class FtsParserListener : public antlr4::tree::ParseTreeListener { + public: + virtual void enterFts_query_unit(FtsParser::Fts_query_unitContext *ctx) = 0; + virtual void exitFts_query_unit(FtsParser::Fts_query_unitContext *ctx) = 0; + + virtual void enterFts_or_expr(FtsParser::Fts_or_exprContext *ctx) = 0; + virtual void exitFts_or_expr(FtsParser::Fts_or_exprContext *ctx) = 0; + + virtual void enterFts_and_expr(FtsParser::Fts_and_exprContext *ctx) = 0; + virtual void exitFts_and_expr(FtsParser::Fts_and_exprContext *ctx) = 0; + + virtual void enterFts_seq_expr(FtsParser::Fts_seq_exprContext *ctx) = 0; + virtual void exitFts_seq_expr(FtsParser::Fts_seq_exprContext *ctx) = 0; + + virtual void enterMust_atom(FtsParser::Must_atomContext *ctx) = 0; + virtual void exitMust_atom(FtsParser::Must_atomContext *ctx) = 0; + + virtual void enterMust_not_atom(FtsParser::Must_not_atomContext *ctx) = 0; + virtual void exitMust_not_atom(FtsParser::Must_not_atomContext *ctx) = 0; + + virtual void enterPlain_atom(FtsParser::Plain_atomContext *ctx) = 0; + virtual void exitPlain_atom(FtsParser::Plain_atomContext *ctx) = 0; + + virtual void enterFts_atom(FtsParser::Fts_atomContext *ctx) = 0; + virtual void exitFts_atom(FtsParser::Fts_atomContext *ctx) = 0; + + virtual void enterFts_field_prefix( + FtsParser::Fts_field_prefixContext *ctx) = 0; + virtual void exitFts_field_prefix( + FtsParser::Fts_field_prefixContext *ctx) = 0; + + virtual void enterFts_primary(FtsParser::Fts_primaryContext *ctx) = 0; + virtual void exitFts_primary(FtsParser::Fts_primaryContext *ctx) = 0; + + virtual void enterFts_boost(FtsParser::Fts_boostContext *ctx) = 0; + virtual void exitFts_boost(FtsParser::Fts_boostContext *ctx) = 0; + + virtual void enterFts_natural_term( + FtsParser::Fts_natural_termContext *ctx) = 0; + virtual void exitFts_natural_term( + FtsParser::Fts_natural_termContext *ctx) = 0; + + virtual void enterFts_term(FtsParser::Fts_termContext *ctx) = 0; + virtual void exitFts_term(FtsParser::Fts_termContext *ctx) = 0; + + virtual void enterFts_phrase(FtsParser::Fts_phraseContext *ctx) = 0; + virtual void exitFts_phrase(FtsParser::Fts_phraseContext *ctx) = 0; +}; + +} // namespace antlr4 diff --git a/src/db/index/column/fts_column/gen_parser.sh b/src/db/index/column/fts_column/gen_parser.sh new file mode 100644 index 000000000..8797a4d5e --- /dev/null +++ b/src/db/index/column/fts_column/gen_parser.sh @@ -0,0 +1,9 @@ +#!/bin/sh +#****************************************************************# +# ScriptName: gen_parser.sh +# Author: fancy.lf +# Function: command to generate antlr sql parser code in se directory +#***************************************************************# + +java -jar ../../../../deps/thirdparty/antlr/antlr-4.8-complete.jar -Dlanguage=Cpp -package antlr4 FtsLexer.g4 FtsParser.g4 -o gen +sed -i 's/\bu8"/"/g' gen/*.cc diff --git a/src/db/index/column/fts_column/iterator/fts_candidate_iterator.cc b/src/db/index/column/fts_column/iterator/fts_candidate_iterator.cc new file mode 100644 index 000000000..5b1d3687d --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_candidate_iterator.cc @@ -0,0 +1,53 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_candidate_iterator.h" +#include + +namespace zvec::fts { + +CandidateDocIterator::CandidateDocIterator( + const std::vector &sorted_local_ids) { + ids_.reserve(sorted_local_ids.size()); + for (uint64_t id : sorted_local_ids) { + ids_.push_back(static_cast(id)); + } + cached_max_score_ = 0.0f; +} + + +uint32_t CandidateDocIterator::next_doc() { + if (pos_ >= ids_.size()) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + cached_doc_id_ = ids_[pos_++]; + return cached_doc_id_; +} + +uint32_t CandidateDocIterator::advance(uint32_t target) { + // Start from pos_: everything before it is already consumed. + auto begin = ids_.begin() + pos_; + auto it = std::lower_bound(begin, ids_.end(), target); + if (it == ids_.end()) { + pos_ = ids_.size(); + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + pos_ = static_cast(it - ids_.begin()) + 1; + cached_doc_id_ = *it; + return cached_doc_id_; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_candidate_iterator.h b/src/db/index/column/fts_column/iterator/fts_candidate_iterator.h new file mode 100644 index 000000000..5f7cce1dd --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_candidate_iterator.h @@ -0,0 +1,55 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "fts_doc_iterator.h" + +namespace zvec::fts { + +/*! Candidate-driven document iterator. + * + * AND-ed with an FTS iterator tree under ConjunctionIterator: since cost() + * returns the (small) candidate count, this iterator becomes the lead and + * the FTS tree is only asked to advance() to each candidate — reusing the + * existing BM25 / matches / filter-pushdown machinery. + * + * Input MUST be ascending segment-local doc_ids (the space TermDocIterator + * uses; no GLOBAL→LOCAL translation needed in zvec). + */ +class CandidateDocIterator : public DocIterator { + public: + explicit CandidateDocIterator(const std::vector &sorted_local_ids); + + uint32_t next_doc() override; + uint32_t advance(uint32_t target) override; + + float score() override { + return 0.0f; + } + uint64_t cost() const override { + return ids_.size(); + } + float max_score() const override { + return 0.0f; + } + + private: + std::vector ids_; // ascending segment-local doc_ids + size_t pos_{0}; // index of next element to return +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_conjunction_iterator.cc b/src/db/index/column/fts_column/iterator/fts_conjunction_iterator.cc new file mode 100644 index 000000000..51e92c44c --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_conjunction_iterator.cc @@ -0,0 +1,187 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_conjunction_iterator.h" +#include + +namespace zvec::fts { + +ConjunctionIterator::ConjunctionIterator( + std::vector must_iterators, + std::vector must_not_iterators) + : must_iterators_(std::move(must_iterators)), + must_not_iterators_(std::move(must_not_iterators)) { + // Sort must iterators by cost (ascending) so the cheapest leads + std::sort(must_iterators_.begin(), must_iterators_.end(), + [](const DocIteratorPtr &a, const DocIteratorPtr &b) { + return a->cost() < b->cost(); + }); + // Compute and cache max_score in base class field + float total = 0.0f; + for (auto &iter : must_iterators_) { + total += iter->cached_max_score_; + } + cached_max_score_ = total; +} + +uint32_t ConjunctionIterator::next_doc() { + if (must_iterators_.empty()) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // MaxScore pruning: If the maximum possible score of this AND node + // cannot beat the threshold, terminate iteration early. + if (min_competitive_score_ > 0.0f && max_score() < min_competitive_score_) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // Advance the lead iterator and try to find agreement + uint32_t candidate = must_iterators_[0]->next_doc(); + cached_doc_id_ = do_next(candidate); + return cached_doc_id_; +} + +uint32_t ConjunctionIterator::next_doc(const zvec::IndexFilter *filter) { + if (must_iterators_.empty()) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // MaxScore pruning + if (min_competitive_score_ > 0.0f && max_score() < min_competitive_score_) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // Lead iterator advances with filter-awareness so filtered docs never + // reach do_next() alignment. + uint32_t candidate = must_iterators_[0]->next_doc(filter); + while (candidate != NO_MORE_DOCS) { + candidate = do_next(candidate); + if (candidate == NO_MORE_DOCS || !filter->is_filtered(candidate)) { + break; + } + // do_next may have re-anchored the lead onto a filtered doc; advance + // the lead past it (still filter-aware) and try again. + candidate = must_iterators_[0]->next_doc(filter); + } + cached_doc_id_ = candidate; + return candidate; +} + +uint32_t ConjunctionIterator::advance(uint32_t target) { + if (must_iterators_.empty()) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // MaxScore pruning + if (min_competitive_score_ > 0.0f && max_score() < min_competitive_score_) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + uint32_t candidate = must_iterators_[0]->advance(target); + cached_doc_id_ = do_next(candidate); + return cached_doc_id_; +} + +uint32_t ConjunctionIterator::do_next(uint32_t candidate) { + if (candidate == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + + while (true) { + // Try to advance all other must iterators to the candidate + bool all_match = true; + for (size_t i = 1; i < must_iterators_.size(); ++i) { + uint32_t other_doc = must_iterators_[i]->advance(candidate); + if (other_doc == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + if (other_doc != candidate) { + // Mismatch: use the higher doc_id as the new candidate + // and re-advance the lead iterator + candidate = must_iterators_[0]->advance(other_doc); + if (candidate == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + all_match = false; + break; + } + } + + if (all_match) { + // All must iterators agree on this candidate + // Check must_not exclusion + if (!is_excluded(candidate)) { + return candidate; + } + // Excluded by must_not, advance lead to next doc + candidate = must_iterators_[0]->next_doc(); + if (candidate == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + } + } +} + +bool ConjunctionIterator::is_excluded(uint32_t candidate) { + for (auto ¬_iter : must_not_iterators_) { + uint32_t not_doc = not_iter->advance(candidate); + if (not_doc == candidate) { + // This document is excluded by a must_not clause + return true; + } + } + return false; +} + +bool ConjunctionIterator::matches() { + // Phase-2 verification: all must sub-iterators must pass matches() + for (auto &iter : must_iterators_) { + if (!iter->matches()) { + return false; + } + } + return true; +} + +float ConjunctionIterator::score() { + float total = 0.0f; + for (auto &iter : must_iterators_) { + total += iter->score(); + } + return total; +} + +uint64_t ConjunctionIterator::cost() const { + if (must_iterators_.empty()) { + return 0; + } + // Cost is determined by the shortest (lead) iterator + return must_iterators_[0]->cost(); +} + +float ConjunctionIterator::max_score() const { + float total = 0.0f; + for (auto &iter : must_iterators_) { + total += iter->max_score(); + } + return total; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_conjunction_iterator.h b/src/db/index/column/fts_column/iterator/fts_conjunction_iterator.h new file mode 100644 index 000000000..561fa8f07 --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_conjunction_iterator.h @@ -0,0 +1,69 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "fts_doc_iterator.h" + +namespace zvec::fts { + +/*! Conjunction (AND) document iterator + * + * Implements multi-way intersection of must sub-iterators with must_not + * exclusion filtering. The lead iterator (lowest cost) drives the iteration; + * other iterators are advanced to match the lead's current doc_id. + */ +class ConjunctionIterator : public DocIterator { + public: + /*! Construct a conjunction iterator. + * \param must_iterators Sub-iterators that must all match (AND) + * \param must_not_iterators Sub-iterators whose matches are excluded (NOT) + */ + ConjunctionIterator(std::vector must_iterators, + std::vector must_not_iterators); + + uint32_t next_doc() override; + //! Internal-driven filter skip: pushes filter into the lead iterator so + //! filtered candidates never trigger the do_next alignment cascade. + uint32_t next_doc(const zvec::IndexFilter *filter) override; + uint32_t advance(uint32_t target) override; + bool matches() override; + float score() override; + uint64_t cost() const override; + float max_score() const override; + + void set_min_competitive_score(float min_score) override { + min_competitive_score_ = min_score; + } + + private: + // Try to find the next doc_id where all must iterators agree, + // starting from the lead iterator's current position. + // Returns NO_MORE_DOCS if no such document exists. + uint32_t do_next(uint32_t candidate); + + // Check if candidate doc_id is excluded by any must_not iterator + bool is_excluded(uint32_t candidate); + + private: + // must_iterators_[0] is the lead (lowest cost) + std::vector must_iterators_; + std::vector must_not_iterators_; + float min_competitive_score_{0.0f}; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_disjunction_iterator.cc b/src/db/index/column/fts_column/iterator/fts_disjunction_iterator.cc new file mode 100644 index 000000000..785f7f0fd --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_disjunction_iterator.cc @@ -0,0 +1,258 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_disjunction_iterator.h" +#include + +namespace zvec::fts { + +namespace { + +// Move element at `idx` forward (toward higher indices) to restore sorted +// order. Only the element at `idx` may be out of place; all other elements +// must already be sorted. +inline void sift_forward(std::vector &vec, size_t idx) { + DocIterator *elem = vec[idx]; + uint32_t elem_doc = elem->cached_doc_id_; + size_t pos = idx; + size_t end = vec.size(); + while (pos + 1 < end && vec[pos + 1]->cached_doc_id_ < elem_doc) { + vec[pos] = vec[pos + 1]; + ++pos; + } + vec[pos] = elem; +} + +} // namespace + +DisjunctionIterator::DisjunctionIterator( + std::vector sub_iterators) + : sub_iterators_(std::move(sub_iterators)) { + // Initialize each sub-iterator to its first doc and prepare postings array + total_cost_ = 0; + total_max_score_ = 0.0f; + for (auto &iter : sub_iterators_) { + total_cost_ += iter->cost(); + total_max_score_ += iter->cached_max_score_; + iter->next_doc(); + postings_.push_back(iter.get()); + } + // Initial sort to establish sorted order + resort_postings(); + cached_max_score_ = total_max_score_; +} + +void DisjunctionIterator::set_min_competitive_score(float min_score) { + min_competitive_score_ = min_score; +} + +// Re-establish sorted order of postings_ by cached_doc_id_ ascending. +// Called when multiple iterators may have changed position. +void DisjunctionIterator::resort_postings() { + std::sort(postings_.begin(), postings_.end(), + [](const DocIterator *a, const DocIterator *b) { + return a->cached_doc_id_ < b->cached_doc_id_; + }); +} + +uint32_t DisjunctionIterator::next_doc() { + return next_doc_impl(nullptr); +} + +uint32_t DisjunctionIterator::next_doc(const zvec::IndexFilter *filter) { + return next_doc_impl(filter); +} + +uint32_t DisjunctionIterator::next_doc_impl(const zvec::IndexFilter *filter) { + // Advance matched from the previous document + for (auto *iter : matching_iterators_) { + iter->next_doc(); + } + matching_iterators_.clear(); + + // Restore sorted order — multiple iterators may have changed + resort_postings(); + + while (true) { + // 1. postings_ is maintained in sorted order + + if (postings_.empty() || postings_[0]->cached_doc_id_ == NO_MORE_DOCS) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // 2. Find Pivot: accumulate max_score until it reaches the threshold + float partial_max_score = 0.0f; + size_t pivot_idx = 0; + bool found_pivot = false; + for (; pivot_idx < postings_.size(); ++pivot_idx) { + if (postings_[pivot_idx]->cached_doc_id_ == NO_MORE_DOCS) { + break; + } + partial_max_score += postings_[pivot_idx]->cached_max_score_; + if (partial_max_score >= min_competitive_score_) { + found_pivot = true; + break; + } + } + + if (!found_pivot) { + // If all remaining iterators' max_score sum is less than threshold, + // no more competitive documents can be produced. + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + uint32_t pivot_doc = postings_[pivot_idx]->cached_doc_id_; + + // 3. Check alignment + if (postings_[0]->cached_doc_id_ == pivot_doc) { + // 3.1 Filter pushdown: if pivot_doc is filtered, skip it before paying + // for block-max accumulation, matches(), or score(). Advance every + // posting currently sitting at pivot_doc past it, then resort. + if (filter && filter->is_filtered(pivot_doc)) { + for (size_t i = 0; i < postings_.size(); ++i) { + if (postings_[i]->cached_doc_id_ == pivot_doc) { + postings_[i]->next_doc(); + } else { + break; // postings_ is sorted; rest are > pivot_doc + } + } + resort_postings(); + continue; + } + + // 3.5 Block-Max WAND pruning (Ding & Suel 2011). + // First accumulate block_max_scores from [0..pivot_idx]. + // If already >= threshold, skip the pruning check (fast path). + // Otherwise, lazily include iterators beyond pivot_idx whose + // posting lists may also contain pivot_doc — their block_max_score + // contributions must be counted to avoid underestimating the + // potential score and incorrectly skipping TopK documents. + if (min_competitive_score_ > 0.0f) { + float block_score_sum = 0.0f; + uint32_t min_block_end = NO_MORE_DOCS; + bool can_skip = true; + + // Phase 1: accumulate [0..pivot_idx] (always needed) + for (size_t i = 0; i <= pivot_idx; ++i) { + auto info = postings_[i]->block_max_info_for(pivot_doc); + block_score_sum += info.block_max_score; + if (info.block_last_doc < min_block_end) { + min_block_end = info.block_last_doc; + } + } + + // Phase 2: if [0..pivot_idx] sum is already sufficient, no pruning + if (block_score_sum >= min_competitive_score_) { + can_skip = false; + } else { + // Lazily accumulate remaining iterators beyond pivot_idx. + // They may also contribute scores for pivot_doc. + for (size_t i = pivot_idx + 1; i < postings_.size(); ++i) { + if (postings_[i]->cached_doc_id_ == NO_MORE_DOCS) { + break; + } + auto info = postings_[i]->block_max_info_for(pivot_doc); + block_score_sum += info.block_max_score; + if (info.block_last_doc < min_block_end) { + min_block_end = info.block_last_doc; + } + if (block_score_sum >= min_competitive_score_) { + can_skip = false; + break; + } + } + } + + if (can_skip && block_score_sum < min_competitive_score_ && + min_block_end != NO_MORE_DOCS) { + // All iterators' blocks containing pivot_doc cannot produce a + // competitive score. Advance ALL iterators in [0..pivot_idx] past + // the smallest block boundary to maximize the jump distance. + uint32_t skip_target = min_block_end + 1; + for (size_t i = 0; i <= pivot_idx; ++i) { + if (postings_[i]->cached_doc_id_ < skip_target) { + postings_[i]->advance(skip_target); + } + } + // Multiple iterators changed — full resort + resort_postings(); + continue; + } + } + + // Candidate doc passed block-level check. Collect all matching iterators. + for (size_t i = 0; i < postings_.size(); ++i) { + if (postings_[i]->cached_doc_id_ == pivot_doc) { + matching_iterators_.push_back(postings_[i]); + } else { + break; // because postings_ is sorted by cached_doc_id_ + } + } + cached_doc_id_ = pivot_doc; + return pivot_doc; + } else { + // 4. Iterator Jumping: advance the iterator with the smallest doc_id + // to at least the pivot's doc_id. This bypasses scoring and checking + // for all documents smaller than pivot_doc! + // Only postings_[0] changed — use sift_forward instead of full sort. + postings_[0]->advance(pivot_doc); + sift_forward(postings_, 0); + } + } +} + +uint32_t DisjunctionIterator::advance(uint32_t target) { + // Clear pending matches as they will be re-advanced below + matching_iterators_.clear(); + + for (auto *iter : postings_) { + if (iter->cached_doc_id_ < target) { + iter->advance(target); + } + } + return next_doc(); +} + +bool DisjunctionIterator::matches() { + // At least one matching sub-iterator must pass phase-2 verification + for (DocIterator *iter : matching_iterators_) { + if (iter->matches()) { + return true; + } + } + return false; +} + +float DisjunctionIterator::score() { + // Sum scores of all matching sub-iterators that pass phase-2 verification + float total = 0.0f; + for (DocIterator *iter : matching_iterators_) { + if (iter->matches()) { + total += iter->score(); + } + } + return total; +} + +uint64_t DisjunctionIterator::cost() const { + return total_cost_; +} + +float DisjunctionIterator::max_score() const { + return total_max_score_; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_disjunction_iterator.h b/src/db/index/column/fts_column/iterator/fts_disjunction_iterator.h new file mode 100644 index 000000000..41fe55ae7 --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_disjunction_iterator.h @@ -0,0 +1,63 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "fts_doc_iterator.h" + +namespace zvec::fts { + +/*! Disjunction (OR) document iterator with WAND pruning + */ +class DisjunctionIterator : public DocIterator { + public: + /*! Construct a disjunction iterator. + * \param sub_iterators Sub-iterators to merge (OR semantics) + */ + explicit DisjunctionIterator(std::vector sub_iterators); + + uint32_t next_doc() override; + //! Internal-driven filter skip: checks filter inside the WAND loop after + //! pivot alignment, before block-max accumulation and resort overhead. + uint32_t next_doc(const zvec::IndexFilter *filter) override; + uint32_t advance(uint32_t target) override; + bool matches() override; + float score() override; + uint64_t cost() const override; + float max_score() const override; + + //! Update the minimum competitive score threshold for WAND pruning. + //! Documents whose total max_score sum falls below this threshold + //! are skipped without exact scoring. + void set_min_competitive_score(float min_score) override; + + private: + void resort_postings(); + + //! Unified WAND loop body. \p filter may be null (no-filter fast path). + uint32_t next_doc_impl(const zvec::IndexFilter *filter); + + private: + std::vector sub_iterators_; // Owns the sub-iterators + std::vector postings_; // Pointers for fast sorting (WAND) + std::vector matching_iterators_; // Current doc matches + float min_competitive_score_{0.0f}; + uint64_t total_cost_{0}; + float total_max_score_{0.0f}; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_doc_iterator.h b/src/db/index/column/fts_column/iterator/fts_doc_iterator.h new file mode 100644 index 000000000..58f0782c0 --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_doc_iterator.h @@ -0,0 +1,123 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "db/index/common/index_filter.h" + +namespace zvec::fts { + +/*! Abstract base class for FTS document iterators. + * + * All query nodes (Term, Phrase, AND, OR) implement this interface to form + * a composable iterator tree. The iterator produces matching documents in + * ascending doc_id order. + * + * Two-phase iteration: + * Phase 1: next_doc() / advance() locate candidate documents using only + * doc_id information (cheap). + * Phase 2: matches() performs exact verification (e.g. position check for + * phrase queries). Only called after Phase 1 succeeds. + */ +class DocIterator { + public: + virtual ~DocIterator() = default; + + //! Sentinel value indicating no more matching documents + static constexpr uint32_t NO_MORE_DOCS = UINT32_MAX; + + //! Cached doc_id for hot-path access without virtual dispatch. + //! Sub-classes MUST update this in next_doc() / advance() before returning. + uint32_t cached_doc_id_{NO_MORE_DOCS}; + + //! Cached max_score for hot-path access without virtual dispatch. + //! Sub-classes MUST set this in constructors (and update if max_score + //! changes, which is rare for most iterators). + float cached_max_score_{0.0f}; + + //! Advance to the next matching document. + //! \return doc_id of the next match, or NO_MORE_DOCS if exhausted. + virtual uint32_t next_doc() = 0; + + //! Filter-aware next_doc. Composite iterators (Disjunction/Conjunction/ + //! Phrase) override to check the filter at the optimal point inside their + //! loops — before block-max binary search, do_next alignment, or phase-2 + //! position verification — so filtered docs do not pay that cost. + //! Default implementation just loops over next_doc() and skips filtered + //! docs (functionally equivalent to a caller-side post-filter check). + //! \param filter Must be non-null; true means SKIP the doc. + virtual uint32_t next_doc(const zvec::IndexFilter *filter) { + uint32_t doc = next_doc(); + while (doc != NO_MORE_DOCS && filter->is_filtered(doc)) { + doc = next_doc(); + } + return doc; + } + + //! Advance to the first matching document with doc_id >= target. + //! \param target Minimum doc_id to seek to. + //! \return doc_id of the match (>= target), or NO_MORE_DOCS if exhausted. + virtual uint32_t advance(uint32_t target) = 0; + + //! Return the current document ID. + //! Undefined before the first call to next_doc() or advance(). + uint32_t doc_id() const { + return cached_doc_id_; + } + + //! Phase-2 exact verification for the current document. + //! For most iterators this is a no-op (returns true). + //! PhraseDocIterator overrides this to check position adjacency. + //! \return true if the current document truly matches. + virtual bool matches() { + return true; + } + + //! Compute the BM25 score of the current document. + //! Must only be called after matches() returns true. + virtual float score() = 0; + + //! Estimated cost of this iterator (e.g. posting list length). + //! Used to order sub-iterators in ConjunctionIterator (shortest first). + virtual uint64_t cost() const = 0; + + //! Upper bound on the score this iterator can produce for any document. + //! Used by WAND pruning in DisjunctionIterator. + virtual float max_score() const { + return std::numeric_limits::max(); + } + + //! Update the minimum competitive score threshold for WAND pruning. + //! Only DisjunctionIterator implements meaningful behavior; other iterators + //! ignore this call. + //! \param min_score Current minimum score needed to enter the TopK heap. + virtual void set_min_competitive_score(float /*min_score*/) {} + + //! Block-Max WAND support: return both block_max_score and max_doc_id + //! for the block containing \p target in a single skip list binary search. + struct BlockMaxInfo { + float block_max_score{0.0f}; + uint32_t block_last_doc{NO_MORE_DOCS}; + }; + virtual BlockMaxInfo block_max_info_for(uint32_t /*target*/) const { + return {max_score(), NO_MORE_DOCS}; + } +}; + +using DocIteratorPtr = std::unique_ptr; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_phrase_iterator.cc b/src/db/index/column/fts_column/iterator/fts_phrase_iterator.cc new file mode 100644 index 000000000..04f2bee9b --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_phrase_iterator.cc @@ -0,0 +1,205 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_phrase_iterator.h" +#include +#include +#include +#include "../fts_utils.h" + +namespace zvec::fts { + +PhraseDocIterator::PhraseDocIterator(DocIteratorPtr conjunction, + std::vector terms, + RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *positions_cf) + : conjunction_(std::move(conjunction)), + terms_(std::move(terms)), + ctx_(ctx), + positions_cf_(positions_cf) { + cached_max_score_ = conjunction_->cached_max_score_; +} + +uint32_t PhraseDocIterator::next_doc() { + cached_doc_id_ = conjunction_->next_doc(); + return cached_doc_id_; +} + +uint32_t PhraseDocIterator::next_doc(const zvec::IndexFilter *filter) { + cached_doc_id_ = conjunction_->next_doc(filter); + return cached_doc_id_; +} + +uint32_t PhraseDocIterator::advance(uint32_t target) { + cached_doc_id_ = conjunction_->advance(target); + return cached_doc_id_; +} + +bool PhraseDocIterator::matches() { + if (cached_doc_id_ == NO_MORE_DOCS) { + return false; + } + // Phase 2: verify position adjacency (deferred IO) + return verify_phrase_positions(cached_doc_id_); +} + +float PhraseDocIterator::score() { + return conjunction_->score(); +} + +uint64_t PhraseDocIterator::cost() const { + return conjunction_->cost(); +} + +float PhraseDocIterator::max_score() const { + return conjunction_->max_score(); +} + +bool PhraseDocIterator::verify_phrase_positions(uint32_t doc_id) const { + const size_t n = terms_.size(); + if (n == 0) { + return false; + } + + // Deduplicate terms within the phrase. Repeated terms (e.g., "to be or not + // to be") collapse into one $POS lookup; term_to_unique_idx maps each phrase + // position back to its slot in the unique list. + std::vector term_to_unique_idx(n); + std::vector unique_to_first_term_idx; + unique_to_first_term_idx.reserve(n); + std::unordered_map seen; + seen.reserve(n); + for (size_t i = 0; i < n; ++i) { + const size_t next_idx = unique_to_first_term_idx.size(); + auto [it, inserted] = seen.try_emplace(terms_[i], next_idx); + if (inserted) { + unique_to_first_term_idx.push_back(i); + } + term_to_unique_idx[i] = it->second; + } + const size_t unique_size = unique_to_first_term_idx.size(); + + // Build unique (term, doc_id) keys into a single reusable buffer; reserve + // up-front so the buffer never reallocates and the Slice pointers below stay + // valid until the MultiGet returns. + size_t total_key_bytes = 0; + for (size_t u = 0; u < unique_size; ++u) { + total_key_bytes += + terms_[unique_to_first_term_idx[u]].size() + 1 + sizeof(uint32_t); + } + std::string key_buffer; + key_buffer.reserve(total_key_bytes); + + std::vector key_slices; + key_slices.reserve(unique_size); + for (size_t u = 0; u < unique_size; ++u) { + const std::string &term = terms_[unique_to_first_term_idx[u]]; + const size_t offset = key_buffer.size(); + const size_t bytes = fts::append_doc_term_key(term, doc_id, &key_buffer); + key_slices.emplace_back(key_buffer.data() + offset, bytes); + } + + // Batched read across unique (term, doc_id) keys — single MultiGet instead + // of per-anchor-position Gets. + std::vector cfs(unique_size, positions_cf_); + std::vector values(unique_size); + std::vector statuses(unique_size); + ctx_->db_->MultiGet(ctx_->read_opts_, unique_size, cfs.data(), + key_slices.data(), values.data(), statuses.data()); + + // Decode every position list once. A missing entry means this doc cannot + // be a phrase match — this happens for docs filtered through the conjunction + // without a position-CF entry, so we do NOT log here. + std::vector> positions_cache(unique_size); + for (size_t u = 0; u < unique_size; ++u) { + if (!statuses[u].ok() || values[u].size() == 0) { + return false; + } + positions_cache[u] = decode_positions(values[u]); + if (positions_cache[u].empty()) { + return false; + } + } + + // Pick the term with the shortest position list as anchor so the outer + // loop iterates as few candidates as possible. anchor_term_idx stays in + // original phrase order — the phrase start equals anchor_pos - + // anchor_term_idx. + size_t anchor_term_idx = 0; + size_t min_size = positions_cache[term_to_unique_idx[0]].size(); + for (size_t i = 1; i < n; ++i) { + const size_t sz = positions_cache[term_to_unique_idx[i]].size(); + if (sz < min_size) { + min_size = sz; + anchor_term_idx = i; + } + } + + const auto &anchor_positions = + positions_cache[term_to_unique_idx[anchor_term_idx]]; + const uint32_t anchor_offset = static_cast(anchor_term_idx); + for (uint32_t anchor_pos : anchor_positions) { + if (anchor_pos < anchor_offset) { + // phrase start would be negative — impossible + continue; + } + const uint32_t start = anchor_pos - anchor_offset; + bool phrase_matched = true; + for (size_t i = 0; i < n; ++i) { + if (i == anchor_term_idx) { + continue; + } + const uint32_t expected = start + static_cast(i); + const auto &positions = positions_cache[term_to_unique_idx[i]]; + if (!std::binary_search(positions.begin(), positions.end(), expected)) { + phrase_matched = false; + break; + } + } + if (phrase_matched) { + return true; + } + } + + return false; +} + +std::vector PhraseDocIterator::decode_positions( + const rocksdb::Slice &data) { + std::vector positions; + size_t index = 0; + uint32_t current_position = 0; + const char *bytes = data.data(); + const size_t size = data.size(); + + while (index < size) { + // Decode varint + uint32_t delta = 0; + uint32_t shift = 0; + while (index < size) { + const uint8_t byte = static_cast(bytes[index++]); + delta |= static_cast(byte & 0x7F) << shift; + shift += 7; + if ((byte & 0x80) == 0) { + break; + } + } + current_position += delta; + positions.push_back(current_position); + } + + return positions; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_phrase_iterator.h b/src/db/index/column/fts_column/iterator/fts_phrase_iterator.h new file mode 100644 index 000000000..c8245a74c --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_phrase_iterator.h @@ -0,0 +1,75 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "db/common/rocksdb_context.h" +#include "fts_conjunction_iterator.h" +#include "fts_doc_iterator.h" +#include "../bm25_scorer.h" + +namespace zvec::fts { + +/*! Phrase document iterator (two-phase) + * + * Internally wraps a ConjunctionIterator for phase-1 doc_id intersection. + * Phase-2 matches() reads position payloads and checks adjacency. + */ +class PhraseDocIterator : public DocIterator { + public: + /*! Construct a phrase iterator. + * \param conjunction ConjunctionIterator over all terms in the phrase + * \param terms Processed (tokenized) term strings in phrase order + * \param positions_cf $POS column family for reading position lists + */ + PhraseDocIterator(DocIteratorPtr conjunction, std::vector terms, + RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *positions_cf); + + uint32_t next_doc() override; + //! Internal-driven filter skip: delegates to the inner conjunction so the + //! expensive phase-2 verify_phrase_positions() ($POS CF reads) is never + //! run on filtered docs. + uint32_t next_doc(const zvec::IndexFilter *filter) override; + uint32_t advance(uint32_t target) override; + + //! Phase-2: verify position adjacency for the current document. + //! Reads position lists from $POS CF (deferred IO). + bool matches() override; + + float score() override; + uint64_t cost() const override; + float max_score() const override; + + private: + // Verify that terms appear at consecutive positions in the document. + // Issues a single MultiGet across the unique terms in the phrase, decodes + // every position list once, then validates adjacency entirely in memory. + bool verify_phrase_positions(uint32_t doc_id) const; + + // Decode varint delta-encoded position list out of a RocksDB value slice. + static std::vector decode_positions(const rocksdb::Slice &data); + + private: + DocIteratorPtr conjunction_; + std::vector terms_; + RocksdbContext *ctx_; + rocksdb::ColumnFamilyHandle *positions_cf_; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_term_iterator.cc b/src/db/index/column/fts_column/iterator/fts_term_iterator.cc new file mode 100644 index 000000000..5d47cd7f0 --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_term_iterator.cc @@ -0,0 +1,206 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_term_iterator.h" +#include +#include +#include +#include "../fts_utils.h" + +namespace zvec::fts { + +// ============================================================ +// Constructors +// ============================================================ + +// Roaring Bitmap mode — takes ownership of bitmap, iterates lazily. +TermDocIterator::TermDocIterator(std::string term, roaring_bitmap_t *bitmap, + uint64_t df, BM25ScorerPtr scorer, + float max_score_val, RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *term_freq_cf, + rocksdb::ColumnFamilyHandle *doc_len_cf, + std::atomic *cf_counter, float boost) + : mode_(Mode::ROARING), + term_(std::move(term)), + df_(df), + scorer_(std::move(scorer)), + max_score_val_(max_score_val * boost), + boost_(boost), + bitmap_(bitmap), + ctx_(ctx), + term_freq_cf_(term_freq_cf), + doc_len_cf_(doc_len_cf), + cf_counter_(cf_counter) { + roaring_init_iterator(bitmap_, &roaring_iter_); + cached_max_score_ = max_score_val_; + idf_weight_ = scorer_->idf(df_); +} + +TermDocIterator::~TermDocIterator() { + if (bitmap_) { + roaring_bitmap_free(bitmap_); + bitmap_ = nullptr; + } + if (cf_counter_) { + --*cf_counter_; + } +} + +// BitPacked mode +TermDocIterator::TermDocIterator(std::string term, + rocksdb::PinnableSlice packed_data, + BM25ScorerPtr scorer, float boost) + : mode_(Mode::BITPACKED), + term_(std::move(term)), + scorer_(std::move(scorer)), + boost_(boost), + packed_data_(std::move(packed_data)) { + // Failure here means the term will produce no docs (next_doc returns + // NO_MORE_DOCS). bp_iter_.open() already logs the underlying parse error; + // surface it once more here with the term context for easier triage. + if (bp_iter_.open(packed_data_.data(), packed_data_.size()) != 0) { + LOG_ERROR( + "TermDocIterator: failed to open bitpacked posting for term[%s], " + "iterator will yield no documents", + term_.c_str()); + } + df_ = bp_iter_.cost(); + // Apply boost to max_score_val_ so that DisjunctionIterator's WAND pivot + // computation matches the actual scores returned by score() below. + max_score_val_ = bp_iter_.max_score() * boost_; + cached_max_score_ = max_score_val_; + idf_weight_ = scorer_->idf(df_); +} + +// ============================================================ +// Iterator interface +// ============================================================ + +uint32_t TermDocIterator::next_doc() { + if (mode_ == Mode::BITPACKED) { + cached_doc_id_ = bp_iter_.next_doc(); + return cached_doc_id_; + } + + // Roaring mode: stream via roaring_uint32_iterator_t + if (!roaring_iter_started_) { + // First call: iterator already points at the first element after + // roaring_init_iterator in the constructor. + roaring_iter_started_ = true; + } else { + roaring_advance_uint32_iterator(&roaring_iter_); + } + if (!roaring_iter_.has_value) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + cached_doc_id_ = roaring_iter_.current_value; + return cached_doc_id_; +} + +uint32_t TermDocIterator::advance(uint32_t target) { + if (mode_ == Mode::BITPACKED) { + cached_doc_id_ = bp_iter_.advance(target); + return cached_doc_id_; + } + + // Roaring mode: skip to the first doc_id >= target + roaring_iter_started_ = true; + if (!roaring_move_uint32_iterator_equalorlarger(&roaring_iter_, target)) { + cached_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + cached_doc_id_ = roaring_iter_.current_value; + return cached_doc_id_; +} + +float TermDocIterator::score() { + if (cached_doc_id_ == NO_MORE_DOCS) { + return 0.0f; + } + + if (mode_ == Mode::BITPACKED) { + // Fast path: read tf/doc_len from inline payload (zero I/O) + const uint32_t tf = bp_iter_.term_freq(); + const uint32_t dl = bp_iter_.doc_len(); + return scorer_->score_with_idf(idf_weight_, tf, dl, boost_); + } + + // Roaring mode: read from RocksDB + const uint32_t tf = read_term_freq(cached_doc_id_); + const uint32_t doc_len = read_doc_len(cached_doc_id_); + return scorer_->score_with_idf(idf_weight_, tf, doc_len, boost_); +} + +uint64_t TermDocIterator::cost() const { + if (mode_ == Mode::BITPACKED) { + return bp_iter_.cost(); + } + return df_; +} + +// ============================================================ +// Block-Max WAND support +// ============================================================ + +DocIterator::BlockMaxInfo TermDocIterator::block_max_info_for( + uint32_t target) const { + if (mode_ == Mode::BITPACKED) { + auto info = bp_iter_.block_max_info_for(target); + // Apply boost so the upper bound matches score() (which multiplies by + // boost_) and stays consistent with max_score_val_ for WAND pivoting. + return {info.block_max_score * boost_, info.block_last_doc}; + } + // Roaring mode: fall back to global max_score (already boosted in ctor), + // no block structure available. + return {max_score_val_, NO_MORE_DOCS}; +} + +// ============================================================ +// Roaring mode helpers +// ============================================================ + +uint32_t TermDocIterator::read_term_freq(uint32_t doc_id) const { + if (!term_freq_cf_) { + return 1; // CF dropped after convert_postings_to_bitpacked + } + const std::string key = fts::make_doc_term_key(term_, doc_id); + std::string value; + if (!ctx_->db_->Get(ctx_->read_opts_, term_freq_cf_, key, &value).ok() || + value.size() < sizeof(uint32_t)) { + return 1; // Default term frequency is 1 + } + uint32_t tf = 0; + std::memcpy(&tf, value.data(), sizeof(uint32_t)); + return tf; +} + +uint32_t TermDocIterator::read_doc_len(uint32_t doc_id) const { + if (!doc_len_cf_) { + return 1; // CF dropped after convert_postings_to_bitpacked + } + std::string doc_id_key(sizeof(uint32_t), '\0'); + std::memcpy(doc_id_key.data(), &doc_id, sizeof(uint32_t)); + + std::string value; + if (!ctx_->db_->Get(ctx_->read_opts_, doc_len_cf_, doc_id_key, &value).ok() || + value.size() < sizeof(uint32_t)) { + return 1; // Default document length is 1 + } + uint32_t doc_len = 0; + std::memcpy(&doc_len, value.data(), sizeof(uint32_t)); + return doc_len; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/iterator/fts_term_iterator.h b/src/db/index/column/fts_column/iterator/fts_term_iterator.h new file mode 100644 index 000000000..1d3d6b427 --- /dev/null +++ b/src/db/index/column/fts_column/iterator/fts_term_iterator.h @@ -0,0 +1,134 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "db/common/rocksdb_context.h" +#include "fts_doc_iterator.h" +#include "../bm25_scorer.h" +#include "../posting/bitpacked_posting_list.h" + +namespace zvec::fts { + +/*! Term document iterator + * Supports two internal modes: + * 1. Roaring mode: sorted doc_id array + RocksDB Get for tf/doc_len + * 2. BitPacked mode: inline payloads, zero RocksDB I/O for score() + */ +class TermDocIterator : public DocIterator { + public: + /*! Roaring Bitmap mode constructor. + * Takes ownership of the bitmap and iterates lazily via + * roaring_uint32_iterator_t — no N×4-byte doc_id array is materialised. + * + * \param term Processed (tokenized) term string + * \param bitmap Deserialized Roaring bitmap (ownership transferred) + * \param df Document frequency of this term in the segment + * \param scorer BM25 scorer (with segment stats loaded) + * \param max_score_val Precomputed WAND upper bound score for this term + * (caller must NOT pre-multiply by boost — the + * constructor applies boost to both score() output + * and max_score_val_ to keep WAND pivot correct) + * \param term_freq_cf $TF column family for reading per-doc term freq + * \param doc_len_cf $DOC_LEN column family for reading doc length + * \param cf_counter CF reference counter for term_freq_cf and doc_len_cf + * \param boost Per-term boost (1.0 = no boost) + */ + TermDocIterator(std::string term, roaring_bitmap_t *bitmap, uint64_t df, + BM25ScorerPtr scorer, float max_score_val, + RocksdbContext *ctx, + rocksdb::ColumnFamilyHandle *term_freq_cf, + rocksdb::ColumnFamilyHandle *doc_len_cf, + std::atomic *cf_counter, float boost = 1.0f); + + ~TermDocIterator() override; + + /*! BitPacked mode constructor. + * All payloads (tf, doc_len, per-block max_score, global max_score) are + * embedded inline in packed_data, so this iterator is completely + * self-contained on the read path: + * - score() reads tf/doc_len from bp_iter_ — zero RocksDB I/O. + * - block_max_info_for() / max_score() all read from the BitPacked + * skip-list / block headers — no $MAX_TF lookup needed. + * Construction takes neither $TF, $DOC_LEN, nor $MAX_TF column families: + * the immutable segment SST may have these CFs entirely empty (cleared + * by FtsColumnIndexer::convert_postings_to_bitpacked at dump time) and + * this iterator still works correctly. + * + * df and max_score are read from bp_iter_ after open(); on open failure + * cost() returns 0 and callers should treat the iterator as empty. + * + * \param term Processed (tokenized) term string + * \param packed_data Serialized BitPacked posting list (ownership taken) + * \param scorer BM25 scorer (with segment stats loaded) + * \param boost Per-term boost (1.0 = no boost) + */ + TermDocIterator(std::string term, rocksdb::PinnableSlice packed_data, + BM25ScorerPtr scorer, float boost = 1.0f); + + // Prevent move/copy: bp_iter_ holds a raw pointer into packed_data_'s + // buffer, so moving would create a dangling pointer. + TermDocIterator(const TermDocIterator &) = delete; + TermDocIterator &operator=(const TermDocIterator &) = delete; + TermDocIterator(TermDocIterator &&) = delete; + TermDocIterator &operator=(TermDocIterator &&) = delete; + + uint32_t next_doc() override; + uint32_t advance(uint32_t target) override; + float score() override; + uint64_t cost() const override; + float max_score() const override { + return max_score_val_; + } + + // Block-Max WAND support (only effective in BitPacked mode) + BlockMaxInfo block_max_info_for(uint32_t target) const override; + + private: + // Read term frequency for the current document (Roaring mode only) + uint32_t read_term_freq(uint32_t doc_id) const; + + // Read document length for the current document (Roaring mode only) + uint32_t read_doc_len(uint32_t doc_id) const; + + private: + enum class Mode { ROARING, BITPACKED }; + Mode mode_; + + std::string term_; + uint64_t df_; + BM25ScorerPtr scorer_; + float max_score_val_; + float idf_weight_{0.0f}; // Pre-computed IDF to avoid log() per score() + float boost_{1.0f}; // Per-term boost (collapsed from repeated terms) + + // Roaring mode state (owns the bitmap; iterator is stack-allocated) + roaring_bitmap_t *bitmap_{nullptr}; + roaring_uint32_iterator_t roaring_iter_{}; + bool roaring_iter_started_{false}; // tracks whether first next_doc called + RocksdbContext *ctx_{nullptr}; + rocksdb::ColumnFamilyHandle *term_freq_cf_{nullptr}; + rocksdb::ColumnFamilyHandle *doc_len_cf_{nullptr}; + std::atomic *cf_counter_{nullptr}; + + // BitPacked mode state + rocksdb::PinnableSlice packed_data_; // owns the serialized data (zero-copy) + BitPackedPostingIterator bp_iter_; // zero-copy iterator over packed_data_ +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/parser/fts_query_parser.cc b/src/db/index/column/fts_column/parser/fts_query_parser.cc new file mode 100644 index 000000000..5a7fdccc4 --- /dev/null +++ b/src/db/index/column/fts_column/parser/fts_query_parser.cc @@ -0,0 +1,398 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fts_query_parser.h" +#include +#include "db/index/column/fts_column/gen/FtsLexer.h" +#include "db/index/column/fts_column/gen/FtsParser.h" +#include "antlr4-runtime.h" + +using namespace antlr4; + +namespace zvec::fts { + +// ============================================================ +// Error listener that captures the first error message +// ============================================================ + +class FtsErrorListener : public BaseErrorListener { + public: + void syntaxError(Recognizer * /*recognizer*/, + antlr4::Token * /*offending_symbol*/, size_t line, + size_t char_position_in_line, const std::string &msg, + std::exception_ptr /*exception*/) override { + if (err_msg_.empty()) { + err_msg_ = ailego::StringHelper::Concat( + "[", line, " ", char_position_in_line, " ", msg, "]"); + } + } + + const std::string &err_msg() const { + return err_msg_; + } + + private: + std::string err_msg_; +}; + +// ============================================================ +// AST builder helpers (anonymous namespace) +// ============================================================ + +namespace { + +// Forward declaration +FtsAstNodePtr build_fts_or_expr(FtsParser::Fts_or_exprContext *or_ctx, + const TokenizerPipeline &pipeline, + FtsDefaultOperator default_op, + std::string *err_msg); + +// Strip surrounding single or double quotes from a quoted string token. +std::string strip_quotes(const std::string "ed) { + if (quoted.size() >= 2 && + ((quoted.front() == '\'' && quoted.back() == '\'') || + (quoted.front() == '"' && quoted.back() == '"'))) { + return quoted.substr(1, quoted.size() - 2); + } + return quoted; +} + +// Propagate must/must_not modifier to the root of an already-built AST node. +// Now that must/must_not live on the FtsAstNode base class, this works +// uniformly for terms, phrases and composite (AND/OR) sub-expressions. +// OR-merge with any existing flags so a second application on the same +// node never silently clears modifiers set by a prior pass. +void apply_modifier(FtsAstNode *node, bool is_must, bool is_must_not) { + if (!node || (!is_must && !is_must_not)) { + return; + } + node->must = node->must || is_must; + node->must_not = node->must_not || is_must_not; +} + +// atom: fts_field_prefix? fts_primary fts_boost? +// +// fts_field_prefix (e.g. "title:") and fts_boost (e.g. "^2") are parsed by +// the grammar but not supported at query execution time — return an error. +// +// fts_primary: fts_term | fts_phrase | LP fts_or_expr RP +FtsAstNodePtr build_fts_atom(FtsParser::Fts_atomContext *atom_ctx, bool is_must, + bool is_must_not, + const TokenizerPipeline &pipeline, + FtsDefaultOperator default_op, + std::string *err_msg) { + // Reject field-prefixed queries (e.g. "title:cancer") + if (atom_ctx->fts_field_prefix() != nullptr) { + if (err_msg) { + *err_msg = "field-prefixed queries are not supported"; + } + return nullptr; + } + + // Reject boosted queries (e.g. "term^2") + if (atom_ctx->fts_boost() != nullptr) { + if (err_msg) { + *err_msg = "boost queries are not supported"; + } + return nullptr; + } + + FtsParser::Fts_primaryContext *primary_ctx = atom_ctx->fts_primary(); + if (primary_ctx == nullptr) { + return nullptr; + } + + if (primary_ctx->fts_term() != nullptr) { + std::string term_text = primary_ctx->fts_term()->getText(); + auto tokens = pipeline.process(term_text); + if (tokens.empty()) { + // Term filtered out (e.g. stop-word, pure punctuation). Returning + // nullptr here lets the seq/and/or builders skip this child. + return nullptr; + } + if (tokens.size() == 1) { + return std::make_unique(std::move(tokens[0].text), is_must, + is_must_not); + } + // Multi-token bare term: combine via the configured default operator and + // attach must/must_not on the composite root. + FtsAstNodePtr composite; + if (default_op == FtsDefaultOperator::AND) { + auto and_node = std::make_unique(); + and_node->children.reserve(tokens.size()); + for (auto &t : tokens) { + and_node->children.push_back( + std::make_unique(std::move(t.text))); + } + composite = std::move(and_node); + } else { + auto or_node = std::make_unique(); + or_node->children.reserve(tokens.size()); + for (auto &t : tokens) { + or_node->children.push_back( + std::make_unique(std::move(t.text))); + } + composite = std::move(or_node); + } + apply_modifier(composite.get(), is_must, is_must_not); + return composite; + } + + if (primary_ctx->fts_phrase() != nullptr) { + std::string raw = primary_ctx->fts_phrase()->getText(); + std::string phrase_text = strip_quotes(raw); + auto tokens = pipeline.process(phrase_text); + auto phrase_node = std::make_unique(); + phrase_node->must = is_must; + phrase_node->must_not = is_must_not; + phrase_node->terms.reserve(tokens.size()); + for (auto &t : tokens) { + phrase_node->terms.push_back(std::move(t.text)); + } + return phrase_node; + } + + if (primary_ctx->fts_or_expr() != nullptr) { + // Parenthesised sub-expression — propagate default_op so that adjacent + // bare terms inside the parentheses share the same implicit semantics. + auto inner = build_fts_or_expr(primary_ctx->fts_or_expr(), pipeline, + default_op, err_msg); + apply_modifier(inner.get(), is_must, is_must_not); + return inner; + } + + return nullptr; +} + +// unary: (PLUS_SIGN | MINUS_SIGN)? atom +// NOT is no longer a unary modifier — it is handled as a binary operator in +// build_fts_and_expr. antlr4 generates separate subclasses for each labeled +// alternative. +FtsAstNodePtr build_fts_unary(FtsParser::Fts_unaryContext *unary_ctx, + const TokenizerPipeline &pipeline, + FtsDefaultOperator default_op, + std::string *err_msg) { + if (auto *must_ctx = dynamic_cast(unary_ctx)) { + return build_fts_atom(must_ctx->fts_atom(), /*is_must=*/true, + /*is_must_not=*/false, pipeline, default_op, err_msg); + } + if (auto *must_not_ctx = + dynamic_cast(unary_ctx)) { + return build_fts_atom(must_not_ctx->fts_atom(), /*is_must=*/false, + /*is_must_not=*/true, pipeline, default_op, err_msg); + } + // Plain_atomContext (no modifier) + if (auto *plain_ctx = + dynamic_cast(unary_ctx)) { + return build_fts_atom(plain_ctx->fts_atom(), /*is_must=*/false, + /*is_must_not=*/false, pipeline, default_op, err_msg); + } + return nullptr; +} + +// seqExpr: unary+ +// Adjacent terms use the implicit default operator passed in (OR or AND). +// This is the only place where FtsDefaultOperator actually changes the AST +// structure; all other build_* helpers simply propagate the value. +FtsAstNodePtr build_fts_seq_expr(FtsParser::Fts_seq_exprContext *seq_ctx, + const TokenizerPipeline &pipeline, + FtsDefaultOperator default_op, + std::string *err_msg) { + auto unary_list = seq_ctx->fts_unary(); + if (unary_list.size() == 1) { + return build_fts_unary(unary_list[0], pipeline, default_op, err_msg); + } + + // Parse all children first + std::vector children; + for (auto *unary_ctx : unary_list) { + auto child = build_fts_unary(unary_ctx, pipeline, default_op, err_msg); + if (!child) { + if (err_msg && !err_msg->empty()) { + return nullptr; + } + continue; + } + children.push_back(std::move(child)); + } + if (children.size() == 1) { + return std::move(children[0]); + } + + // Assign children to the appropriate node type + if (default_op == FtsDefaultOperator::AND) { + auto and_node = std::make_unique(); + and_node->children = std::move(children); + return and_node; + } + auto or_node = std::make_unique(); + or_node->children = std::move(children); + return or_node; +} + +// andExpr: seqExpr ((AND | NOT) seqExpr)* +// +// NOT shares the same precedence as AND. Each `NOT seqExpr` on the right of +// the operator marks the produced child as must_not, then the whole +// sub-expression collapses into a single AndNode. Example: +// `a NOT b` => And[a, b{must_not}] +// `a AND b NOT c` => And[a, b, c{must_not}] +FtsAstNodePtr build_fts_and_expr(FtsParser::Fts_and_exprContext *and_ctx, + const TokenizerPipeline &pipeline, + FtsDefaultOperator default_op, + std::string *err_msg) { + auto and_node = std::make_unique(); + bool next_is_not = false; + for (auto *raw : and_ctx->children) { + if (auto *term = dynamic_cast(raw)) { + const auto token_type = term->getSymbol()->getType(); + if (token_type == FtsParser::AND) { + next_is_not = false; + } else if (token_type == FtsParser::NOT) { + next_is_not = true; + } + continue; + } + auto *seq_ctx = dynamic_cast(raw); + if (seq_ctx == nullptr) { + continue; + } + auto child = build_fts_seq_expr(seq_ctx, pipeline, default_op, err_msg); + bool is_not_for_this_child = next_is_not; + next_is_not = false; + if (!child) { + if (err_msg && !err_msg->empty()) { + return nullptr; + } + continue; + } + if (is_not_for_this_child) { + apply_modifier(child.get(), /*is_must=*/false, /*is_must_not=*/true); + } + and_node->children.push_back(std::move(child)); + } + if (and_node->children.empty()) { + return nullptr; + } + if (and_node->children.size() == 1) { + return std::move(and_node->children[0]); + } + return and_node; +} + +// orExpr: andExpr (OR andExpr)* +FtsAstNodePtr build_fts_or_expr(FtsParser::Fts_or_exprContext *or_ctx, + const TokenizerPipeline &pipeline, + FtsDefaultOperator default_op, + std::string *err_msg) { + auto and_list = or_ctx->fts_and_expr(); + if (and_list.size() == 1) { + return build_fts_and_expr(and_list[0], pipeline, default_op, err_msg); + } + auto or_node = std::make_unique(); + for (auto *and_ctx : and_list) { + auto child = build_fts_and_expr(and_ctx, pipeline, default_op, err_msg); + if (!child) { + if (err_msg && !err_msg->empty()) { + return nullptr; + } + continue; + } + or_node->children.push_back(std::move(child)); + } + if (or_node->children.size() == 1) { + return std::move(or_node->children[0]); + } + return or_node; +} + +} // anonymous namespace + +// ============================================================ +// FtsQueryParser::parse() +// ============================================================ + +FtsAstNodePtr FtsQueryParser::parse(const std::string &query, + const TokenizerPipelinePtr &pipeline, + FtsDefaultOperator default_op) { + err_msg_.clear(); + if (!pipeline) { + err_msg_ = "fts parser: pipeline is required"; + return nullptr; + } + + try { + ANTLRInputStream input(query); + FtsLexer lexer(&input); + + FtsErrorListener lexer_error_listener; + lexer.removeErrorListeners(); + lexer.addErrorListener(&lexer_error_listener); + + CommonTokenStream tokens(&lexer); + + FtsParser parser(&tokens); + + FtsErrorListener parser_error_listener; + parser.removeErrorListeners(); + parser.addErrorListener(&parser_error_listener); + + // First attempt with SLL prediction mode (fast path) + parser.getInterpreter()->setPredictionMode( + atn::PredictionMode::SLL); + FtsParser::Fts_query_unitContext *tree = parser.fts_query_unit(); + + // Fall back to full LL mode if SLL produced errors + if (lexer.getNumberOfSyntaxErrors() > 0 || + parser.getNumberOfSyntaxErrors() > 0) { + tokens.reset(); + parser.reset(); + parser.getInterpreter()->setPredictionMode( + atn::PredictionMode::LL); + tree = parser.fts_query_unit(); + } + + if (lexer.getNumberOfSyntaxErrors() > 0) { + err_msg_ = "fts lexer error " + lexer_error_listener.err_msg(); + return nullptr; + } + if (parser.getNumberOfSyntaxErrors() > 0) { + err_msg_ = "fts syntax error " + parser_error_listener.err_msg(); + return nullptr; + } + + if (tree == nullptr || tree->fts_or_expr() == nullptr) { + err_msg_ = "fts parse error: empty or invalid query"; + return nullptr; + } + + auto result = build_fts_or_expr(tree->fts_or_expr(), *pipeline, default_op, + &err_msg_); + if (!result && !err_msg_.empty()) { + return nullptr; + } + if (!result) { + // Grammar valid but analyzer dropped every term: return EmptyNode so + // callers don't have to treat zero-doc queries as parse errors. + return std::make_unique(); + } + return result; + + } catch (const std::exception &exception) { + err_msg_ = "fts parse exception: " + std::string(exception.what()); + return nullptr; + } +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/parser/fts_query_parser.h b/src/db/index/column/fts_column/parser/fts_query_parser.h new file mode 100644 index 000000000..fb1ff9ef6 --- /dev/null +++ b/src/db/index/column/fts_column/parser/fts_query_parser.h @@ -0,0 +1,67 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "db/index/column/fts_column/fts_query_ast.h" +#include "db/index/column/fts_column/tokenizer/tokenizer_factory.h" + +namespace zvec::fts { + +/*! Default boolean operator applied to adjacent bare terms that are not + * separated by an explicit operator (AND / OR / + / -). + * This is equivalent to Lucene/Elasticsearch's `default_operator` semantics. + */ +enum class FtsDefaultOperator { + OR, // Adjacent bare terms are combined with OR (historical default). + AND, // Adjacent bare terms are combined with AND. +}; + +/*! FTS query parser + * Thread-compatible but not thread-safe: create one instance per parse call + * or protect with a mutex. + */ +class FtsQueryParser { + public: + FtsQueryParser() = default; + + /*! Parse an FTS query expression string into an AST. + * \param query Query string, e.g. '+vector -slow "exact phrase" 中文 + * AND 分词' + * \param pipeline Tokenizer pipeline used to tokenize phrase contents + * and bare terms so that query-side segmentation + * matches the doc-side index. Must be non-null. + * \param default_op Default operator for adjacent bare terms with no + * explicit operator. Defaults to OR for backward + * compatibility. Does not change the semantics of + * explicit AND / OR / + / - usages. + * \return Root AST node, or nullptr on parse failure. Call err_msg() to + * retrieve the error description. + */ + FtsAstNodePtr parse(const std::string &query, + const TokenizerPipelinePtr &pipeline, + FtsDefaultOperator default_op = FtsDefaultOperator::OR); + + /*! Return the error message from the most recent failed parse() call. */ + const std::string &err_msg() const { + return err_msg_; + } + + private: + std::string err_msg_; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/posting/bitpacked_posting_list.cc b/src/db/index/column/fts_column/posting/bitpacked_posting_list.cc new file mode 100644 index 000000000..c085681cc --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_posting_list.cc @@ -0,0 +1,704 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bitpacked_posting_list.h" +#include +#include +#include +#include +#include "bitpacked_simd_dispatch.h" + +#ifdef _MSC_VER +#include +#include +#endif + +namespace zvec::fts { + +// ============================================================ +// BitPacked Posting List on-disk format +// ============================================================ +// +// Encodes doc_id deltas, term frequencies, and document lengths using +// per-block bitpacking. Each block stores up to 128 entries and carries +// a precomputed BM25 score upper bound to support Block-Max WAND pruning. +// +// File layout: +// [Header 16B] [SkipList N*12B] [Block0] [Block1] ... +// +// Block layout: +// [BlockHeader 12B] [packed_deltas] [packed_tfs] [packed_dlens] + +namespace { + +/// Round up \p value to the next multiple of \p alignment. +constexpr size_t align_up(size_t value, size_t alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} + +/// Allocate 16-byte-aligned memory for \p count uint32_t values, returned as +/// a unique_ptr with a custom deleter. +inline auto make_aligned_uint32_array(size_t count) { + const size_t num_bytes = align_up(count * sizeof(uint32_t), 16); +#ifdef _MSC_VER + auto *ptr = static_cast(_aligned_malloc(num_bytes, 16)); + return std::unique_ptr(ptr, + _aligned_free); +#else + auto *ptr = static_cast(std::aligned_alloc(16, num_bytes)); + return std::unique_ptr(ptr, std::free); +#endif +} + +} // namespace + +// ============================================================ +// Low-level bitpacking primitives +// ============================================================ + +uint8_t BitPackedPostingList::bits_needed(uint32_t max_value) { + if (max_value == 0) return 0; +#ifdef _MSC_VER + unsigned long index = 0; + _BitScanReverse(&index, max_value); + return static_cast(index + 1); +#else + return static_cast(32 - __builtin_clz(max_value)); +#endif +} + +void BitPackedPostingList::pack_uint32(const uint32_t *in, uint8_t bitwidth, + uint32_t count, uint8_t *out) { + if (bitwidth == 0 || count == 0) return; + + // Full block path: 128 values at once via dispatch (SIMD or scalar) + if (count == DOCS_PER_BLOCK) { + simd::get_dispatch().pack_uint32_128(in, bitwidth, out); + return; + } + + // Tail block path (count < 128): use scalar fastpack, 32 at a time + const size_t total_bytes = packed_byte_size(bitwidth, count); + std::memset(out, 0, total_bytes); + + uint32_t *out32 = reinterpret_cast(out); + uint32_t offset = 0; + + while (offset + 32 <= count) { + FastPForLib::fastpackwithoutmask(in + offset, out32, bitwidth); + out32 += bitwidth; + offset += 32; + } + + // Tail: fewer than 32 integers + if (offset < count) { + alignas(16) uint32_t padded_in[32] = {}; + std::memcpy(padded_in, in + offset, (count - offset) * sizeof(uint32_t)); + alignas(16) uint32_t padded_out[32] = {}; + FastPForLib::fastpackwithoutmask(padded_in, padded_out, bitwidth); + size_t tail_bytes = packed_byte_size(bitwidth, count - offset); + std::memcpy(out32, padded_out, tail_bytes); + } +} + +void BitPackedPostingList::unpack_uint32(const uint8_t *in, uint8_t bitwidth, + uint32_t count, uint32_t *out) { + if (bitwidth == 0 || count == 0) { + for (uint32_t i = 0; i < count; ++i) { + out[i] = 0; + } + return; + } + + // Full block path: 128 values at once via dispatch (SIMD or scalar) + if (count == DOCS_PER_BLOCK) { + simd::get_dispatch().unpack_uint32_128(in, bitwidth, out); + return; + } + + // Tail block path (count < 128): use scalar fastunpack, 32 at a time + const uint32_t *in32 = reinterpret_cast(in); + uint32_t offset = 0; + + while (offset + 32 <= count) { + FastPForLib::fastunpack(in32, out + offset, bitwidth); + in32 += bitwidth; + offset += 32; + } + + // Tail: fewer than 32 integers + if (offset < count) { + const size_t tail_bytes = packed_byte_size(bitwidth, count - offset); + alignas(16) uint32_t padded_in[32] = {}; + std::memcpy(padded_in, in32, tail_bytes); + alignas(16) uint32_t padded_out[32] = {}; + FastPForLib::fastunpack(padded_in, padded_out, bitwidth); + std::memcpy(out + offset, padded_out, (count - offset) * sizeof(uint32_t)); + } +} + +// ============================================================ +// Encoder +// ============================================================ + +std::string BitPackedPostingList::encode(const uint32_t *doc_ids, + const uint32_t *tfs, + const uint32_t *doc_lens, size_t count, + uint64_t df, + const BM25Scorer &scorer) { + if (count == 0) { + // Encode an empty posting list (just the header) + Header hdr{}; + hdr.magic = MAGIC; + hdr.version = VERSION; + hdr.num_docs = 0; + hdr.num_blocks = 0; + std::string result(HEADER_SIZE, '\0'); + std::memcpy(result.data(), &hdr, HEADER_SIZE); + return result; + } + + const uint32_t num_blocks = + static_cast((count + DOCS_PER_BLOCK - 1) / DOCS_PER_BLOCK); + + // ---- Phase 1: Compute delta-encoded doc_ids ---- + // Use 16-byte-aligned allocation so SIMD pack/max paths can use aligned loads + auto deltas = make_aligned_uint32_array(count); + deltas[0] = doc_ids[0]; + for (size_t i = 1; i < count; ++i) { + deltas[i] = doc_ids[i] - doc_ids[i - 1]; + } + + // ---- Phase 2: Compute per-block metadata and packed sizes ---- + struct BlockInfo { + size_t start; // index into the arrays + uint32_t num_docs; // number of docs in this block + uint8_t bw_id; // bitwidth for doc_id deltas + uint8_t bw_tf; // bitwidth for tfs + uint8_t bw_dl; // bitwidth for doc_lens + float max_score; // block max BM25 score + size_t packed_size; // total packed data size for this block + }; + + std::vector blocks(num_blocks); + + for (uint32_t b = 0; b < num_blocks; ++b) { + const size_t start = static_cast(b) * DOCS_PER_BLOCK; + const uint32_t num_docs = static_cast( + std::min(static_cast(DOCS_PER_BLOCK), count - start)); + + // Find max values in block for bitwidth computation + uint32_t max_delta = 0, max_tf = 0, max_dl = 0; + float block_max = 0.0f; + + if (num_docs == DOCS_PER_BLOCK) { + // Dispatch max for full blocks (SSE4.1 or scalar fallback) + simd::get_dispatch().max_128(deltas.get(), tfs, doc_lens, start, + DOCS_PER_BLOCK, max_delta, max_tf, max_dl); + // block_max_score still needs scalar loop (float BM25 scoring) + for (uint32_t i = 0; i < DOCS_PER_BLOCK; ++i) { + float s = scorer.score(df, tfs[start + i], doc_lens[start + i]); + block_max = std::max(block_max, s); + } + } else { + // Scalar path for tail blocks + for (uint32_t i = 0; i < num_docs; ++i) { + max_delta = std::max(max_delta, deltas[start + i]); + max_tf = std::max(max_tf, tfs[start + i]); + max_dl = std::max(max_dl, doc_lens[start + i]); + float s = scorer.score(df, tfs[start + i], doc_lens[start + i]); + block_max = std::max(block_max, s); + } + } + + blocks[b].start = start; + blocks[b].num_docs = num_docs; + blocks[b].bw_id = bits_needed(max_delta); + blocks[b].bw_tf = bits_needed(max_tf); + blocks[b].bw_dl = bits_needed(max_dl); + blocks[b].max_score = block_max; + // Full block (128 values): use SIMD packed size; tail block: use scalar + if (num_docs == DOCS_PER_BLOCK) { + blocks[b].packed_size = simd_packed_byte_size(blocks[b].bw_id) + + simd_packed_byte_size(blocks[b].bw_tf) + + simd_packed_byte_size(blocks[b].bw_dl); + } else { + blocks[b].packed_size = packed_byte_size(blocks[b].bw_id, num_docs) + + packed_byte_size(blocks[b].bw_tf, num_docs) + + packed_byte_size(blocks[b].bw_dl, num_docs); + } + } + + // ---- Phase 3: Compute total size and block offsets ---- + const size_t skip_list_size = num_blocks * sizeof(BlockMeta); + const size_t block_header_size = sizeof(BlockHeader); + + // Compute block offsets, aligning each block start to a 16-byte boundary + // so that SIMD decode paths can use aligned loads on the packed data. + size_t current_offset = align_up(HEADER_SIZE + skip_list_size, 16); + std::vector block_offsets(num_blocks); + for (uint32_t b = 0; b < num_blocks; ++b) { + block_offsets[b] = static_cast(current_offset); + current_offset = align_up( + current_offset + block_header_size + blocks[b].packed_size, 16); + } + + const size_t total_size = current_offset; + + // ---- Phase 4: Serialize ---- + std::string result(total_size, '\0'); + char *buf = result.data(); + + // File Header + Header hdr{}; + hdr.magic = MAGIC; + hdr.version = VERSION; + hdr.num_docs = static_cast(count); + hdr.num_blocks = num_blocks; + std::memcpy(buf, &hdr, HEADER_SIZE); + + // Skip List + BlockMeta *skip = reinterpret_cast(buf + HEADER_SIZE); + for (uint32_t b = 0; b < num_blocks; ++b) { + const size_t last_idx = blocks[b].start + blocks[b].num_docs - 1; + skip[b].max_doc_id = doc_ids[last_idx]; + skip[b].block_offset = block_offsets[b]; + skip[b].block_max_score = blocks[b].max_score; + } + + // Blocks + for (uint32_t b = 0; b < num_blocks; ++b) { + char *block_ptr = buf + block_offsets[b]; + + // Block Header + BlockHeader bhdr{}; + bhdr.min_doc_id = doc_ids[blocks[b].start]; + bhdr.bitwidth_id = blocks[b].bw_id; + bhdr.bitwidth_tf = blocks[b].bw_tf; + bhdr.bitwidth_dl = blocks[b].bw_dl; + bhdr.num_docs = static_cast(blocks[b].num_docs); + bhdr.block_max_score = blocks[b].max_score; + std::memcpy(block_ptr, &bhdr, sizeof(BlockHeader)); + + uint8_t *packed_ptr = + reinterpret_cast(block_ptr + sizeof(BlockHeader)); + + const bool is_full_block = (blocks[b].num_docs == DOCS_PER_BLOCK); + + // Pack doc_id deltas + const size_t id_bytes = + is_full_block ? simd_packed_byte_size(blocks[b].bw_id) + : packed_byte_size(blocks[b].bw_id, blocks[b].num_docs); + pack_uint32(&deltas[blocks[b].start], blocks[b].bw_id, blocks[b].num_docs, + packed_ptr); + packed_ptr += id_bytes; + + // Pack term frequencies + const size_t tf_bytes = + is_full_block ? simd_packed_byte_size(blocks[b].bw_tf) + : packed_byte_size(blocks[b].bw_tf, blocks[b].num_docs); + pack_uint32(&tfs[blocks[b].start], blocks[b].bw_tf, blocks[b].num_docs, + packed_ptr); + packed_ptr += tf_bytes; + + // Pack document lengths + pack_uint32(&doc_lens[blocks[b].start], blocks[b].bw_dl, blocks[b].num_docs, + packed_ptr); + } + + return result; +} + +// ============================================================ +// Iterator +// ============================================================ + +int BitPackedPostingIterator::open(const char *data, size_t size) { + if (!data || size < BitPackedPostingList::HEADER_SIZE) { + LOG_ERROR( + "BitPackedPostingIterator open failed: truncated data, " + "size[%zu] expected_min[%zu]", + size, BitPackedPostingList::HEADER_SIZE); + return -1; + } + + // Parse file header + BitPackedPostingList::Header hdr{}; + std::memcpy(&hdr, data, sizeof(hdr)); + + if (hdr.magic != BitPackedPostingList::MAGIC) { + LOG_ERROR( + "BitPackedPostingIterator open failed: bad magic, " + "got[0x%x] expected[0x%x]", + hdr.magic, BitPackedPostingList::MAGIC); + return -1; + } + if (hdr.version != BitPackedPostingList::VERSION) { + LOG_ERROR( + "BitPackedPostingIterator open failed: unsupported version, " + "got[%u] expected[%u]", + hdr.version, BitPackedPostingList::VERSION); + return -1; + } + + num_docs_ = hdr.num_docs; + num_blocks_ = hdr.num_blocks; + data_ = data; + data_size_ = size; + + if (num_docs_ == 0) { + current_doc_id_ = NO_MORE_DOCS; + return 0; + } + + // Validate skip list fits + const size_t skip_list_offset = BitPackedPostingList::HEADER_SIZE; + const size_t skip_list_size = + num_blocks_ * sizeof(BitPackedPostingList::BlockMeta); + if (skip_list_offset + skip_list_size > size) { + LOG_ERROR( + "BitPackedPostingIterator open failed: skip list overruns buffer, " + "num_blocks[%u] data_size[%zu] need[%zu]", + num_blocks_, size, skip_list_offset + skip_list_size); + return -1; + } + + skip_list_ = reinterpret_cast( + data + skip_list_offset); + + // Compute global max score + global_max_score_ = 0.0f; + for (uint32_t b = 0; b < num_blocks_; ++b) { + global_max_score_ = + std::max(global_max_score_, skip_list_[b].block_max_score); + } + + // Initialize to before-first-block state + current_block_idx_ = 0; + in_block_pos_ = 0; + current_block_size_ = 0; + block_decoded_ = false; + current_doc_id_ = NO_MORE_DOCS; + + // Cache SIMD dispatch function pointers to avoid PLT overhead on hot path + const auto &dispatch = simd::get_dispatch(); + prefix_sum_fn_ = dispatch.prefix_sum_128; + find_first_ge_fn_ = dispatch.find_first_ge; + unpack_fn_ = dispatch.unpack_uint32_128; + + return 0; +} + +void BitPackedPostingIterator::decode_block(size_t block_idx) { + if (block_idx >= num_blocks_) { + LOG_WARN( + "BitPackedPostingIterator decode_block out of range: " + "block_idx[%zu] num_blocks[%u]", + block_idx, num_blocks_); + current_block_size_ = 0; + block_decoded_ = false; + return; + } + + const auto &meta = skip_list_[block_idx]; + const char *block_ptr = data_ + meta.block_offset; + + // Parse block header + BitPackedPostingList::BlockHeader bhdr{}; + std::memcpy(&bhdr, block_ptr, sizeof(bhdr)); + + current_block_size_ = bhdr.num_docs; + current_block_idx_ = block_idx; + in_block_pos_ = 0; + + const uint8_t *packed_ptr = + reinterpret_cast(block_ptr + sizeof(bhdr)); + + const bool is_full_block = + (bhdr.num_docs == BitPackedPostingList::DOCS_PER_BLOCK); + + // Unpack doc_id deltas + const size_t id_bytes = + is_full_block + ? BitPackedPostingList::simd_packed_byte_size(bhdr.bitwidth_id) + : BitPackedPostingList::packed_byte_size(bhdr.bitwidth_id, + bhdr.num_docs); + alignas(16) uint32_t deltas[BitPackedPostingList::DOCS_PER_BLOCK]; + if (is_full_block) { + // Fast path: use cached function pointer directly for full blocks + unpack_fn_(packed_ptr, bhdr.bitwidth_id, deltas); + } else { + BitPackedPostingList::unpack_uint32(packed_ptr, bhdr.bitwidth_id, + bhdr.num_docs, deltas); + } + packed_ptr += id_bytes; + + // Reconstruct absolute doc_ids from deltas using prefix-sum + if (is_full_block) { + prefix_sum_fn_(deltas, bhdr.min_doc_id, + BitPackedPostingList::DOCS_PER_BLOCK, block_doc_ids_); + } else { + // Scalar prefix-sum for tail block + block_doc_ids_[0] = bhdr.min_doc_id; + for (uint32_t i = 1; i < bhdr.num_docs; ++i) { + block_doc_ids_[i] = block_doc_ids_[i - 1] + deltas[i]; + } + } + + // Lazy decode: record packed data pointers and bitwidths for tf/doc_len. + // Actual decoding is deferred until term_freq() or doc_len() is called. + const size_t tf_bytes = + is_full_block + ? BitPackedPostingList::simd_packed_byte_size(bhdr.bitwidth_tf) + : BitPackedPostingList::packed_byte_size(bhdr.bitwidth_tf, + bhdr.num_docs); + packed_tf_ptr_ = packed_ptr; + current_bitwidth_tf_ = bhdr.bitwidth_tf; + packed_ptr += tf_bytes; + + packed_dl_ptr_ = packed_ptr; + current_bitwidth_dl_ = bhdr.bitwidth_dl; + + current_block_num_docs_ = bhdr.num_docs; + current_block_is_full_ = is_full_block; + + // Reset lazy decode flags + tf_decoded_ = false; + dl_decoded_ = false; + + block_decoded_ = true; +} + +uint32_t BitPackedPostingIterator::next_doc() { + if (num_docs_ == 0) { + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // If no block is decoded yet, decode the first block + if (!block_decoded_) { + decode_block(0); + if (current_block_size_ == 0) { + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + current_doc_id_ = block_doc_ids_[0]; + in_block_pos_ = 0; + return current_doc_id_; + } + + // Advance within current block + ++in_block_pos_; + if (in_block_pos_ < current_block_size_) { + current_doc_id_ = block_doc_ids_[in_block_pos_]; + return current_doc_id_; + } + + // Move to next block + size_t next_block = current_block_idx_ + 1; + if (next_block >= num_blocks_) { + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + decode_block(next_block); + if (current_block_size_ == 0) { + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + current_doc_id_ = block_doc_ids_[0]; + in_block_pos_ = 0; + return current_doc_id_; +} + +size_t BitPackedPostingIterator::simd_find_first_ge(uint32_t target, + size_t start) const { + return find_first_ge_fn_(block_doc_ids_, current_block_size_, target, start); +} + +uint32_t BitPackedPostingIterator::advance(uint32_t target) { + if (num_docs_ == 0) { + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // If current doc_id already >= target, return it + if (current_doc_id_ != NO_MORE_DOCS && current_doc_id_ >= target) { + return current_doc_id_; + } + + // Use skip list to find the target block via binary search. + // Find the first block whose max_doc_id >= target. + size_t lo = 0, hi = num_blocks_; + + // If we have a current block and its max_doc_id >= target, + // we can search within the current block first. + if (block_decoded_ && current_block_idx_ < num_blocks_ && + skip_list_[current_block_idx_].max_doc_id >= target) { + // Target might be in current block - SIMD scan from current position + { + size_t pos = simd_find_first_ge(target, in_block_pos_); + if (pos < current_block_size_) { + in_block_pos_ = pos; + current_doc_id_ = block_doc_ids_[pos]; + return current_doc_id_; + } + } + // Not found in current block (shouldn't happen if skip list is correct) + lo = current_block_idx_ + 1; + } else if (block_decoded_) { + // Current block's max_doc_id < target, start search from next block + lo = current_block_idx_ + 1; + } + + // Binary search in skip list for the first block with max_doc_id >= target + size_t target_block = hi; // sentinel: no block found + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + if (skip_list_[mid].max_doc_id >= target) { + target_block = mid; + hi = mid; + } else { + lo = mid + 1; + } + } + + if (target_block >= num_blocks_) { + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // Decode the target block + decode_block(target_block); + if (current_block_size_ == 0) { + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + + // SIMD scan within the block for the first doc_id >= target + { + size_t pos = simd_find_first_ge(target, 0); + if (pos < current_block_size_) { + in_block_pos_ = pos; + current_doc_id_ = block_doc_ids_[pos]; + return current_doc_id_; + } + } + + // All docs in this block are < target (shouldn't happen with correct skip + // list), try next block + size_t next = target_block + 1; + if (next >= num_blocks_) { + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + decode_block(next); + if (current_block_size_ == 0) { + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + { + size_t pos = simd_find_first_ge(target, 0); + if (pos < current_block_size_) { + in_block_pos_ = pos; + current_doc_id_ = block_doc_ids_[pos]; + return current_doc_id_; + } + } + current_doc_id_ = NO_MORE_DOCS; + return NO_MORE_DOCS; +} + +void BitPackedPostingIterator::ensure_tf_decoded() { + if (tf_decoded_) { + return; + } + if (current_block_is_full_) { + unpack_fn_(packed_tf_ptr_, current_bitwidth_tf_, block_tfs_); + } else { + BitPackedPostingList::unpack_uint32(packed_tf_ptr_, current_bitwidth_tf_, + current_block_num_docs_, block_tfs_); + } + tf_decoded_ = true; +} + +void BitPackedPostingIterator::ensure_dl_decoded() { + if (dl_decoded_) { + return; + } + if (current_block_is_full_) { + unpack_fn_(packed_dl_ptr_, current_bitwidth_dl_, block_doc_lens_); + } else { + BitPackedPostingList::unpack_uint32(packed_dl_ptr_, current_bitwidth_dl_, + current_block_num_docs_, + block_doc_lens_); + } + dl_decoded_ = true; +} + +uint32_t BitPackedPostingIterator::term_freq() { + if (!block_decoded_ || in_block_pos_ >= current_block_size_) { + return 0; + } + ensure_tf_decoded(); + return block_tfs_[in_block_pos_]; +} + +uint32_t BitPackedPostingIterator::doc_len() { + if (!block_decoded_ || in_block_pos_ >= current_block_size_) { + return 1; + } + ensure_dl_decoded(); + return block_doc_lens_[in_block_pos_]; +} + +BitPackedPostingIterator::BlockMaxInfo +BitPackedPostingIterator::block_max_info_for(uint32_t target) const { + if (num_blocks_ == 0 || skip_list_ == nullptr) { + return {0.0f, NO_MORE_DOCS}; + } + + // Fast path: check if target falls within the previously cached block + if (cached_bmi_valid_ && target <= cached_bmi_last_doc_) { + // target is in the same or earlier block as last query. + // Check if it's still in the same block (block_idx is correct). + if (cached_bmi_block_idx_ == 0 || + target > skip_list_[cached_bmi_block_idx_ - 1].max_doc_id) { + return {cached_bmi_score_, cached_bmi_last_doc_}; + } + } + + size_t lo = 0, hi = num_blocks_; + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + if (skip_list_[mid].max_doc_id >= target) { + hi = mid; + } else { + lo = mid + 1; + } + } + if (lo >= num_blocks_) { + return {0.0f, NO_MORE_DOCS}; + } + + // Update cache + cached_bmi_block_idx_ = lo; + cached_bmi_last_doc_ = skip_list_[lo].max_doc_id; + cached_bmi_score_ = skip_list_[lo].block_max_score; + cached_bmi_valid_ = true; + + return {cached_bmi_score_, cached_bmi_last_doc_}; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/posting/bitpacked_posting_list.h b/src/db/index/column/fts_column/posting/bitpacked_posting_list.h new file mode 100644 index 000000000..aeeb7f12f --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_posting_list.h @@ -0,0 +1,237 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "bitpacked_simd_dispatch.h" +#include "../bm25_scorer.h" + +namespace zvec::fts { + +// ============================================================ +// BitPacked Posting List encoder +// ============================================================ + +class BitPackedPostingList { + public: + static constexpr uint32_t DOCS_PER_BLOCK = 128; + static constexpr uint32_t MAGIC = 0x42504B44; // "BPKD" + static constexpr uint32_t VERSION = 1; + + /// Skip-list entry stored after the file header. + struct BlockMeta { + uint32_t max_doc_id; ///< Last (largest) doc_id in this block + uint32_t block_offset; ///< Byte offset from data start to block header + float block_max_score; ///< BM25 score upper bound for this block + }; + + /// File header (16 bytes). + struct Header { + uint32_t magic; + uint32_t version; + uint32_t num_docs; + uint32_t num_blocks; + }; + static constexpr size_t HEADER_SIZE = sizeof(Header); + + /// Block header (16 bytes, padded for SIMD alignment). + struct BlockHeader { + uint32_t min_doc_id; + uint8_t bitwidth_id; + uint8_t bitwidth_tf; + uint8_t bitwidth_dl; + uint8_t num_docs; ///< Number of docs in this block (<=128) + float block_max_score; ///< Redundant copy for fast in-block access + uint32_t padding_{ + 0}; ///< Padding to make BlockHeader 16 bytes (SIMD alignment) + }; + + /// Encode a posting list with inline payloads. + /// \param doc_ids Sorted ascending doc_id array + /// \param tfs Term frequency for each doc + /// \param doc_lens Document length for each doc + /// \param count Number of entries + /// \param df Document frequency (used for IDF in block_max_score) + /// \param scorer BM25 scorer with segment stats loaded + /// \return Serialized bitpacked posting list + static std::string encode(const uint32_t *doc_ids, const uint32_t *tfs, + const uint32_t *doc_lens, size_t count, uint64_t df, + const BM25Scorer &scorer); + + /// Check if raw data starts with the BitPacked magic number. + static bool is_bitpacked_format(const char *data, size_t size) { + if (size < sizeof(uint32_t)) return false; + uint32_t magic = 0; + std::memcpy(&magic, data, sizeof(uint32_t)); + return magic == MAGIC; + } + + // ---- Low-level bitpacking primitives ---- + + /// Pack \p count uint32 values (each using \p bitwidth bits) into \p out. + /// \p out must have at least ceil(bitwidth * count / 8) bytes. + /// \p count must be <= DOCS_PER_BLOCK (128). + static void pack_uint32(const uint32_t *in, uint8_t bitwidth, uint32_t count, + uint8_t *out); + + /// Unpack \p count uint32 values (each using \p bitwidth bits) from \p in. + /// \p out must have room for \p count uint32_t values. + static void unpack_uint32(const uint8_t *in, uint8_t bitwidth, uint32_t count, + uint32_t *out); + + /// Compute the minimum number of bits needed to represent \p max_value. + /// Returns 0 if max_value == 0. + static uint8_t bits_needed(uint32_t max_value); + + /// Compute packed byte size for \p count values at \p bitwidth bits each + /// (scalar format, used for tail blocks with count < DOCS_PER_BLOCK). + static size_t packed_byte_size(uint8_t bitwidth, uint32_t count) { + return (static_cast(bitwidth) * count + 7) / 8; + } + + /// Compute packed byte size for a full SIMD block (128 values). + /// SIMD format stores bitwidth __m128i values = bitwidth * 16 bytes. + static size_t simd_packed_byte_size(uint8_t bitwidth) { + return static_cast(bitwidth) * 16; + } +}; + +// ============================================================ +// BitPacked Posting Iterator (zero-copy, block-at-a-time) +// ============================================================ + +/// Zero-copy iterator over a serialized BitPacked posting list. +/// Decodes one block at a time into stack-allocated arrays. +class BitPackedPostingIterator { + public: + static constexpr uint32_t NO_MORE_DOCS = UINT32_MAX; + + BitPackedPostingIterator() = default; + + /// Open from serialized data (zero-copy, does not own the data). + /// \param data Pointer to serialized bitpacked posting list + /// \param size Size of the serialized data in bytes + /// \return 0 on success, -1 on error (bad magic, truncated data, etc.) + int open(const char *data, size_t size); + + /// Advance to the next document. + /// \return doc_id of the next document, or NO_MORE_DOCS if exhausted. + uint32_t next_doc(); + + /// Advance to the first document with doc_id >= target. + /// Uses the skip list for O(log N_blocks) block-level seeking. + /// \return doc_id >= target, or NO_MORE_DOCS if exhausted. + uint32_t advance(uint32_t target); + + /// Current document ID (valid after next_doc/advance). + uint32_t doc_id() const { + return current_doc_id_; + } + + /// Term frequency of the current document (valid after next_doc/advance). + /// NOTE: non-const because lazy decode may be triggered on first access. + uint32_t term_freq(); + + /// Document length of the current document (valid after next_doc/advance). + /// NOTE: non-const because lazy decode may be triggered on first access. + uint32_t doc_len(); + + /// Return both block_max_score and max_doc_id for the block containing + /// \p target in a single binary search on the skip list. + /// Does NOT move the iterator position. + struct BlockMaxInfo { + float block_max_score{0.0f}; + uint32_t block_last_doc{NO_MORE_DOCS}; + }; + BlockMaxInfo block_max_info_for(uint32_t target) const; + + /// Total number of documents in this posting list. + uint64_t cost() const { + return num_docs_; + } + + /// Maximum block_max_score across all blocks (global upper bound). + float max_score() const { + return global_max_score_; + } + + private: + /// Decode block at index \p block_idx into the stack arrays. + void decode_block(size_t block_idx); + + /// Lazy decode: ensure tf values are decoded before access. + void ensure_tf_decoded(); + + /// Lazy decode: ensure doc_len values are decoded before access. + void ensure_dl_decoded(); + + /// SIMD search: find first index i in block_doc_ids_[start..size) + /// where doc_id >= target. Uses SSE4.1 for 4-wide comparison. + size_t simd_find_first_ge(uint32_t target, size_t start) const; + + // File header fields + uint32_t num_docs_{0}; + uint32_t num_blocks_{0}; + + // Skip list (pointer into data_, not owned) + const BitPackedPostingList::BlockMeta *skip_list_{nullptr}; + + // Raw data pointer (not owned) + const char *data_{nullptr}; + size_t data_size_{0}; + + // Current block state (decoded into stack arrays) + alignas(16) uint32_t block_doc_ids_[BitPackedPostingList::DOCS_PER_BLOCK]; + alignas(16) uint32_t block_tfs_[BitPackedPostingList::DOCS_PER_BLOCK]; + alignas(16) uint32_t block_doc_lens_[BitPackedPostingList::DOCS_PER_BLOCK]; + size_t current_block_idx_{0}; + uint32_t current_block_size_{0}; + size_t in_block_pos_{0}; ///< Position within current decoded block + bool block_decoded_{false}; ///< Whether current block is decoded + + // Lazy decode state: tf and doc_len are decoded on first access + bool tf_decoded_{false}; + bool dl_decoded_{false}; + + // Store packed data pointers for lazy decode + const uint8_t *packed_tf_ptr_{nullptr}; + const uint8_t *packed_dl_ptr_{nullptr}; + uint8_t current_bitwidth_tf_{0}; + uint8_t current_bitwidth_dl_{0}; + uint32_t current_block_num_docs_{0}; ///< num_docs for lazy decode dispatch + bool current_block_is_full_{false}; ///< Whether current block is full (128) + + uint32_t current_doc_id_{NO_MORE_DOCS}; + float global_max_score_{0.0f}; + + // Cached SIMD dispatch function pointers (initialized in open()). + // Avoids repeated PLT/indirect calls through get_dispatch() on every + // decode_block / simd_find_first_ge invocation. + simd::PrefixSumFunc prefix_sum_fn_{nullptr}; + simd::FindFirstGeFunc find_first_ge_fn_{nullptr}; + simd::UnpackFunc unpack_fn_{nullptr}; + + // Cache for block_max_info_for to avoid repeated binary searches. + // If target falls within [cached_bmi_block_min_doc_+1, cached_bmi_last_doc_], + // we can return the cached result directly. + mutable uint32_t cached_bmi_last_doc_{0}; + mutable float cached_bmi_score_{0.0f}; + mutable size_t cached_bmi_block_idx_{0}; + mutable bool cached_bmi_valid_{false}; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/posting/bitpacked_simd_avx2.cc b/src/db/index/column/fts_column/posting/bitpacked_simd_avx2.cc new file mode 100644 index 000000000..91f5ed002 --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_simd_avx2.cc @@ -0,0 +1,216 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bitpacked_simd_avx2.h" + +#if defined(__AVX2__) || \ + (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) + +#include +#include +#include "bitpacked_simd_sse41.h" + +#ifdef _MSC_VER +#include +static inline int ctz_u32(unsigned int v) { + unsigned long index; + _BitScanForward(&index, v); + return static_cast(index); +} +#else +static inline int ctz_u32(unsigned int v) { + return __builtin_ctz(v); +} +#endif + +namespace zvec::fts::simd { + +// ------------------------------------------------------------ +// avx2_max_128 +// ------------------------------------------------------------ + +void avx2_max_128(const uint32_t *deltas, const uint32_t *tfs, + const uint32_t *doc_lens, size_t start, uint32_t count, + uint32_t &max_delta, uint32_t &max_tf, uint32_t &max_dl) { + __m256i vmax_delta = _mm256_setzero_si256(); + __m256i vmax_tf = _mm256_setzero_si256(); + __m256i vmax_dl = _mm256_setzero_si256(); + + for (uint32_t i = 0; i < count; i += 8) { + vmax_delta = _mm256_max_epu32( + vmax_delta, _mm256_loadu_si256( + reinterpret_cast(&deltas[start + i]))); + vmax_tf = _mm256_max_epu32( + vmax_tf, + _mm256_loadu_si256(reinterpret_cast(&tfs[start + i]))); + vmax_dl = _mm256_max_epu32( + vmax_dl, _mm256_loadu_si256( + reinterpret_cast(&doc_lens[start + i]))); + } + + // Horizontal max: reduce 8 lanes to scalar + auto hmax = [](__m256i v) -> uint32_t { + // Reduce 256-bit to 128-bit by taking max of high and low halves + __m128i lo = _mm256_castsi256_si128(v); + __m128i hi = _mm256_extracti128_si256(v, 1); + __m128i m = _mm_max_epu32(lo, hi); + // Reduce 128-bit to scalar + m = _mm_max_epu32(m, _mm_shuffle_epi32(m, _MM_SHUFFLE(2, 3, 0, 1))); + m = _mm_max_epu32(m, _mm_shuffle_epi32(m, _MM_SHUFFLE(1, 0, 3, 2))); + return static_cast(_mm_extract_epi32(m, 0)); + }; + + max_delta = hmax(vmax_delta); + max_tf = hmax(vmax_tf); + max_dl = hmax(vmax_dl); +} + +// ------------------------------------------------------------ +// avx2_pack_uint32_128 — fallback to SSE4.1 +// ------------------------------------------------------------ + +void avx2_pack_uint32_128(const uint32_t *in, uint8_t bitwidth, uint8_t *out) { + // FastPForLib does not provide AVX2 bitpacking; delegate to SSE4.1. + sse41_pack_uint32_128(in, bitwidth, out); +} + +// ------------------------------------------------------------ +// avx2_unpack_uint32_128 — fallback to SSE4.1 +// ------------------------------------------------------------ + +void avx2_unpack_uint32_128(const uint8_t *in, uint8_t bitwidth, + uint32_t *out) { + // FastPForLib does not provide AVX2 bitpacking; delegate to SSE4.1. + sse41_unpack_uint32_128(in, bitwidth, out); +} + +// ------------------------------------------------------------ +// avx2_prefix_sum_128 +// ------------------------------------------------------------ + +void avx2_prefix_sum_128(const uint32_t *deltas, uint32_t min_doc_id, + uint32_t /*count*/, uint32_t *out) { + // Process 8 elements per iteration (16 groups of 8 = 128 elements). + // Within each 256-bit register we compute a prefix-sum, then propagate + // the carry (last element) to the next group. + __m256i carry = _mm256_set1_epi32(static_cast(min_doc_id) - + static_cast(deltas[0])); + + for (uint32_t g = 0; g < 16; ++g) { + __m256i v = + _mm256_loadu_si256(reinterpret_cast(&deltas[g * 8])); + + // In-register prefix-sum for 8 elements (two 128-bit lanes independently, + // then cross-lane fixup). + + // Step 1: shift by 1 element (4 bytes) within each 128-bit lane + __m256i shifted1 = _mm256_bslli_epi128(v, 4); + v = _mm256_add_epi32(v, shifted1); + + // Step 2: shift by 2 elements (8 bytes) within each 128-bit lane + __m256i shifted2 = _mm256_bslli_epi128(v, 8); + v = _mm256_add_epi32(v, shifted2); + + // Step 3: cross-lane fixup — high lane needs the sum of the low lane's + // last element (index 3) added to all its elements. + // Broadcast low lane's element[3] to all positions of high lane. + __m128i lo = _mm256_castsi256_si128(v); + __m128i lo_last = _mm_shuffle_epi32(lo, _MM_SHUFFLE(3, 3, 3, 3)); + __m256i cross = _mm256_set_m128i(lo_last, _mm_setzero_si128()); + v = _mm256_add_epi32(v, cross); + + // Add carry from previous group + v = _mm256_add_epi32(v, carry); + + _mm256_storeu_si256(reinterpret_cast<__m256i *>(&out[g * 8]), v); + + // Broadcast the last element (index 7) as carry for next group. + // Element 7 is in the high lane at position 3. + __m128i hi = _mm256_extracti128_si256(v, 1); + __m128i hi_last = _mm_shuffle_epi32(hi, _MM_SHUFFLE(3, 3, 3, 3)); + carry = _mm256_set_m128i(hi_last, hi_last); + } +} + +// ------------------------------------------------------------ +// avx2_find_first_ge +// ------------------------------------------------------------ + +size_t avx2_find_first_ge(const uint32_t *arr, uint32_t size, uint32_t target, + size_t start) { + const __m256i vtarget = _mm256_set1_epi32(static_cast(target)); + const __m256i sign_bit = _mm256_set1_epi32(static_cast(0x80000000u)); + const __m256i starget = _mm256_xor_si256(vtarget, sign_bit); + + size_t i = start; + // Scalar until aligned to 4-element boundary (minimum for unaligned AVX2) + for (; i < size && (i & 3); ++i) { + if (arr[i] >= target) { + return i; + } + } + // SIMD scan: 8 elements at a time + for (; i + 8 <= size; i += 8) { + __m256i v = _mm256_loadu_si256(reinterpret_cast(&arr[i])); + __m256i sv = _mm256_xor_si256(v, sign_bit); + // cmpgt: sv < starget means arr[i] < target + __m256i cmp = _mm256_cmpgt_epi32(starget, sv); + int mask = _mm256_movemask_ps(_mm256_castsi256_ps(cmp)); + if (mask != 0xFF) { + // At least one element >= target + int first = ctz_u32(static_cast(~mask & 0xFF)); + return i + first; + } + } + // Scalar tail + for (; i < size; ++i) { + if (arr[i] >= target) { + return i; + } + } + return size; +} + +} // namespace zvec::fts::simd + +#else // !defined(__AVX2__) && !(defined(_MSC_VER) && (defined(_M_X64) || + // defined(_M_IX86))) + +// Stub implementations when AVX2 is not available at compile time. +// The runtime dispatch layer (bitpacked_simd_dispatch.cc) will never call +// these on non-AVX2 machines, but the linker still needs the symbols. + +namespace zvec::fts::simd { + +void avx2_max_128(const uint32_t *, const uint32_t *, const uint32_t *, size_t, + uint32_t, uint32_t &max_delta, uint32_t &max_tf, + uint32_t &max_dl) { + max_delta = 0; + max_tf = 0; + max_dl = 0; +} + +void avx2_pack_uint32_128(const uint32_t *, uint8_t, uint8_t *) {} + +void avx2_unpack_uint32_128(const uint8_t *, uint8_t, uint32_t *) {} + +void avx2_prefix_sum_128(const uint32_t *, uint32_t, uint32_t, uint32_t *) {} + +size_t avx2_find_first_ge(const uint32_t *, uint32_t size, uint32_t, size_t) { + return size; +} + +} // namespace zvec::fts::simd + +#endif // defined(__AVX2__) diff --git a/src/db/index/column/fts_column/posting/bitpacked_simd_avx2.h b/src/db/index/column/fts_column/posting/bitpacked_simd_avx2.h new file mode 100644 index 000000000..d86796016 --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_simd_avx2.h @@ -0,0 +1,49 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace zvec::fts::simd { + +/// Compute element-wise max of 128 uint32 values across three arrays using +/// AVX2 _mm256_max_epu32. \p deltas must be 32-byte aligned; \p tfs and +/// \p doc_lens may be unaligned. +void avx2_max_128(const uint32_t *deltas, const uint32_t *tfs, + const uint32_t *doc_lens, size_t start, uint32_t count, + uint32_t &max_delta, uint32_t &max_tf, uint32_t &max_dl); + +/// Pack 128 uint32 values at \p bitwidth bits each into \p out. +/// Falls back to SSE4.1 implementation (FastPForLib lacks AVX2 bitpacking). +void avx2_pack_uint32_128(const uint32_t *in, uint8_t bitwidth, uint8_t *out); + +/// Unpack 128 uint32 values at \p bitwidth bits each from \p in. +/// Falls back to SSE4.1 implementation (FastPForLib lacks AVX2 bitpacking). +void avx2_unpack_uint32_128(const uint8_t *in, uint8_t bitwidth, uint32_t *out); + +/// Compute prefix-sum over \p count (must be 128) delta values, producing +/// absolute doc_ids. Uses AVX2 SIMD prefix-sum with carry propagation. +/// \p deltas must be 32-byte aligned; \p out must be 32-byte aligned. +void avx2_prefix_sum_128(const uint32_t *deltas, uint32_t min_doc_id, + uint32_t count, uint32_t *out); + +/// Find the first index i in arr[start..size) where arr[i] >= target. +/// Uses AVX2 8-wide comparison with unsigned-to-signed trick. +/// \p arr must be 32-byte aligned. +size_t avx2_find_first_ge(const uint32_t *arr, uint32_t size, uint32_t target, + size_t start); + +} // namespace zvec::fts::simd diff --git a/src/db/index/column/fts_column/posting/bitpacked_simd_dispatch.cc b/src/db/index/column/fts_column/posting/bitpacked_simd_dispatch.cc new file mode 100644 index 000000000..c850703cd --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_simd_dispatch.cc @@ -0,0 +1,60 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bitpacked_simd_dispatch.h" +#include +#include "bitpacked_simd_scalar.h" +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || \ + defined(_M_IX86) +#include "bitpacked_simd_avx2.h" +#include "bitpacked_simd_sse41.h" +#endif + +namespace zvec::fts::simd { + +static DispatchTable init_dispatch() { + DispatchTable t{}; +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || \ + defined(_M_IX86) + if (zvec::ailego::internal::CpuFeatures::static_flags_.AVX2) { + t.max_128 = avx2_max_128; + t.pack_uint32_128 = avx2_pack_uint32_128; + t.unpack_uint32_128 = avx2_unpack_uint32_128; + t.prefix_sum_128 = avx2_prefix_sum_128; + t.find_first_ge = avx2_find_first_ge; + return t; + } + if (zvec::ailego::internal::CpuFeatures::static_flags_.SSE4_1) { + t.max_128 = sse41_max_128; + t.pack_uint32_128 = sse41_pack_uint32_128; + t.unpack_uint32_128 = sse41_unpack_uint32_128; + t.prefix_sum_128 = sse41_prefix_sum_128; + t.find_first_ge = sse41_find_first_ge; + return t; + } +#endif + t.max_128 = scalar_max_128; + t.pack_uint32_128 = scalar_pack_uint32_128; + t.unpack_uint32_128 = scalar_unpack_uint32_128; + t.prefix_sum_128 = scalar_prefix_sum_128; + t.find_first_ge = scalar_find_first_ge; + return t; +} + +const DispatchTable &get_dispatch() { + static const DispatchTable table = init_dispatch(); + return table; +} + +} // namespace zvec::fts::simd diff --git a/src/db/index/column/fts_column/posting/bitpacked_simd_dispatch.h b/src/db/index/column/fts_column/posting/bitpacked_simd_dispatch.h new file mode 100644 index 000000000..64c498e06 --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_simd_dispatch.h @@ -0,0 +1,44 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace zvec::fts::simd { + +// Function pointer types for SIMD-dispatched operations. +using MaxFunc = void (*)(const uint32_t *, const uint32_t *, const uint32_t *, + size_t, uint32_t, uint32_t &, uint32_t &, uint32_t &); +using PackFunc = void (*)(const uint32_t *, uint8_t, uint8_t *); +using UnpackFunc = void (*)(const uint8_t *, uint8_t, uint32_t *); +using PrefixSumFunc = void (*)(const uint32_t *, uint32_t, uint32_t, + uint32_t *); +using FindFirstGeFunc = size_t (*)(const uint32_t *, uint32_t, uint32_t, + size_t); + +/// Dispatch table populated once at startup via CPU feature detection. +struct DispatchTable { + MaxFunc max_128; + PackFunc pack_uint32_128; + UnpackFunc unpack_uint32_128; + PrefixSumFunc prefix_sum_128; + FindFirstGeFunc find_first_ge; +}; + +/// Get the global dispatch table (initialized on first call). +const DispatchTable &get_dispatch(); + +} // namespace zvec::fts::simd diff --git a/src/db/index/column/fts_column/posting/bitpacked_simd_scalar.cc b/src/db/index/column/fts_column/posting/bitpacked_simd_scalar.cc new file mode 100644 index 000000000..4877751ba --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_simd_scalar.cc @@ -0,0 +1,97 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bitpacked_simd_scalar.h" +#include +#include +#include +#include "bitpacked_posting_list.h" + +namespace zvec::fts::simd { + +// ------------------------------------------------------------ +// scalar_max_128 +// ------------------------------------------------------------ + +void scalar_max_128(const uint32_t *deltas, const uint32_t *tfs, + const uint32_t *doc_lens, size_t start, uint32_t count, + uint32_t &max_delta, uint32_t &max_tf, uint32_t &max_dl) { + uint32_t md = 0, mt = 0, ml = 0; + for (uint32_t i = 0; i < count; ++i) { + md = std::max(md, deltas[start + i]); + mt = std::max(mt, tfs[start + i]); + ml = std::max(ml, doc_lens[start + i]); + } + max_delta = md; + max_tf = mt; + max_dl = ml; +} + +// ------------------------------------------------------------ +// scalar_pack_uint32_128 +// ------------------------------------------------------------ + +void scalar_pack_uint32_128(const uint32_t *in, uint8_t bitwidth, + uint8_t *out) { + // Scalar fastpack processes 32 values at a time; loop 4 times for 128. + const size_t total_bytes = + BitPackedPostingList::simd_packed_byte_size(bitwidth); + std::memset(out, 0, total_bytes); + + uint32_t *out32 = reinterpret_cast(out); + for (uint32_t g = 0; g < 4; ++g) { + FastPForLib::fastpackwithoutmask(in + g * 32, out32, bitwidth); + out32 += bitwidth; + } +} + +// ------------------------------------------------------------ +// scalar_unpack_uint32_128 +// ------------------------------------------------------------ + +void scalar_unpack_uint32_128(const uint8_t *in, uint8_t bitwidth, + uint32_t *out) { + const uint32_t *in32 = reinterpret_cast(in); + for (uint32_t g = 0; g < 4; ++g) { + FastPForLib::fastunpack(in32, out + g * 32, bitwidth); + in32 += bitwidth; + } +} + +// ------------------------------------------------------------ +// scalar_prefix_sum_128 +// ------------------------------------------------------------ + +void scalar_prefix_sum_128(const uint32_t *deltas, uint32_t min_doc_id, + uint32_t count, uint32_t *out) { + // First element: min_doc_id corresponds to deltas[0] + out[0] = min_doc_id; + for (uint32_t i = 1; i < count; ++i) { + out[i] = out[i - 1] + deltas[i]; + } +} + +// ------------------------------------------------------------ +// scalar_find_first_ge +// ------------------------------------------------------------ + +size_t scalar_find_first_ge(const uint32_t *arr, uint32_t size, uint32_t target, + size_t start) { + for (size_t i = start; i < size; ++i) { + if (arr[i] >= target) return i; + } + return size; +} + +} // namespace zvec::fts::simd diff --git a/src/db/index/column/fts_column/posting/bitpacked_simd_scalar.h b/src/db/index/column/fts_column/posting/bitpacked_simd_scalar.h new file mode 100644 index 000000000..ce0cbf9f7 --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_simd_scalar.h @@ -0,0 +1,47 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace zvec::fts::simd { + +/// Scalar fallback: compute element-wise max of up to 128 uint32 values across +/// three arrays using a simple loop. +void scalar_max_128(const uint32_t *deltas, const uint32_t *tfs, + const uint32_t *doc_lens, size_t start, uint32_t count, + uint32_t &max_delta, uint32_t &max_tf, uint32_t &max_dl); + +/// Scalar fallback: pack 128 uint32 values at \p bitwidth bits each into \p out +/// using FastPForLib::fastpackwithoutmask (32 values at a time, 4 iterations). +void scalar_pack_uint32_128(const uint32_t *in, uint8_t bitwidth, uint8_t *out); + +/// Scalar fallback: unpack 128 uint32 values at \p bitwidth bits each from +/// \p in using FastPForLib::fastunpack (32 values at a time, 4 iterations). +void scalar_unpack_uint32_128(const uint8_t *in, uint8_t bitwidth, + uint32_t *out); + +/// Scalar fallback: compute prefix-sum over \p count delta values, producing +/// absolute doc_ids. +void scalar_prefix_sum_128(const uint32_t *deltas, uint32_t min_doc_id, + uint32_t count, uint32_t *out); + +/// Scalar fallback: find the first index i in arr[start..size) where +/// arr[i] >= target using a linear scan. +size_t scalar_find_first_ge(const uint32_t *arr, uint32_t size, uint32_t target, + size_t start); + +} // namespace zvec::fts::simd diff --git a/src/db/index/column/fts_column/posting/bitpacked_simd_sse41.cc b/src/db/index/column/fts_column/posting/bitpacked_simd_sse41.cc new file mode 100644 index 000000000..1a7ccd20f --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_simd_sse41.cc @@ -0,0 +1,202 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bitpacked_simd_sse41.h" + +#if defined(__SSE4_1__) || \ + (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) + +#include +#include // SSE2 +#include +#include // SSE4.1 +#include +#include "bitpacked_posting_list.h" + +#ifdef _MSC_VER +#include +static inline int ctz_u32(unsigned int v) { + unsigned long index; + _BitScanForward(&index, v); + return static_cast(index); +} +#else +static inline int ctz_u32(unsigned int v) { + return __builtin_ctz(v); +} +#endif + +namespace zvec::fts::simd { + +// ------------------------------------------------------------ +// sse41_max_128 +// ------------------------------------------------------------ + +void sse41_max_128(const uint32_t *deltas, const uint32_t *tfs, + const uint32_t *doc_lens, size_t start, uint32_t count, + uint32_t &max_delta, uint32_t &max_tf, uint32_t &max_dl) { + __m128i vmax_delta = _mm_setzero_si128(); + __m128i vmax_tf = _mm_setzero_si128(); + __m128i vmax_dl = _mm_setzero_si128(); + for (uint32_t i = 0; i < count; i += 4) { + vmax_delta = _mm_max_epu32( + vmax_delta, + _mm_load_si128(reinterpret_cast(&deltas[start + i]))); + vmax_tf = _mm_max_epu32( + vmax_tf, + _mm_loadu_si128(reinterpret_cast(&tfs[start + i]))); + vmax_dl = _mm_max_epu32( + vmax_dl, _mm_loadu_si128( + reinterpret_cast(&doc_lens[start + i]))); + } + // Horizontal max: reduce 4 lanes to scalar + auto hmax = [](__m128i v) -> uint32_t { + v = _mm_max_epu32(v, _mm_shuffle_epi32(v, _MM_SHUFFLE(2, 3, 0, 1))); + v = _mm_max_epu32(v, _mm_shuffle_epi32(v, _MM_SHUFFLE(1, 0, 3, 2))); + return static_cast(_mm_extract_epi32(v, 0)); + }; + max_delta = hmax(vmax_delta); + max_tf = hmax(vmax_tf); + max_dl = hmax(vmax_dl); +} + +// ------------------------------------------------------------ +// sse41_pack_uint32_128 +// ------------------------------------------------------------ + +void sse41_pack_uint32_128(const uint32_t *in, uint8_t bitwidth, uint8_t *out) { + const size_t total_bytes = + BitPackedPostingList::simd_packed_byte_size(bitwidth); + if ((reinterpret_cast(out) & 15) == 0) { + FastPForLib::SIMD_fastpackwithoutmask_32( + in, reinterpret_cast<__m128i *>(out), bitwidth); + } else { + alignas(16) __m128i simd_out[32]; + FastPForLib::SIMD_fastpackwithoutmask_32(in, simd_out, bitwidth); + std::memcpy(out, simd_out, total_bytes); + } +} + +// ------------------------------------------------------------ +// sse41_unpack_uint32_128 +// ------------------------------------------------------------ + +void sse41_unpack_uint32_128(const uint8_t *in, uint8_t bitwidth, + uint32_t *out) { + if ((reinterpret_cast(in) & 15) == 0) { + FastPForLib::SIMD_fastunpack_32(reinterpret_cast(in), out, + bitwidth); + } else { + const size_t packed_bytes = + BitPackedPostingList::simd_packed_byte_size(bitwidth); + alignas(16) __m128i simd_in[32]; + std::memcpy(simd_in, in, packed_bytes); + FastPForLib::SIMD_fastunpack_32(simd_in, out, bitwidth); + } +} + +// ------------------------------------------------------------ +// sse41_prefix_sum_128 +// ------------------------------------------------------------ + +void sse41_prefix_sum_128(const uint32_t *deltas, uint32_t min_doc_id, + uint32_t /*count*/, uint32_t *out) { + __m128i carry = _mm_set1_epi32(static_cast(min_doc_id) - + static_cast(deltas[0])); + + for (uint32_t g = 0; g < 32; ++g) { + __m128i v = + _mm_load_si128(reinterpret_cast(&deltas[g * 4])); + + // In-register prefix-sum for 4 elements + __m128i shifted1 = _mm_slli_si128(v, 4); + v = _mm_add_epi32(v, shifted1); + __m128i shifted2 = _mm_slli_si128(v, 8); + v = _mm_add_epi32(v, shifted2); + + // Add carry from previous group + v = _mm_add_epi32(v, carry); + + _mm_store_si128(reinterpret_cast<__m128i *>(&out[g * 4]), v); + + // Broadcast the last element as carry for next group + carry = _mm_shuffle_epi32(v, _MM_SHUFFLE(3, 3, 3, 3)); + } +} + +// ------------------------------------------------------------ +// sse41_find_first_ge +// ------------------------------------------------------------ + +size_t sse41_find_first_ge(const uint32_t *arr, uint32_t size, uint32_t target, + size_t start) { + const __m128i vtarget = _mm_set1_epi32(static_cast(target)); + const __m128i sign_bit = _mm_set1_epi32(static_cast(0x80000000u)); + const __m128i starget = _mm_xor_si128(vtarget, sign_bit); + + size_t i = start; + // Scalar until aligned to 4-element boundary + for (; i < size && (i & 3); ++i) { + if (arr[i] >= target) return i; + } + // SIMD scan: 4 elements at a time + for (; i + 4 <= size; i += 4) { + __m128i v = _mm_load_si128(reinterpret_cast(&arr[i])); + __m128i sv = _mm_xor_si128(v, sign_bit); + __m128i cmp = _mm_cmplt_epi32(sv, starget); + int mask = _mm_movemask_ps(_mm_castsi128_ps(cmp)); + if (mask != 0xF) { + int first = ctz_u32(static_cast(~mask)); + return i + first; + } + } + // Scalar tail + for (; i < size; ++i) { + if (arr[i] >= target) return i; + } + return size; +} + +} // namespace zvec::fts::simd + +#else // !defined(__SSE4_1__) && !(defined(_MSC_VER) && (defined(_M_X64) || + // defined(_M_IX86))) + +// Stub implementations when SSE4.1 is not available at compile time. +// The runtime dispatch layer (bitpacked_simd_dispatch.cc) will never call +// these on non-SSE4.1 machines, but the linker still needs the symbols. + +namespace zvec::fts::simd { + +void sse41_max_128(const uint32_t *, const uint32_t *, const uint32_t *, size_t, + uint32_t, uint32_t &max_delta, uint32_t &max_tf, + uint32_t &max_dl) { + max_delta = 0; + max_tf = 0; + max_dl = 0; +} + +void sse41_pack_uint32_128(const uint32_t *, uint8_t, uint8_t *) {} + +void sse41_unpack_uint32_128(const uint8_t *, uint8_t, uint32_t *) {} + +void sse41_prefix_sum_128(const uint32_t *, uint32_t, uint32_t, uint32_t *) {} + +size_t sse41_find_first_ge(const uint32_t *, uint32_t size, uint32_t, size_t) { + return size; +} + +} // namespace zvec::fts::simd + +#endif // defined(__SSE4_1__) diff --git a/src/db/index/column/fts_column/posting/bitpacked_simd_sse41.h b/src/db/index/column/fts_column/posting/bitpacked_simd_sse41.h new file mode 100644 index 000000000..ca82514c4 --- /dev/null +++ b/src/db/index/column/fts_column/posting/bitpacked_simd_sse41.h @@ -0,0 +1,50 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace zvec::fts::simd { + +/// Compute element-wise max of 128 uint32 values across three arrays using +/// SSE4.1 _mm_max_epu32. \p deltas must be 16-byte aligned; \p tfs and +/// \p doc_lens may be unaligned. +void sse41_max_128(const uint32_t *deltas, const uint32_t *tfs, + const uint32_t *doc_lens, size_t start, uint32_t count, + uint32_t &max_delta, uint32_t &max_tf, uint32_t &max_dl); + +/// Pack 128 uint32 values at \p bitwidth bits each into \p out using SSE SIMD +/// interleaved layout (SIMD_fastpackwithoutmask_32). +void sse41_pack_uint32_128(const uint32_t *in, uint8_t bitwidth, uint8_t *out); + +/// Unpack 128 uint32 values at \p bitwidth bits each from \p in using SSE SIMD +/// interleaved layout (SIMD_fastunpack_32). +void sse41_unpack_uint32_128(const uint8_t *in, uint8_t bitwidth, + uint32_t *out); + +/// Compute prefix-sum over \p count (must be 128) delta values, producing +/// absolute doc_ids. Uses SSE2 SIMD prefix-sum with carry propagation. +/// \p deltas must be 16-byte aligned; \p out must be 16-byte aligned. +void sse41_prefix_sum_128(const uint32_t *deltas, uint32_t min_doc_id, + uint32_t count, uint32_t *out); + +/// Find the first index i in arr[start..size) where arr[i] >= target. +/// Uses SSE2 SIMD 4-wide comparison with unsigned-to-signed trick. +/// \p arr must be 16-byte aligned. +size_t sse41_find_first_ge(const uint32_t *arr, uint32_t size, uint32_t target, + size_t start); + +} // namespace zvec::fts::simd diff --git a/src/db/index/column/fts_column/tokenizer/jieba_tokenizer.cc b/src/db/index/column/fts_column/tokenizer/jieba_tokenizer.cc new file mode 100644 index 000000000..77c084f6c --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/jieba_tokenizer.cc @@ -0,0 +1,177 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jieba_tokenizer.h" +#include +#include +// Drop the ERROR macro that cppjieba's transitive defines so it +// does not collide with zvec::GlobalConfig::LogLevel::ERROR below. +#ifdef ERROR +#undef ERROR +#endif +#include + +namespace zvec::fts { + +static std::string get_string_or_default(const ailego::JsonObject &config, + const char *key, + const std::string &default_value) { + auto val = config[key]; + if (val.is_string()) { + std::string result = val.as_string().c_str(); + if (!result.empty()) { + return result; + } + } + return default_value; +} + +// Priority: per-field config > ZVEC_JIEBA_DICT_DIR > GlobalConfig. +static std::string resolve_jieba_dict_dir(const ailego::JsonObject &config) { + std::string dir = get_string_or_default(config, "jieba_dict_dir", ""); + if (!dir.empty()) { + return dir; + } + if (const char *env = std::getenv("ZVEC_JIEBA_DICT_DIR"); env && *env) { + return env; + } + return GlobalConfig::Instance().jieba_dict_dir(); +} + +bool JiebaTokenizer::init(const ailego::JsonObject &config) { + std::string user_dict_path = + get_string_or_default(config, "user_dict_path", ""); + + std::string mode_str = get_string_or_default(config, "cut_mode", "search"); + if (mode_str == "search") { + cut_mode_ = CutMode::kSearch; + } else if (mode_str == "mix") { + cut_mode_ = CutMode::kMix; + } else if (mode_str == "full") { + cut_mode_ = CutMode::kFull; + } else if (mode_str == "hmm") { + cut_mode_ = CutMode::kHmm; + } else { + LOG_ERROR("JiebaTokenizer: unknown cut_mode '%s'", mode_str.c_str()); + return false; + } + + bool needs_dict = cut_mode_ != CutMode::kHmm; + bool needs_model = cut_mode_ != CutMode::kFull; + + std::string dict_dir = resolve_jieba_dict_dir(config); + if ((needs_dict || needs_model) && dict_dir.empty()) { + LOG_ERROR( + "JiebaTokenizer: jieba_dict_dir not configured. Set via " + "extra_params.jieba_dict_dir, ZVEC_JIEBA_DICT_DIR env var, " + "or zvec.set_default_jieba_dict_dir() / " + "zvec.init(jieba_dict_dir=...)."); + return false; + } + + std::string dict_path = needs_dict ? dict_dir + "/jieba.dict.utf8" : ""; + std::string model_path = needs_model ? dict_dir + "/hmm_model.utf8" : ""; + + reset(); + + try { + if (needs_dict) { + dict_trie_ = + std::make_unique(dict_path, user_dict_path); + } + if (needs_model) { + hmm_model_ = std::make_unique(model_path); + } + switch (cut_mode_) { + case CutMode::kSearch: + query_seg_ = std::make_unique(dict_trie_.get(), + hmm_model_.get()); + break; + case CutMode::kMix: + mix_seg_ = std::make_unique(dict_trie_.get(), + hmm_model_.get()); + break; + case CutMode::kFull: + full_seg_ = std::make_unique(dict_trie_.get()); + break; + case CutMode::kHmm: + hmm_seg_ = std::make_unique(hmm_model_.get()); + break; + } + } catch (const std::exception &e) { + LOG_ERROR("JiebaTokenizer init failed: %s", e.what()); + reset(); + return false; + } + + initialized_ = true; + LOG_INFO("JiebaTokenizer init success. dict_dir[%s] cut_mode[%s]", + dict_dir.c_str(), mode_str.c_str()); + return true; +} + +JiebaTokenizer::~JiebaTokenizer() = default; + +void JiebaTokenizer::reset() { + query_seg_.reset(); + mix_seg_.reset(); + full_seg_.reset(); + hmm_seg_.reset(); + dict_trie_.reset(); + hmm_model_.reset(); + initialized_ = false; +} + +std::vector JiebaTokenizer::tokenize(const std::string &text) const { + std::vector tokens; + if (!initialized_ || text.empty()) { + return tokens; + } + + std::vector words; + switch (cut_mode_) { + case CutMode::kSearch: + query_seg_->Cut(text, words, true); + break; + case CutMode::kMix: + mix_seg_->Cut(text, words, true); + break; + case CutMode::kFull: + full_seg_->Cut(text, words); + break; + case CutMode::kHmm: + hmm_seg_->Cut(text, words); + break; + } + + tokens.reserve(words.size()); + // Position = output sequence index, not cppjieba's unicode_offset: + // overlapping sub-words emitted after long parents share unicode_offset, + // which breaks PhraseDocIterator's strict anchor+1 adjacency check. + uint32_t seq = 0; + for (const auto &word : words) { + if (word.word.empty()) { + continue; + } + Token token; + token.text = word.word; + token.offset = word.offset; + token.position = seq++; + tokens.push_back(std::move(token)); + } + + return tokens; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/jieba_tokenizer.h b/src/db/index/column/fts_column/tokenizer/jieba_tokenizer.h new file mode 100644 index 000000000..591551ab8 --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/jieba_tokenizer.h @@ -0,0 +1,86 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "tokenizer.h" + +namespace zvec::fts { + +/*! Jieba tokenizer + * + * Wraps cppjieba's low-level segmenters to provide Chinese (and mixed + * Chinese/English) word segmentation. Uses CutForSearch (QuerySegment) by + * default, which produces the finer granularity used for indexing/search. + * + * After init(), the active segmenter is thread-safe for concurrent Cut + * calls, so tokenize() can be invoked from multiple threads. + */ +class JiebaTokenizer : public Tokenizer { + public: + JiebaTokenizer() = default; + ~JiebaTokenizer() override; + + // Non-copyable + JiebaTokenizer(const JiebaTokenizer &) = delete; + JiebaTokenizer &operator=(const JiebaTokenizer &) = delete; + + // JSON config keys: + // "jieba_dict_dir" - directory containing jieba.dict.utf8 + hmm_model.utf8 + // "user_dict_path" - optional user.dict.utf8 + // "cut_mode" - "search" (default) | "mix" | "full" | "hmm" + // + // jieba_dict_dir resolution: per-field > ZVEC_JIEBA_DICT_DIR > + // zvec::GlobalConfig::jieba_dict_dir() (set by SDK on import or via init). + // Stop-word filtering belongs to a TokenFilter, not here. + bool init(const ailego::JsonObject &config) override; + + std::vector tokenize(const std::string &text) const override; + + const char *name() const override { + return "jieba"; + } + + bool is_valid() const { + return initialized_; + } + + // Move-only (unique_ptr members) + JiebaTokenizer(JiebaTokenizer &&) = default; + JiebaTokenizer &operator=(JiebaTokenizer &&) = default; + + private: + enum class CutMode { kSearch, kMix, kFull, kHmm }; + + // Release segmenters first (they hold raw pointers into dict_trie_ / + // hmm_model_), then release the underlying dict/model. + void reset(); + + // Declared before segmenters: reverse-order destruction keeps the raw + // pointers held by segmenters valid until the segmenters die. + std::unique_ptr dict_trie_; + std::unique_ptr hmm_model_; + std::unique_ptr query_seg_; + std::unique_ptr mix_seg_; + std::unique_ptr full_seg_; + std::unique_ptr hmm_seg_; + + CutMode cut_mode_{CutMode::kSearch}; + bool initialized_{false}; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/standard_tokenizer.cc b/src/db/index/column/fts_column/tokenizer/standard_tokenizer.cc new file mode 100644 index 000000000..122d9878b --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/standard_tokenizer.cc @@ -0,0 +1,76 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "standard_tokenizer.h" +#include + +namespace zvec::fts { + +bool StandardTokenizer::init(const ailego::JsonObject &config) { + // Read optional max_token_length; keep default (255) if not present or + // if the provided value is zero. + auto length_val = config["max_token_length"]; + if (length_val.is_integer()) { + uint32_t configured_length = static_cast(length_val.as_integer()); + if (configured_length > 0) { + max_token_length_ = configured_length; + } + } + return true; +} + +std::vector StandardTokenizer::tokenize(const std::string &text) const { + std::vector tokens; + uint32_t position = 0; + size_t index = 0; + const size_t text_length = text.size(); + + while (index < text_length) { + // Skip non-alphanumeric characters (delimiters / punctuation). + while (index < text_length && + !std::isalnum(static_cast(text[index]))) { + ++index; + } + if (index >= text_length) { + break; + } + + // Mark the start of an alphanumeric run. + const uint32_t token_start = static_cast(index); + + // Advance to the end of the alphanumeric run. + while (index < text_length && + std::isalnum(static_cast(text[index]))) { + ++index; + } + + const uint32_t token_length = static_cast(index) - token_start; + + // Discard tokens that exceed the configured length limit. + if (token_length > max_token_length_) { + ++position; + continue; + } + + Token token; + token.text = text.substr(token_start, token_length); + token.offset = token_start; + token.position = position++; + tokens.push_back(std::move(token)); + } + + return tokens; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/standard_tokenizer.h b/src/db/index/column/fts_column/tokenizer/standard_tokenizer.h new file mode 100644 index 000000000..48a3c25e7 --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/standard_tokenizer.h @@ -0,0 +1,48 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "tokenizer.h" + +namespace zvec::fts { + +/*! Standard tokenizer + * Splits text on non-alphanumeric characters (punctuation, whitespace, etc.) + * and discards the delimiters. Produces lowercase-ready tokens composed of + * letters and digits only. + */ +class StandardTokenizer : public Tokenizer { + public: + /*! Initialise from JSON config. + * Supported keys: + * "max_token_length" (uint32, default 255): tokens longer than this limit + * are silently discarded. + * Always returns true. + */ + bool init(const ailego::JsonObject &config) override; + + std::vector tokenize(const std::string &text) const override; + + const char *name() const override { + return "standard"; + } + + private: + // Tokens whose byte length exceeds this value are discarded. + uint32_t max_token_length_{255}; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/token_filter.cc b/src/db/index/column/fts_column/tokenizer/token_filter.cc new file mode 100644 index 000000000..ffcb9b961 --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/token_filter.cc @@ -0,0 +1,32 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "token_filter.h" +#include +#include + +namespace zvec::fts { + +std::vector LowercaseTokenFilter::filter( + std::vector tokens) const { + for (auto &token : tokens) { + std::transform(token.text.begin(), token.text.end(), token.text.begin(), + [](unsigned char character) { + return static_cast(std::tolower(character)); + }); + } + return tokens; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/token_filter.h b/src/db/index/column/fts_column/tokenizer/token_filter.h new file mode 100644 index 000000000..ce11fbe14 --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/token_filter.h @@ -0,0 +1,57 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "tokenizer.h" + +namespace zvec::fts { + +/*! Token Filter abstract interface + * Post-process tokenization results, such as case conversion, stopword + * filtering, etc. + */ +class TokenFilter { + public: + virtual ~TokenFilter() = default; + + /*! Filter/transform a list of tokens. + * \param tokens input token list (may be modified in place) + * \return processed token list + */ + virtual std::vector filter(std::vector tokens) const = 0; + + /*! Return filter name + */ + virtual const char *name() const = 0; +}; + +using TokenFilterPtr = std::shared_ptr; + +/*! Lowercase Token Filter + * Convert all token text to lowercase (only handles ASCII characters) + */ +class LowercaseTokenFilter : public TokenFilter { + public: + std::vector filter(std::vector tokens) const override; + + const char *name() const override { + return "lowercase"; + } +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/tokenizer.h b/src/db/index/column/fts_column/tokenizer/tokenizer.h new file mode 100644 index 000000000..efc7906fa --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/tokenizer.h @@ -0,0 +1,64 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace zvec::fts { + +/*! A single token in the tokenization result + */ +struct Token { + // token text content + std::string text; + // start byte offset of token in original text + uint32_t offset{0}; + // token position in document (which word, starting from 0) + uint32_t position{0}; +}; + +/*! Abstract tokenizer interface + * All tokenizer implementations must inherit from this interface + */ +class Tokenizer { + public: + virtual ~Tokenizer() = default; + + /*! Initialise the tokenizer from a JSON configuration object. + * Must be called once before tokenize(). + * \param config JSON object containing tokenizer-specific parameters. + * \return true on success, false on failure. + */ + virtual bool init(const ailego::JsonObject &config) = 0; + + /*! Tokenize input text + * \param text UTF-8 encoded input text + * \return Tokenization result list, sorted by position in ascending + * order + */ + virtual std::vector tokenize(const std::string &text) const = 0; + + /*! Return tokenizer name + */ + virtual const char *name() const = 0; +}; + +using TokenizerPtr = std::shared_ptr; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/tokenizer_factory.cc b/src/db/index/column/fts_column/tokenizer/tokenizer_factory.cc new file mode 100644 index 000000000..ec775678e --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/tokenizer_factory.cc @@ -0,0 +1,104 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tokenizer_factory.h" +#include +#include +#include "jieba_tokenizer.h" +#include "standard_tokenizer.h" +#include "whitespace_tokenizer.h" + +namespace zvec::fts { + +TokenizerPipelinePtr TokenizerFactory::create(const FtsIndexParams ¶ms) { + // Parse extra_params JSON string into a JsonObject. + // Empty string is treated as an empty object; malformed JSON fails. + ailego::JsonObject extra_json; + if (!params.extra_params.empty()) { + ailego::JsonValue parsed; + if (!parsed.parse(params.extra_params.c_str())) { + LOG_ERROR("[TokenizerFactory] failed to parse extra_params JSON: %s", + params.extra_params.c_str()); + return nullptr; + } + if (!parsed.is_object()) { + LOG_ERROR("[TokenizerFactory] extra_params is not a JSON object: %s", + params.extra_params.c_str()); + return nullptr; + } + extra_json = parsed.as_object(); + } + + TokenizerPtr tokenizer = create_tokenizer(params.tokenizer_name, extra_json); + if (!tokenizer) { + LOG_ERROR("[TokenizerFactory] failed to create tokenizer: %s", + params.tokenizer_name.c_str()); + return nullptr; + } + + std::vector filters; + for (const auto &filter_name : params.filters) { + TokenFilterPtr filter = create_filter(filter_name); + if (!filter) { + LOG_ERROR("[TokenizerFactory] failed to create filter: %s", + filter_name.c_str()); + return nullptr; + } + filters.push_back(std::move(filter)); + } + + return std::make_shared(std::move(tokenizer), + std::move(filters)); +} + +std::vector TokenizerPipeline::process(const std::string &text) const { + std::vector tokens = tokenizer_->tokenize(text); + for (const auto &filter : filters_) { + tokens = filter->filter(std::move(tokens)); + } + return tokens; +} + +TokenizerPtr TokenizerFactory::create_tokenizer( + const std::string &tokenizer_name, const ailego::JsonObject &extra_json) { + TokenizerPtr tokenizer; + if (tokenizer_name.empty() || tokenizer_name == "standard") { + tokenizer = std::make_shared(); + } else if (tokenizer_name == "jieba") { + tokenizer = std::make_shared(); + } else if (tokenizer_name == "whitespace") { + tokenizer = std::make_shared(); + } else { + LOG_ERROR("[TokenizerFactory] unknown tokenizer name: %s", + tokenizer_name.c_str()); + return nullptr; + } + + if (!tokenizer->init(extra_json)) { + LOG_ERROR("[TokenizerFactory] failed to init tokenizer: %s", + tokenizer_name.c_str()); + return nullptr; + } + return tokenizer; +} + +TokenFilterPtr TokenizerFactory::create_filter(const std::string &filter_name) { + if (filter_name == "lowercase") { + return std::make_shared(); + } + LOG_ERROR("[TokenizerFactory] unknown filter name: %s", filter_name.c_str()); + return nullptr; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/tokenizer_factory.h b/src/db/index/column/fts_column/tokenizer/tokenizer_factory.h new file mode 100644 index 000000000..f118f8e1a --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/tokenizer_factory.h @@ -0,0 +1,64 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "token_filter.h" +#include "tokenizer.h" +#include "../fts_types.h" + +namespace zvec::fts { + +/*! Tokenizer pipeline: contains one tokenizer and a set of token filters + * Execution order: tokenizer → filter[0] → filter[1] → ... + */ +class TokenizerPipeline { + public: + TokenizerPipeline(TokenizerPtr tokenizer, std::vector filters) + : tokenizer_(std::move(tokenizer)), filters_(std::move(filters)) {} + + /*! Tokenize text and apply all filters + */ + std::vector process(const std::string &text) const; + + private: + TokenizerPtr tokenizer_; + std::vector filters_; +}; + +using TokenizerPipelinePtr = std::shared_ptr; + +/*! Tokenizer factory + * Create TokenizerPipeline based on FtsIndexParams configuration. + */ +class TokenizerFactory { + public: + /*! Create tokenizer pipeline from FtsIndexParams. + * \param params FTS index parameters containing tokenizer_name, filters, + * and extra_params (JSON string for tokenizer-specific + * configuration). + * \return Tokenizer pipeline, returns nullptr on failure + */ + static TokenizerPipelinePtr create(const FtsIndexParams ¶ms); + + private: + static TokenizerPtr create_tokenizer(const std::string &tokenizer_name, + const ailego::JsonObject &extra_json); + static TokenFilterPtr create_filter(const std::string &filter_name); +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/tokenizer_pipeline_manager.cc b/src/db/index/column/fts_column/tokenizer/tokenizer_pipeline_manager.cc new file mode 100644 index 000000000..b3261319d --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/tokenizer_pipeline_manager.cc @@ -0,0 +1,124 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tokenizer_pipeline_manager.h" +#include +#include +#include + +namespace zvec::fts { + +// ============================================================ +// Key generation +// ============================================================ + +std::string TokenizerPipelineManager::make_key(const FtsIndexParams ¶ms) { + // Build a stable cache key from the three FtsIndexParams fields. + // Format: "tokenizer_name|filter0,filter1,...|extra_params_json" + std::string key; + key += params.tokenizer_name; + key += "|"; + for (size_t i = 0; i < params.filters.size(); ++i) { + if (i > 0) { + key += ","; + } + key += params.filters[i]; + } + key += "|"; + key += params.extra_params; + return key; +} + +// ============================================================ +// acquire +// ============================================================ + +TokenizerPipelinePtr TokenizerPipelineManager::acquire( + const FtsIndexParams ¶ms) { + const std::string key = make_key(params); + + // Fast path: pipeline already exists. + { + std::unique_lock lock(mutex_); + auto it = pipelines_.find(key); + if (it != pipelines_.end()) { + it->second.ref_count++; + LOG_DEBUG( + "TokenizerPipelineManager: reuse pipeline key[%s] ref_count[%d]", + key.c_str(), it->second.ref_count); + return it->second.pipeline; + } + } + + // Create the pipeline outside of the lock to avoid blocking other + // acquire/release calls during the (potentially expensive) construction. + TokenizerPipelinePtr pipeline = TokenizerFactory::create(params); + if (!pipeline) { + LOG_ERROR( + "TokenizerPipelineManager: failed to create pipeline for " + "tokenizer[%s] key[%s]", + params.tokenizer_name.c_str(), key.c_str()); + return nullptr; + } + + // Re-acquire the lock and check whether another thread has already + // created a pipeline with the same key while we were constructing ours. + std::unique_lock lock(mutex_); + auto it = pipelines_.find(key); + if (it != pipelines_.end()) { + it->second.ref_count++; + LOG_DEBUG( + "TokenizerPipelineManager: another thread created pipeline first, " + "discard newly created one. key[%s] ref_count[%d]", + key.c_str(), it->second.ref_count); + return it->second.pipeline; + } + + Entry entry; + entry.pipeline = pipeline; + entry.ref_count = 1; + pipelines_.emplace(key, std::move(entry)); + + LOG_DEBUG("TokenizerPipelineManager: created pipeline key[%s]", key.c_str()); + return pipeline; +} + +// ============================================================ +// release +// ============================================================ + +void TokenizerPipelineManager::release(const FtsIndexParams ¶ms) { + const std::string key = make_key(params); + + std::unique_lock lock(mutex_); + + auto it = pipelines_.find(key); + if (it == pipelines_.end()) { + LOG_WARN("TokenizerPipelineManager: release called for unknown key[%s]", + key.c_str()); + return; + } + + it->second.ref_count--; + LOG_DEBUG("TokenizerPipelineManager: release key[%s] ref_count[%d]", + key.c_str(), it->second.ref_count); + + if (it->second.ref_count <= 0) { + pipelines_.erase(it); + LOG_DEBUG("TokenizerPipelineManager: destroyed pipeline key[%s]", + key.c_str()); + } +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/tokenizer_pipeline_manager.h b/src/db/index/column/fts_column/tokenizer/tokenizer_pipeline_manager.h new file mode 100644 index 000000000..9c975a062 --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/tokenizer_pipeline_manager.h @@ -0,0 +1,88 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "tokenizer_factory.h" + +namespace zvec::fts { + +/*! + * TokenizerPipelineManager + * + * Global singleton that creates, caches and reference-counts + * TokenizerPipeline instances. Two callers that request a pipeline with + * the same FtsIndexParams will receive the same shared_ptr, and the + * underlying pipeline is destroyed only when the last caller releases it. + * + * The cache key is built from tokenizer_name, filters and extra_params + * fields of FtsIndexParams, producing a deterministic string. + * + * Thread-safety: all public methods are protected by a std::shared_mutex. + * acquire() and release() take an exclusive (write) lock; the map itself is + * never read concurrently with a write. + */ +class TokenizerPipelineManager + : public ailego::Singleton { + public: + /*! + * Build a canonical cache key from the given FtsIndexParams. + * The key is deterministic: tokenizer_name + sorted filters + extra_params. + * + * \param params FTS index parameters + * \return Canonical string key + */ + static std::string make_key(const FtsIndexParams ¶ms); + + /*! + * Acquire a shared pipeline for the given configuration. + * If a pipeline with the same key already exists its reference count is + * incremented and the existing instance is returned. Otherwise a new + * pipeline is created via TokenizerFactory::create(). + * + * \param params FTS index parameters + * \return Shared pipeline pointer, or nullptr on failure + */ + TokenizerPipelinePtr acquire(const FtsIndexParams ¶ms); + + /*! + * Release a previously acquired pipeline identified by its FtsIndexParams. + * Decrements the reference count; when it reaches zero the entry is + * removed from the map and the pipeline is destroyed. + * + * \param params Same FtsIndexParams used during acquire() + */ + void release(const FtsIndexParams ¶ms); + + protected: + //! Constructor (protected, accessed via Singleton::Instance()) + TokenizerPipelineManager() = default; + friend class ailego::Singleton; + + private: + //! Internal map entry + struct Entry { + TokenizerPipelinePtr pipeline; + int ref_count{0}; + }; + + std::shared_mutex mutex_; + std::unordered_map pipelines_; +}; + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/whitespace_tokenizer.cc b/src/db/index/column/fts_column/tokenizer/whitespace_tokenizer.cc new file mode 100644 index 000000000..aad42fc7d --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/whitespace_tokenizer.cc @@ -0,0 +1,56 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "whitespace_tokenizer.h" +#include + +namespace zvec::fts { + +std::vector WhitespaceTokenizer::tokenize( + const std::string &text) const { + std::vector tokens; + uint32_t position = 0; + size_t index = 0; + const size_t text_length = text.size(); + + while (index < text_length) { + // skip whitespace characters + while (index < text_length && + std::isspace(static_cast(text[index]))) { + ++index; + } + if (index >= text_length) { + break; + } + + // find token start position + const uint32_t token_start = static_cast(index); + + // find token end position + while (index < text_length && + !std::isspace(static_cast(text[index]))) { + ++index; + } + + Token token; + token.text = text.substr(token_start, index - token_start); + token.offset = token_start; + token.position = position++; + tokens.push_back(std::move(token)); + } + + return tokens; +} + +} // namespace zvec::fts diff --git a/src/db/index/column/fts_column/tokenizer/whitespace_tokenizer.h b/src/db/index/column/fts_column/tokenizer/whitespace_tokenizer.h new file mode 100644 index 000000000..e2668c671 --- /dev/null +++ b/src/db/index/column/fts_column/tokenizer/whitespace_tokenizer.h @@ -0,0 +1,39 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "tokenizer.h" + +namespace zvec::fts { + +/*! Whitespace tokenizer + * Split text by whitespace characters (space, tab, newline, etc.), used as + * default tokenizer + */ +class WhitespaceTokenizer : public Tokenizer { + public: + // WhitespaceTokenizer requires no configuration; always succeeds. + bool init(const ailego::JsonObject & /*config*/) override { + return true; + } + + std::vector tokenize(const std::string &text) const override; + + const char *name() const override { + return "whitespace"; + } +}; + +} // namespace zvec::fts diff --git a/src/db/index/common/doc.cc b/src/db/index/common/doc.cc index 0405eac1d..e298af97b 100644 --- a/src/db/index/common/doc.cc +++ b/src/db/index/common/doc.cc @@ -1281,21 +1281,45 @@ Status VectorQuery::validate_and_sanitize(const FieldSchema *schema) { kMaxOutputFieldSize); } + // Mutual exclusion: fts_ and vector fields cannot be set together. + if (fts_.has_value()) { + if (!query_vector_.empty() || !query_sparse_indices_.empty()) { + return Status::InvalidArgument( + "Invalid query: fts and vector query fields " + "(query_vector/query_sparse_indices) are mutually exclusive"); + } + } + if (schema == nullptr) { + if (fts_.has_value()) { + // FTS query requires a valid field_name_ that resolves to an FTS field. + return Status::InvalidArgument( + "Invalid query: fts requires a valid FTS field, but field[", + field_name_, "] does not exist in the collection"); + } if (query_vector_.empty() && query_sparse_indices_.empty()) { - // Scalar-only filter query + // Scalar-only filter query (no field_name_ needed) return Status::OK(); - } else { - // If a query vector was provided, the field must exist as a vector field - // since we are performing a vector similarity search. + } + // If a query vector was provided, the field must exist as a vector field. + return Status::InvalidArgument( + "Invalid query: query vector is provided, but query field[", + field_name_, + "] does not exist or is not a vector field in the collection"); + } + + // FTS query: field must be an FTS-indexed field. + if (fts_.has_value()) { + if (schema->index_type() != IndexType::FTS) { return Status::InvalidArgument( - "Invalid query: query vector is provided, but query field[", - field_name_, - "] does not exist or is not a vector field in the collection"); + "Invalid query: fts requires an FTS-indexed field, but field[", + field_name_, "] has index type ", + IndexTypeCodeBook::AsString(schema->index_type())); } + return Status::OK(); } - // Vector query + // Vector query: field must be a vector field. if (schema->is_dense_vector()) { // Validate dimension auto dim = schema->dimension(); diff --git a/src/db/index/common/index_params.cc b/src/db/index/common/index_params.cc index cb06f0779..0b696956a 100644 --- a/src/db/index/common/index_params.cc +++ b/src/db/index/common/index_params.cc @@ -12,8 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include +#include #include +#include "db/index/column/fts_column/fts_pipeline.h" +#include "db/index/column/fts_column/fts_types.h" +#include "db/index/column/fts_column/tokenizer/tokenizer_pipeline_manager.h" #include "type_helper.h" namespace zvec { @@ -38,4 +44,93 @@ std::string VectorIndexParams::vector_index_params_to_string( return oss.str(); } +// ============================================================ +// FtsIndexParams — helpers +// ============================================================ + +static fts::FtsIndexParams to_internal(const FtsIndexParams ¶ms) { + fts::FtsIndexParams p; + p.tokenizer_name = params.tokenizer_name(); + p.filters = params.filters(); + p.extra_params = params.extra_params(); + return p; +} + +// ============================================================ +// FtsIndexParams — opaque pipeline state (Pimpl) +// ============================================================ + +namespace detail { +struct FtsState { + std::once_flag once; + std::shared_ptr pipeline; + bool created{false}; +}; + +struct FtsPipelineHelper { + static std::unique_ptr &state(FtsIndexParams &p) { + return p.state_; + } +}; +} // namespace detail + +// ============================================================ +// FtsIndexParams — ctor / dtor / move +// ============================================================ + +FtsIndexParams::FtsIndexParams(std::string tokenizer_name, + std::vector filters, + std::string extra_params) + : IndexParams(IndexType::FTS), + tokenizer_name_(std::move(tokenizer_name)), + filters_(std::move(filters)), + extra_params_(std::move(extra_params)), + state_(std::make_unique()) {} + +FtsIndexParams::FtsIndexParams(FtsIndexParams &&other) noexcept + : IndexParams(IndexType::FTS), + tokenizer_name_(std::move(other.tokenizer_name_)), + filters_(std::move(other.filters_)), + extra_params_(std::move(other.extra_params_)), + state_(std::move(other.state_)) {} + +FtsIndexParams::~FtsIndexParams() { + if (state_ && state_->created) { + auto internal = to_internal(*this); + fts::TokenizerPipelineManager::Instance().release(internal); + } +} + +// ============================================================ +// FtsIndexParams — pipeline acquisition (internal) +// ============================================================ + +namespace detail { + +Result> AcquireFtsPipeline( + FtsIndexParams ¶ms) { + auto &state_uptr = FtsPipelineHelper::state(params); + if (!state_uptr) { + // Lazily reconstruct after a move-from; not thread-safe vs. a concurrent + // move on the same instance, but moves on a live instance already need + // external synchronisation. + state_uptr = std::make_unique(); + } + auto &st = *state_uptr; + std::call_once(st.once, [&]() { + auto internal = to_internal(params); + st.pipeline = fts::TokenizerPipelineManager::Instance().acquire(internal); + if (st.pipeline) { + st.created = true; + } + }); + if (!st.pipeline) { + return tl::make_unexpected( + Status::InternalError("Failed to create tokenizer pipeline")); + } + return st.pipeline; +} + +} // namespace detail + } // namespace zvec \ No newline at end of file diff --git a/src/db/index/common/proto_converter.cc b/src/db/index/common/proto_converter.cc index d58dc1897..109a09fe0 100644 --- a/src/db/index/common/proto_converter.cc +++ b/src/db/index/common/proto_converter.cc @@ -144,6 +144,28 @@ proto::InvertIndexParams ProtoConverter::ToPb(const InvertIndexParams *params) { return params_pb; } +// FtsIndexParams +FtsIndexParams::Ptr ProtoConverter::FromPb( + const proto::FtsIndexParams ¶ms_pb) { + std::vector filters; + filters.reserve(params_pb.filters_size()); + for (const auto &filter : params_pb.filters()) { + filters.push_back(filter); + } + return std::make_shared( + params_pb.tokenizer_name(), std::move(filters), params_pb.extra_params()); +} + +proto::FtsIndexParams ProtoConverter::ToPb(const FtsIndexParams *params) { + proto::FtsIndexParams params_pb; + params_pb.set_tokenizer_name(params->tokenizer_name()); + for (const auto &filter : params->filters()) { + params_pb.add_filters(filter); + } + params_pb.set_extra_params(params->extra_params()); + return params_pb; +} + // FieldSchema FieldSchema::Ptr ProtoConverter::FromPb(const proto::FieldSchema &schema_pb) { auto schema = std::make_shared(); @@ -215,6 +237,8 @@ IndexParams::Ptr ProtoConverter::FromPb(const proto::IndexParams ¶ms_pb) { return ProtoConverter::FromPb(params_pb.hnsw_rabitq()); } else if (params_pb.has_vamana()) { return ProtoConverter::FromPb(params_pb.vamana()); + } else if (params_pb.has_fts()) { + return ProtoConverter::FromPb(params_pb.fts()); } return nullptr; @@ -286,6 +310,13 @@ proto::IndexParams ProtoConverter::ToPb(const IndexParams *params) { } break; } + case IndexType::FTS: { + auto fts_params = dynamic_cast(params); + if (fts_params) { + params_pb.mutable_fts()->CopyFrom(ProtoConverter::ToPb(fts_params)); + } + break; + } default: break; } diff --git a/src/db/index/common/proto_converter.h b/src/db/index/common/proto_converter.h index 362f95047..4850bac9c 100644 --- a/src/db/index/common/proto_converter.h +++ b/src/db/index/common/proto_converter.h @@ -48,6 +48,10 @@ struct ProtoConverter { const proto::InvertIndexParams ¶ms_pb); static proto::InvertIndexParams ToPb(const InvertIndexParams *params); + // FtsIndexParams + static FtsIndexParams::Ptr FromPb(const proto::FtsIndexParams ¶ms_pb); + static proto::FtsIndexParams ToPb(const FtsIndexParams *params); + // IndexParams static IndexParams::Ptr FromPb(const proto::IndexParams ¶ms_pb); static proto::IndexParams ToPb(const IndexParams *params); diff --git a/src/db/index/common/schema.cc b/src/db/index/common/schema.cc index 1236f5fc2..3c4d92495 100644 --- a/src/db/index/common/schema.cc +++ b/src/db/index/common/schema.cc @@ -301,11 +301,11 @@ Status CollectionSchema::validate() const { "schema validate failed: max_doc_count_per_segment must >= ", MAX_DOC_COUNT_PER_SEGMENT_MIN_THRESHOLD); } - auto v_fields = vector_fields(); - if (v_fields.empty()) { - return Status::InvalidArgument( - "schema validate failed: vector fields is empty"); + if (fields_.empty()) { + return Status::InvalidArgument("schema validate failed: collection[", name_, + "] has no fields"); } + auto v_fields = vector_fields(); if (v_fields.size() > kMaxVectorFieldSize) { return Status::InvalidArgument( "schema validate failed: collection[", name_, @@ -549,6 +549,25 @@ FieldSchemaPtrList CollectionSchema::vector_fields() const { return vector_fields; } +bool CollectionSchema::has_fts_field() const { + for (const auto &field : fields_) { + if (field->index_type() == IndexType::FTS) { + return true; + } + } + return false; +} + +FieldSchemaPtrList CollectionSchema::fts_fields() const { + FieldSchemaPtrList fts; + for (const auto &field : fields_) { + if (field->index_type() == IndexType::FTS) { + fts.push_back(field); + } + } + return fts; +} + uint64_t CollectionSchema::max_doc_count_per_segment() const { return max_doc_count_per_segment_; } diff --git a/src/db/index/common/type_helper.h b/src/db/index/common/type_helper.h index 02b7c0bad..0fe42d0c1 100644 --- a/src/db/index/common/type_helper.h +++ b/src/db/index/common/type_helper.h @@ -37,6 +37,8 @@ struct IndexTypeCodeBook { return IndexType::VAMANA; case proto::IT_INVERT: return IndexType::INVERT; + case proto::IT_FTS: + return IndexType::FTS; default: break; } @@ -58,6 +60,8 @@ struct IndexTypeCodeBook { return proto::IT_VAMANA; case IndexType::INVERT: return proto::IT_INVERT; + case IndexType::FTS: + return proto::IT_FTS; default: break; } @@ -79,6 +83,8 @@ struct IndexTypeCodeBook { return "VAMANA"; case IndexType::INVERT: return "INVERT"; + case IndexType::FTS: + return "FTS"; default: break; } diff --git a/src/db/index/segment/segment.cc b/src/db/index/segment/segment.cc index 96ec3dc37..d643894b5 100644 --- a/src/db/index/segment/segment.cc +++ b/src/db/index/segment/segment.cc @@ -45,6 +45,9 @@ #include "db/common/file_helper.h" #include "db/common/global_resource.h" #include "db/common/typedef.h" +#include "db/index/column/fts_column/fts_column_indexer.h" +#include "db/index/column/fts_column/fts_rocksdb_merge.h" +#include "db/index/column/fts_column/fts_types.h" #include "db/index/column/inverted_column/inverted_indexer.h" #include "db/index/column/vector_column/engine_helper.hpp" #include "db/index/column/vector_column/vector_column_indexer.h" @@ -68,6 +71,7 @@ namespace zvec { + void global_init() { static std::once_flag once; // run once @@ -160,6 +164,13 @@ class SegmentImpl : public Segment, InvertedColumnIndexer::Ptr get_scalar_indexer( const std::string &field_name) const override; + fts::FtsColumnIndexerPtr get_fts_indexer( + const std::string &field_name) const override; + + Result> fts_search( + const std::string &field_name, const fts::FtsAstNode &ast, + const fts::FtsQueryParams ¶ms) override; + const IndexFilter::Ptr get_filter() override; Status create_all_vector_index( @@ -279,6 +290,7 @@ class SegmentImpl : public Segment, const vector_column_params::VectorDataBuffer &buf, Doc *doc); Status insert_scalar_indexer(Doc &doc); + Status insert_fts_indexer(Doc &doc); Status insert_vector_indexer(Doc &doc); Status internal_insert(Doc &doc); Status internal_update(Doc &doc); @@ -298,6 +310,12 @@ class SegmentImpl : public Segment, Status reopen_invert_indexer(bool read_only = false); + // FTS helpers + Status open_fts_indexers(bool create); + Status close_fts_indexers(); + Status flush_fts_indexers(); + Status dump_fts_indexers(); + Status insert_array_to_invert_indexer( const FieldSchema::Ptr &schema, const std::shared_ptr &data, @@ -322,6 +340,11 @@ class SegmentImpl : public Segment, // scalar index (uses segment-local doc ID) InvertedIndexer::Ptr invert_indexers_; + // FTS index (uses segment-local doc ID) + std::shared_ptr fts_ctx_; + std::unordered_map fts_indexers_; + bool has_fts_{false}; + // vector index (uses block-local doc ID, each indexer starts from 0) std::unordered_map memory_vector_indexers_; @@ -447,6 +470,10 @@ Status SegmentImpl::Open(const SegmentOptions &options) { s = load_scalar_index_blocks(); CHECK_RETURN_STATUS(s); + // load FTS indexes + s = open_fts_indexers(false); + CHECK_RETURN_STATUS(s); + // load vector indexes s = load_vector_index_blocks(); CHECK_RETURN_STATUS(s); @@ -510,6 +537,9 @@ Status SegmentImpl::Create(const SegmentOptions &options, uint64_t min_doc_id) { auto s = load_scalar_index_blocks(true); CHECK_RETURN_STATUS(s); + s = open_fts_indexers(true); + CHECK_RETURN_STATUS(s); + doc_id_allocator_.store(min_doc_id); return Status::OK(); @@ -520,6 +550,7 @@ Status SegmentImpl::close() { if (invert_indexers_) { invert_indexers_.reset(); } + close_fts_indexers(); for (const auto &[name, indexers] : vector_indexers_) { for (auto indexer : indexers) { indexer->Close(); @@ -818,6 +849,9 @@ Status SegmentImpl::internal_insert(Doc &doc) { if (!s.ok() && s.code() != StatusCode::ALREADY_EXISTS) { return s; } + // write FTS index + s = insert_fts_indexer(doc); + CHECK_RETURN_STATUS(s); // write vector index s = insert_vector_indexer(doc); if (!s.ok() && s != Status::AlreadyExists()) { @@ -2143,6 +2177,9 @@ Status SegmentImpl::dump() { CHECK_RETURN_STATUS(s); } + s = dump_fts_indexers(); + CHECK_RETURN_STATUS(s); + sealed_ = true; return Status::OK(); @@ -2175,6 +2212,12 @@ Status SegmentImpl::flush() { CHECK_RETURN_STATUS(s); } + // flush FTS indexers + if (has_fts_) { + s = flush_fts_indexers(); + CHECK_RETURN_STATUS(s); + } + // flush vector indexer for (const auto &indexer : memory_vector_indexers_) { if (indexer.second) { @@ -4418,4 +4461,240 @@ Result Segment::Open(const std::string &path, return segment; } +//////////////////////////////////////////////////////////////////////////////////// +// FTS integration +//////////////////////////////////////////////////////////////////////////////////// + +Status SegmentImpl::open_fts_indexers(bool create) { + if (!collection_schema_->has_fts_field()) { + return Status::OK(); + } + + auto fts_fields = collection_schema_->fts_fields(); + has_fts_ = true; + + auto fts_path = FileHelper::MakeFtsIndexPath(seg_path_); + + // Collect CF names and per-CF merge operators + std::vector cf_names; + std::unordered_map> + per_cf_merge_ops; + + for (const auto &field : fts_fields) { + const auto &name = field->name(); + cf_names.push_back(name); // postings + cf_names.push_back(name + kFtsPositionsSuffix); // positions + + per_cf_merge_ops[name] = std::make_shared(); + + // Side CFs (_tf / _max_tf / _doc_len) are present in mutable segments + // that have not yet been dumped. After dump, + // convert_postings_to_bitpacked() inlines their payloads into BitPacked + // postings and the CFs are dropped. + // + // When opening an existing segment (create=false), we always include the + // side CF names so that segments closed without dump (e.g. graceful + // shutdown with only flush) can still perform accurate BM25 scoring via + // the Roaring posting path. If the CFs were already dropped (post-dump + // immutable segment), the open will fail and we retry without them. + if (create) { + cf_names.push_back(name + kFtsTfSuffix); + cf_names.push_back(name + kFtsMaxTfSuffix); + cf_names.push_back(name + kFtsDocLenSuffix); + per_cf_merge_ops[name + kFtsMaxTfSuffix] = + std::make_shared(); + } + } + cf_names.push_back(kFtsStatCfName); + + fts_ctx_ = std::make_shared(); + Status s; + + // Whether side CFs are available after open + bool has_side_cfs = create; + + bool enable_hash_skiplist = true; + if (create) { + s = fts_ctx_->create(RocksdbContext::Args{ + fts_path, cf_names, nullptr, per_cf_merge_ops, enable_hash_skiplist}); + } else { + // Try opening with side CFs first (un-dumped mutable segment). + // If they don't exist (post-dump), retry without them. + std::vector cf_names_with_side = cf_names; + auto per_cf_merge_ops_with_side = per_cf_merge_ops; + for (const auto &field : fts_fields) { + const auto &name = field->name(); + cf_names_with_side.push_back(name + kFtsTfSuffix); + cf_names_with_side.push_back(name + kFtsMaxTfSuffix); + cf_names_with_side.push_back(name + kFtsDocLenSuffix); + per_cf_merge_ops_with_side[name + kFtsMaxTfSuffix] = + std::make_shared(); + } + s = fts_ctx_->open( + RocksdbContext::Args{fts_path, cf_names_with_side, nullptr, + per_cf_merge_ops_with_side, enable_hash_skiplist}, + options_.read_only_); + if (s.ok()) { + has_side_cfs = true; + } else { + // Side CFs not found (immutable segment after dump) — retry without. + fts_ctx_ = std::make_shared(); + s = fts_ctx_->open( + RocksdbContext::Args{fts_path, cf_names, nullptr, per_cf_merge_ops}, + options_.read_only_); + } + } + if (!s.ok()) { + LOG_ERROR("open_fts_indexers: failed to %s FTS RocksDB at [%s]: %s", + create ? "create" : "open", fts_path.c_str(), + s.message().c_str()); + return s; + } + + auto *stat_cf = fts_ctx_->get_cf(kFtsStatCfName); + + for (const auto &field : fts_fields) { + const auto &name = field->name(); + auto *postings_cf = fts_ctx_->get_cf(name); + auto *positions_cf = fts_ctx_->get_cf(name + kFtsPositionsSuffix); + // Side CF handles are available when the segment has not been dumped + // (side CFs still exist). For dumped immutable segments the handles + // are nullptr and FtsColumnIndexer falls back to BitPacked inline + // payloads or tf=1/doc_len=1 defaults. + auto *term_freq_cf = + has_side_cfs ? fts_ctx_->get_cf(name + kFtsTfSuffix) : nullptr; + auto *max_tf_cf = + has_side_cfs ? fts_ctx_->get_cf(name + kFtsMaxTfSuffix) : nullptr; + auto *doc_len_cf = + has_side_cfs ? fts_ctx_->get_cf(name + kFtsDocLenSuffix) : nullptr; + + auto indexer = std::make_shared(); + + auto ret = indexer->open(field, fts_ctx_.get(), postings_cf, positions_cf, + term_freq_cf, max_tf_cf, doc_len_cf, stat_cf); + if (!ret.has_value()) { + LOG_ERROR( + "open_fts_indexers: FtsColumnIndexer::open failed for field[%s] " + "err[%s] postings_cf[%p] positions_cf[%p] stat_cf[%p]", + name.c_str(), ret.error().message().c_str(), (void *)postings_cf, + (void *)positions_cf, (void *)stat_cf); + return Status::InternalError("Failed to open FTS indexer: ", name, " ", + ret.error().message()); + } + + fts_indexers_[name] = indexer; + } + + return Status::OK(); +} + +Status SegmentImpl::flush_fts_indexers() { + for (const auto &[name, indexer] : fts_indexers_) { + auto ret = indexer->flush(); + if (!ret.has_value()) { + return Status::InternalError("FTS flush failed: ", name, " ", + ret.error().message()); + } + } + auto s = fts_ctx_->flush(); + CHECK_RETURN_STATUS(s); + return Status::OK(); +} + +Status SegmentImpl::close_fts_indexers() { + fts_indexers_.clear(); + if (fts_ctx_) { + auto s = fts_ctx_->close(); + fts_ctx_.reset(); + return s; + } + return Status::OK(); +} + +Status SegmentImpl::insert_fts_indexer(Doc &doc) { + if (!has_fts_) { + return Status::OK(); + } + for (const auto &field : collection_schema_->fts_fields()) { + auto it = fts_indexers_.find(field->name()); + if (it == fts_indexers_.end()) { + return Status::InternalError("FTS indexer not found: ", field->name()); + } + auto value = doc.get(field->name()); + if (value.has_value()) { + auto segment_doc_id = doc_ids_.size(); + auto ret = it->second->insert(segment_doc_id, value.value()); + if (!ret.has_value()) { + return Status::InternalError("FTS insert failed: ", field->name(), " ", + ret.error().message()); + } + } + } + return Status::OK(); +} + +Status SegmentImpl::dump_fts_indexers() { + if (!has_fts_) { + return Status::OK(); + } + + // flush all indexers + for (const auto &[name, indexer] : fts_indexers_) { + auto ret = indexer->flush(); + if (!ret.has_value()) { + return Status::InternalError("FTS flush failed during dump: ", name, " ", + ret.error().message()); + } + } + + // convert postings to bitpacked format + for (const auto &[name, indexer] : fts_indexers_) { + auto ret = indexer->convert_postings_to_bitpacked(); + if (!ret.has_value()) { + return Status::InternalError("FTS convert_postings_to_bitpacked failed: ", + name, " ", ret.error().message()); + } + } + + // reset side CFs and drop $TF/$MAX_TF/$DOC_LEN CFs + for (const auto &[name, indexer] : fts_indexers_) { + indexer->reset_side_cfs(); + } + for (const auto &field : collection_schema_->fts_fields()) { + const auto &name = field->name(); + fts_ctx_->drop_cf(name + kFtsTfSuffix); + fts_ctx_->drop_cf(name + kFtsMaxTfSuffix); + fts_ctx_->drop_cf(name + kFtsDocLenSuffix); + } + + return Status::OK(); +} + +fts::FtsColumnIndexerPtr SegmentImpl::get_fts_indexer( + const std::string &field_name) const { + auto it = fts_indexers_.find(field_name); + if (it != fts_indexers_.end()) { + return it->second; + } + return nullptr; +} + +Result> SegmentImpl::fts_search( + const std::string &field_name, const fts::FtsAstNode &ast, + const fts::FtsQueryParams ¶ms) { + auto indexer = get_fts_indexer(field_name); + if (!indexer) { + return tl::make_unexpected( + Status::NotFound("FTS indexer not found: ", field_name)); + } + + auto ret = indexer->search(ast, params); + if (!ret.has_value()) { + return tl::make_unexpected(Status::InternalError( + "FTS search failed: ", field_name, " ", ret.error().message())); + } + + return std::move(ret.value()); +} + } // namespace zvec \ No newline at end of file diff --git a/src/db/index/segment/segment.h b/src/db/index/segment/segment.h index 06e05d78c..3b21c6487 100644 --- a/src/db/index/segment/segment.h +++ b/src/db/index/segment/segment.h @@ -25,6 +25,7 @@ #include #include #include +#include "db/index/column/fts_column/fts_column_indexer.h" #include "db/index/column/inverted_column/inverted_column_indexer.h" #include "db/index/column/inverted_column/inverted_indexer.h" #include "db/index/column/vector_column/combined_vector_column_indexer.h" @@ -172,6 +173,14 @@ class Segment { virtual InvertedColumnIndexer::Ptr get_scalar_indexer( const std::string &field_name) const = 0; + // caller hold segment shared_ptr for segment handle the indexer's lifetime + virtual fts::FtsColumnIndexerPtr get_fts_indexer( + const std::string &field_name) const = 0; + + virtual Result> fts_search( + const std::string &field_name, const fts::FtsAstNode &ast, + const fts::FtsQueryParams ¶ms) = 0; + virtual const IndexFilter::Ptr get_filter() = 0; // for others diff --git a/src/db/index/segment/segment_helper.cc b/src/db/index/segment/segment_helper.cc index 7d1adc792..ff5204a00 100644 --- a/src/db/index/segment/segment_helper.cc +++ b/src/db/index/segment/segment_helper.cc @@ -24,10 +24,16 @@ #if RABITQ_SUPPORTED #include "core/algorithm/hnsw_rabitq/rabitq_params.h" #endif +#include #include "db/common/constants.h" #include "db/common/file_helper.h" #include "db/common/global_resource.h" +#include "db/common/rocksdb_context.h" #include "db/common/typedef.h" +#include "db/index/column/fts_column/fts_column_indexer.h" +#include "db/index/column/fts_column/fts_rocksdb_merge.h" +#include "db/index/column/fts_column/fts_rocksdb_reducer.h" +#include "db/index/column/fts_column/fts_types.h" #include "db/index/column/inverted_column/inverted_indexer.h" #include "db/index/column/vector_column/engine_helper.hpp" #include "db/index/column/vector_column/vector_column_indexer.h" @@ -38,7 +44,6 @@ #include "zvec/core/framework/index_factory.h" #include "zvec/core/framework/index_meta.h" #include "zvec/core/framework/index_reformer.h" -#include "roaring.hh" namespace zvec { @@ -68,7 +73,9 @@ Status SegmentHelper::Execute(SegmentTask::Ptr &task) { class RowIdFilter : public IndexFilter { public: - explicit RowIdFilter(roaring::Roaring &&delete_row_id_bitmap) + // Copies the bitmap so callers can keep using it (e.g. share with FTS + // reduce). + explicit RowIdFilter(const roaring::Roaring &delete_row_id_bitmap) : delete_row_id_bitmap_(delete_row_id_bitmap) {} bool is_filtered(uint64_t id) const override { @@ -87,6 +94,10 @@ Status SegmentHelper::ExecuteCompactTask(CompactTask &task) { auto filter = task.filter_; auto output_segment_id = task.output_segment_id_; + // input_segments must be pre-sorted by ascending min_doc_id so the + // shared delete_row_id_bitmap (built by FilterRecordBatch, consumed by + // both vector and FTS reducers) is well-defined. Guaranteed upstream by + // SegmentManager::get_segments(). auto columns = schema->forward_field_names(); // make segment path @@ -118,8 +129,10 @@ Status SegmentHelper::ExecuteCompactTask(CompactTask &task) { return Status::OK(); } + // RowIdFilter copies the bitmap so ReduceFts below can reuse it; sharing + // lets the FTS reducer skip its own per-doc dense rank table. std::shared_ptr row_id_filter = - std::make_shared(std::move(delete_row_id_bitmap)); + std::make_shared(delete_row_id_bitmap); s = ReduceVectorIndex(schema, input_segments, output_segment_path, row_id_filter, block_id_generator, min_doc_id, @@ -128,6 +141,12 @@ Status SegmentHelper::ExecuteCompactTask(CompactTask &task) { LOG_INFO("Compacted vector index"); + s = ReduceFts(schema, input_segments, output_segment_path, + delete_row_id_bitmap); + CHECK_RETURN_STATUS(s); + + LOG_INFO("Compacted fts index"); + auto new_segment_meta = std::make_shared(); new_segment_meta->set_id(task.output_segment_id_); new_segment_meta->set_persisted_blocks(block_metas); @@ -903,6 +922,117 @@ arrow::Status SegmentHelper::FilterRecordBatch( return arrow::Status::OK(); } +Status SegmentHelper::ReduceFts(const CollectionSchema::Ptr &schema, + const std::vector &input_segments, + const std::string &output_segment_path, + const roaring::Roaring &delete_row_id_bitmap) { + if (!schema->has_fts_field()) { + return Status::OK(); + } + if (input_segments.empty()) { + return Status::OK(); + } + + auto fts_fields = schema->fts_fields(); + + // Build the destination FTS RocksDB with the post-dump CF layout: + // postings + positions per field, plus the shared stat CF. Side CFs + // ($TF/$MAX_TF/$DOC_LEN) are skipped — the reducer writes BitPacked + // directly, matching the immutable-segment shape after + // convert_postings_to_bitpacked(). + auto dst_fts_path = FileHelper::MakeFtsIndexPath(output_segment_path); + std::vector cf_names; + std::unordered_map> + per_cf_merge_ops; + for (const auto &field : fts_fields) { + const auto &name = field->name(); + cf_names.push_back(name); + cf_names.push_back(name + kFtsPositionsSuffix); + per_cf_merge_ops[name] = std::make_shared(); + } + cf_names.push_back(kFtsStatCfName); + + auto dst_ctx = std::make_shared(); + Status s = dst_ctx->create( + RocksdbContext::Args{dst_fts_path, cf_names, nullptr, per_cf_merge_ops, + /*enable_hash_skiplist=*/true}); + if (!s.ok()) { + LOG_ERROR("ReduceFts: create destination FTS RocksDB failed at [%s]: %s", + dst_fts_path.c_str(), s.message().c_str()); + return s; + } + + // Feed segments in caller's order — matches the scan order + // delete_row_id_bitmap is keyed by. + auto *dst_stat_cf = dst_ctx->get_cf(kFtsStatCfName); + for (const auto &field : fts_fields) { + const auto &name = field->name(); + auto *dst_postings_cf = dst_ctx->get_cf(name); + auto *dst_positions_cf = dst_ctx->get_cf(name + kFtsPositionsSuffix); + + fts::FtsRocksdbReducer reducer; + auto init_ret = reducer.init(name, dst_ctx.get(), dst_postings_cf, + dst_positions_cf, dst_stat_cf); + if (!init_ret) { + auto err = init_ret.error(); + LOG_ERROR("ReduceFts: reducer.init failed. field[%s] err[%s]", + name.c_str(), err.message().c_str()); + (void)dst_ctx->close(); + return err; + } + + for (auto &seg : input_segments) { + auto src_indexer = seg->get_fts_indexer(name); + if (!src_indexer) { + auto err = Status::InternalError( + "ReduceFts: source segment missing FTS indexer. segment_id=", + seg->id(), " field=", name); + LOG_ERROR("%s", err.message().c_str()); + (void)dst_ctx->close(); + return err; + } + fts::FtsSegmentStats stats{seg->meta()->min_doc_id(), + seg->meta()->max_doc_id(), + seg->meta()->doc_count()}; + auto feed_ret = + reducer.feed(stats, src_indexer->ctx(), src_indexer->postings_cf(), + src_indexer->positions_cf()); + if (!feed_ret) { + auto err = feed_ret.error(); + LOG_ERROR("ReduceFts: reducer.feed failed. field[%s] err[%s]", + name.c_str(), err.message().c_str()); + (void)dst_ctx->close(); + return err; + } + } + + auto reduce_ret = reducer.reduce(delete_row_id_bitmap); + if (!reduce_ret) { + auto err = reduce_ret.error(); + LOG_ERROR("ReduceFts: reducer.reduce failed. field[%s] err[%s]", + name.c_str(), err.message().c_str()); + (void)dst_ctx->close(); + return err; + } + (void)reducer.cleanup(); + } + + s = dst_ctx->flush(); + if (!s.ok()) { + LOG_ERROR("ReduceFts: flush destination FTS RocksDB failed: %s", + s.message().c_str()); + (void)dst_ctx->close(); + return s; + } + s = dst_ctx->close(); + if (!s.ok()) { + LOG_ERROR("ReduceFts: close destination FTS RocksDB failed: %s", + s.message().c_str()); + return s; + } + return Status::OK(); +} + Status SegmentHelper::ExecuteCreateVectorIndexTask( CreateVectorIndexTask &task) { if (task.column_to_build_vector_index_ == "") { @@ -936,4 +1066,4 @@ Status SegmentHelper::ExecuteDropScalarIndexTask(DropScalarIndexTask &task) { &task.output_scalar_indexer_); } -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/db/index/segment/segment_helper.h b/src/db/index/segment/segment_helper.h index a1d5bb754..96b8ee8fd 100644 --- a/src/db/index/segment/segment_helper.h +++ b/src/db/index/segment/segment_helper.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -230,6 +231,16 @@ class SegmentHelper { const core::IndexProvider::Pointer &raw_vector_provider, std::shared_ptr *out_field); + // Build a fresh FTS RocksDB under output_segment_path by streaming all + // FTS fields from input_segments through FtsRocksdbReducer. + // - input_segments: ascending min_doc_id, contiguous doc_id range. + // - delete_row_id_bitmap: deleted positions in input scan order + // (shared with the vector path); empty for pure consolidation. + static Status ReduceFts(const CollectionSchema::Ptr &schema, + const std::vector &input_segments, + const std::string &output_segment_path, + const roaring::Roaring &delete_row_id_bitmap); + static arrow::Status FilterRecordBatch( const std::shared_ptr &batch, const IndexFilter::Ptr filter, uint32_t row_id_offset, @@ -238,4 +249,4 @@ class SegmentHelper { uint64_t *max_doc_id); }; -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/db/proto/zvec.proto b/src/db/proto/zvec.proto index 197914c76..e94d1d399 100644 --- a/src/db/proto/zvec.proto +++ b/src/db/proto/zvec.proto @@ -62,6 +62,8 @@ enum IndexType { IT_VAMANA = 5; // Invert Index IT_INVERT = 10; + // Full-Text Search Index + IT_FTS = 11; }; enum QuantizeType { @@ -131,6 +133,12 @@ message VamanaIndexParams { bool use_id_map = 7; } +message FtsIndexParams { + string tokenizer_name = 1; + repeated string filters = 2; + string extra_params = 3; +}; + message IndexParams { oneof params { InvertIndexParams invert = 1; @@ -139,6 +147,7 @@ message IndexParams { IVFIndexParams ivf = 4; HnswRabitqIndexParams hnsw_rabitq = 5; VamanaIndexParams vamana = 6; + FtsIndexParams fts = 7; }; }; diff --git a/src/db/sqlengine/analyzer/query_analyzer.cc b/src/db/sqlengine/analyzer/query_analyzer.cc index 4d981370a..c4af8f366 100644 --- a/src/db/sqlengine/analyzer/query_analyzer.cc +++ b/src/db/sqlengine/analyzer/query_analyzer.cc @@ -400,6 +400,11 @@ Result QueryAnalyzer::create_queryinfo_from_sqlinfo( // set group by query_info->set_group_by(select_info->group_by()); + // set fts query + if (select_info->has_fts_query()) { + query_info->set_fts_cond_info(select_info->fts_cond_info()); + } + return query_info; } diff --git a/src/db/sqlengine/analyzer/query_info.cc b/src/db/sqlengine/analyzer/query_info.cc index f6f066312..3a506272c 100644 --- a/src/db/sqlengine/analyzer/query_info.cc +++ b/src/db/sqlengine/analyzer/query_info.cc @@ -85,6 +85,12 @@ std::string QueryInfo::to_string() const { ")\n"); } + str += "fts_cond:\n"; + if (fts_cond_info_ != nullptr) { + str += fts_cond_info_->to_string(); + str += "\n"; + } + str += "filter_cond:\n"; if (filter_cond_ != nullptr) { str += filter_cond_->text(); diff --git a/src/db/sqlengine/analyzer/query_info.h b/src/db/sqlengine/analyzer/query_info.h index 653231a74..ad9b381fc 100644 --- a/src/db/sqlengine/analyzer/query_info.h +++ b/src/db/sqlengine/analyzer/query_info.h @@ -22,6 +22,7 @@ #include #include #include "db/common/constants.h" +#include "db/sqlengine/common/fts_cond_info.h" #include "db/sqlengine/common/group_by.h" #include "query_field_info.h" #include "query_node.h" @@ -125,6 +126,7 @@ class QueryInfo { bool reverse_sort_{false}; }; + public: QueryInfo() = default; ~QueryInfo() = default; @@ -161,6 +163,14 @@ class QueryInfo { return vector_cond_info_; } + void set_fts_cond_info(FtsCondInfo::Ptr value) { + fts_cond_info_ = std::move(value); + } + + const FtsCondInfo::Ptr &fts_cond_info() const { + return fts_cond_info_; + } + void set_query_topn(uint32_t value) { query_topn_ = value; } @@ -340,6 +350,7 @@ class QueryInfo { QueryNode::Ptr filter_cond_{nullptr}; QueryVectorCondInfo::Ptr vector_cond_info_{nullptr}; + FtsCondInfo::Ptr fts_cond_info_{nullptr}; // these two are for post filtering only QueryNode::Ptr post_invert_cond_{nullptr}; diff --git a/src/db/sqlengine/common/fts_cond_info.h b/src/db/sqlengine/common/fts_cond_info.h new file mode 100644 index 000000000..17de4ad75 --- /dev/null +++ b/src/db/sqlengine/common/fts_cond_info.h @@ -0,0 +1,43 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "db/index/column/fts_column/fts_query_ast.h" + +namespace zvec::sqlengine { + +struct FtsCondInfo { + using Ptr = std::shared_ptr; + + FtsCondInfo() = default; + + FtsCondInfo(std::string field_name, fts::FtsAstNodePtr ast) + : field_name(std::move(field_name)), fts_ast(std::move(ast)) {} + + std::string to_string() const { + std::string str = field_name + " MATCH "; + if (fts_ast) { + str += fts_ast->text(); + } + return str; + } + + std::string field_name; + fts::FtsAstNodePtr fts_ast; +}; + +} // namespace zvec::sqlengine diff --git a/src/db/sqlengine/parser/select_info.cc b/src/db/sqlengine/parser/select_info.cc index 87ac39975..c4bed19df 100644 --- a/src/db/sqlengine/parser/select_info.cc +++ b/src/db/sqlengine/parser/select_info.cc @@ -196,6 +196,11 @@ std::string SelectInfo::to_string() { str += "\n"; } + if (fts_cond_info_ != nullptr) { + str += "fts_cond: " + fts_cond_info_->to_string(); + str += "\n"; + } + return str; } diff --git a/src/db/sqlengine/parser/select_info.h b/src/db/sqlengine/parser/select_info.h index e1a312013..c393ef756 100644 --- a/src/db/sqlengine/parser/select_info.h +++ b/src/db/sqlengine/parser/select_info.h @@ -17,6 +17,7 @@ #include #include #include +#include "db/sqlengine/common/fts_cond_info.h" #include "db/sqlengine/common/group_by.h" #include "base_info.h" #include "node.h" @@ -69,6 +70,18 @@ class SelectInfo : public BaseInfo { return group_by_; } + void set_fts_cond_info(FtsCondInfo::Ptr value) { + fts_cond_info_ = std::move(value); + } + + const FtsCondInfo::Ptr &fts_cond_info() const { + return fts_cond_info_; + } + + bool has_fts_query() const { + return fts_cond_info_ != nullptr; + } + std::string to_string(); private: @@ -82,6 +95,7 @@ class SelectInfo : public BaseInfo { int limit_{-1}; bool include_vector_{false}; bool include_doc_id_{false}; + FtsCondInfo::Ptr fts_cond_info_{nullptr}; }; } // namespace zvec::sqlengine diff --git a/src/db/sqlengine/planner/doc_filter.cc b/src/db/sqlengine/planner/doc_filter.cc index 756a1b972..0f44e6e97 100644 --- a/src/db/sqlengine/planner/doc_filter.cc +++ b/src/db/sqlengine/planner/doc_filter.cc @@ -17,7 +17,6 @@ #include #include #include -#include #include "db/sqlengine/planner/invert_search.h" namespace zvec::sqlengine { @@ -107,7 +106,8 @@ std::optional DocFilter::get_forward_bit(uint64_t id) const { return std::nullopt; } -std::optional> DocFilter::get_bf_by_keys_and_update() { +std::optional> DocFilter::get_bf_by_keys_and_update( + float ratio) { auto meta = segment_->meta(); if (!meta) { return std::nullopt; @@ -117,9 +117,7 @@ std::optional> DocFilter::get_bf_by_keys_and_update() { return std::nullopt; } size_t doc_count = meta->doc_count(); - float brute_force_by_keys_ratio = - GlobalConfig::Instance().brute_force_by_keys_ratio(); - uint64_t bf_by_keys_threshold = meta->doc_count() * brute_force_by_keys_ratio; + uint64_t bf_by_keys_threshold = static_cast(doc_count * ratio); // decide to use brute force by keys or not if (size_t match_count = invert_result_->count(); @@ -128,13 +126,16 @@ std::optional> DocFilter::get_bf_by_keys_and_update() { invert_result_->extract_ids(&ids); invert_filter_.reset(); invert_result_.reset(); - LOG_INFO("Use brute force by keys, doc_count[%zu] invert_result_count[%zu]", - doc_count, match_count); + LOG_INFO( + "Use brute force by keys, doc_count[%zu] invert_result_count[%zu] " + "ratio[%.4f]", + doc_count, match_count, ratio); return std::vector(ids.begin(), ids.end()); } else { LOG_DEBUG( - "Not use brute force by keys, doc_count[%zu] invert_result_count[%zu]", - doc_count, match_count); + "Not use brute force by keys, doc_count[%zu] invert_result_count[%zu] " + "ratio[%.4f]", + doc_count, match_count, ratio); } return std::nullopt; } diff --git a/src/db/sqlengine/planner/doc_filter.h b/src/db/sqlengine/planner/doc_filter.h index b662a7425..7f4dffbd1 100644 --- a/src/db/sqlengine/planner/doc_filter.h +++ b/src/db/sqlengine/planner/doc_filter.h @@ -44,8 +44,11 @@ class DocFilter : public IndexFilter { bool is_filtered(uint64_t id) const override; - //! get brute force by keys and clear `invert_filter_` if suitable - std::optional> get_bf_by_keys_and_update(); + //! When invert cardinality <= \p ratio * doc_count, extract the ids and + //! clear invert_filter_ so the caller drives evaluation by ids instead of + //! bitmap-checking. Ratio is per-caller (vector vs FTS use different + //! GlobalConfig knobs) because per-candidate cost differs. + std::optional> get_bf_by_keys_and_update(float ratio); bool empty() const; diff --git a/src/db/sqlengine/planner/fts_recall_node.cc b/src/db/sqlengine/planner/fts_recall_node.cc new file mode 100644 index 000000000..45313d9e0 --- /dev/null +++ b/src/db/sqlengine/planner/fts_recall_node.cc @@ -0,0 +1,143 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "db/sqlengine/planner/fts_recall_node.h" +#include +#include +#include +#include "db/sqlengine/common/util.h" + +namespace cp = arrow::compute; + +namespace zvec::sqlengine { + +FtsRecallNode::FtsRecallNode(Segment::Ptr segment, QueryInfo::Ptr query_info, + DocFilter::Ptr doc_filter, int batch_size) + : segment_(std::move(segment)), + query_info_(std::move(query_info)), + doc_filter_(std::move(doc_filter)), + fetched_columns_(query_info_->get_all_fetched_scalar_field_names()), + batch_size_(batch_size) { + auto table = segment_->fetch(fetched_columns_, std::vector{}); + // Append BM25 score column so downstream fill_doc_score() surfaces it to + // the Python Doc.score, matching the vector-recall path. + schema_ = Util::append_field(*table->schema(), kFieldScore, arrow::float32()); +} + +arrow::AsyncGenerator> FtsRecallNode::gen() { + auto state_ptr = std::make_shared(); + return [self = shared_from_this(), state_ptr = std::move(state_ptr)]() + -> arrow::Future> { + auto &state = *state_ptr; + + if (!state.iter_) { + auto fts_ret = self->prepare(); + if (!fts_ret) { + return arrow::Future>::MakeFinished( + arrow::Status::ExecutionError("prepare fts failed:", + fts_ret.error().c_str())); + } + state.fts_result_ = fts_ret.value(); + state.iter_ = state.fts_result_->create_iterator(); + } + + if (!state.iter_->valid()) { + return arrow::Future>::MakeFinished( + std::nullopt); + } + + std::vector indices; + indices.reserve(self->batch_size_); + arrow::FloatBuilder score_builder; + for (int i = 0; state.iter_->valid() && i < self->batch_size_; + i++, state.iter_->next()) { + indices.push_back(state.iter_->doc_id()); + auto s = score_builder.Append(state.iter_->score()); + if (!s.ok()) { + return arrow::Future>::MakeFinished( + arrow::Status::ExecutionError("score builder append failed:", + s.ToString())); + } + } + if (indices.empty()) { + return arrow::Future>::MakeFinished( + std::nullopt); + } + + auto table = self->segment_->fetch(self->fetched_columns_, indices); + if (!table) { + return arrow::Future>::MakeFinished( + arrow::Status::UnknownError("fetch table failed")); + } + auto batch = table->CombineChunksToBatch(); + if (!batch.ok()) { + return arrow::Future>::MakeFinished( + arrow::Status::ExecutionError("combine chunks to batch failed:", + batch.status().ToString())); + } + auto score_array = score_builder.Finish(); + if (!score_array.ok()) { + return arrow::Future>::MakeFinished( + arrow::Status::ExecutionError("finish score builder failed:", + score_array.status().ToString())); + } + auto record_batch = std::move(batch.ValueUnsafe()); + auto with_score = + record_batch->AddColumn(record_batch->num_columns(), kFieldScore, + score_array.MoveValueUnsafe()); + if (!with_score.ok()) { + return arrow::Future>::MakeFinished( + arrow::Status::ExecutionError("add score column failed:", + with_score.status().ToString())); + } + cp::ExecBatch exec_batch(*with_score.ValueUnsafe()); + return arrow::Future>::MakeFinished( + std::move(exec_batch)); + }; +} + +Result FtsRecallNode::prepare() { + auto filter_status = doc_filter_->compute_filter(); + if (!filter_status.ok()) { + return tl::make_unexpected(filter_status); + } + + const auto &fts_cond = query_info_->fts_cond_info(); + if (!fts_cond) { + return tl::make_unexpected( + Status::InvalidArgument("FtsRecallNode: no fts_cond_info in query")); + } + + fts::FtsQueryParams params; + params.topk = query_info_->query_topn(); + // Brute-force path: get_bf_by_keys_and_update also clears invert_filter_ + // when it returns ids, so the filter set below won't double-check them. + if (auto bf_keys = doc_filter_->get_bf_by_keys_and_update( + GlobalConfig::Instance().fts_brute_force_by_keys_ratio())) { + params.candidate_ids = std::move(bf_keys.value()); + } + // Push down remaining filters (delete / forward) so filtered docs are + // skipped during scoring and we still return up to topk results. + params.filter = doc_filter_->empty() ? nullptr : doc_filter_; + + auto results = + segment_->fts_search(fts_cond->field_name, *fts_cond->fts_ast, params); + if (!results) { + return tl::make_unexpected(results.error()); + } + + return std::make_shared(std::move(results.value())); +} + +} // namespace zvec::sqlengine diff --git a/src/db/sqlengine/planner/fts_recall_node.h b/src/db/sqlengine/planner/fts_recall_node.h new file mode 100644 index 000000000..ec1079fc3 --- /dev/null +++ b/src/db/sqlengine/planner/fts_recall_node.h @@ -0,0 +1,59 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "db/index/column/common/index_results.h" +#include "db/index/column/fts_column/fts_index_results.h" +#include "db/index/segment/segment.h" +#include "db/sqlengine/analyzer/query_info.h" +#include "db/sqlengine/planner/doc_filter.h" + +namespace cp = arrow::compute; + +namespace zvec::sqlengine { + +class FtsRecallNode : public std::enable_shared_from_this { + public: + FtsRecallNode(Segment::Ptr segment, QueryInfo::Ptr query_info, + DocFilter::Ptr doc_filter, int batch_size); + + //! get schema + std::shared_ptr schema() const { + return schema_; + } + + arrow::AsyncGenerator> gen(); + + private: + Result prepare(); + + private: + struct State { + FtsIndexResults::Ptr fts_result_; + IndexResults::IteratorUPtr iter_; + }; + + Segment::Ptr segment_; + QueryInfo::Ptr query_info_; + DocFilter::Ptr doc_filter_; + const std::vector &fetched_columns_; + int batch_size_; + std::shared_ptr schema_; +}; + +} // namespace zvec::sqlengine diff --git a/src/db/sqlengine/planner/query_planner.cc b/src/db/sqlengine/planner/query_planner.cc index c0c588a30..fdde20f0f 100644 --- a/src/db/sqlengine/planner/query_planner.cc +++ b/src/db/sqlengine/planner/query_planner.cc @@ -28,6 +28,7 @@ #include "db/sqlengine/analyzer/query_info.h" #include "db/sqlengine/analyzer/query_node.h" #include "db/sqlengine/common/util.h" +#include "db/sqlengine/planner/fts_recall_node.h" #include "db/sqlengine/planner/invert_recall_node.h" #include "db/sqlengine/planner/ops/check_not_filtered_op.h" #include "db/sqlengine/planner/ops/contain_op.h" @@ -356,6 +357,7 @@ Result QueryPlanner::make_physical_plan( query_info->to_string().c_str()); int topn = query_info->query_topn(); auto vector_cond = query_info->vector_cond_info(); + auto fts_cond = query_info->fts_cond_info(); bool has_group_by = query_info->group_by() != nullptr; // optimize plan by instrument query info condition, eg adjust invert cond @@ -406,6 +408,9 @@ Result QueryPlanner::make_physical_plan( if (query_info->vector_cond_info()) { seg_plan = vector_scan(segment, std::move(segment_query_info), std::move(forward_filter), single_stage_search); + } else if (query_info->fts_cond_info()) { + seg_plan = fts_scan(segment, std::move(segment_query_info), + std::move(forward_filter), single_stage_search); } else if (query_info->invert_cond()) { seg_plan = invert_scan(segment, std::move(segment_query_info), std::move(forward_filter)); @@ -439,6 +444,14 @@ Result QueryPlanner::make_physical_plan( kFieldScore, vector_cond->is_reverse_sort() ? cp::SortOrder::Descending : cp::SortOrder::Ascending}}}}}; + } else if (fts_cond) { + // FTS uses BM25 where higher score = more relevant. Per-segment results + // are already in descending score order; merging multiple segments + // requires a global re-sort to keep the contract. + node = ac::Declaration{"order_by", + {std::move(node)}, + ac::OrderByNodeOptions{cp::Ordering{{cp::SortKey{ + kFieldScore, cp::SortOrder::Descending}}}}}; } // group by need to collect all docs @@ -515,14 +528,14 @@ Result QueryPlanner::forward_scan( return std::make_shared(std::move(node), std::move(schema)); } -Result QueryPlanner::vector_scan( - Segment::Ptr seg, QueryInfo::Ptr query_info, - std::unique_ptr forward_filter, +DocFilter::Ptr QueryPlanner::build_doc_filter( + const Segment::Ptr &seg, const QueryInfo::Ptr &query_info, + std::unique_ptr &forward_filter, bool single_stage_search) { std::unique_ptr forward_filter_plan; // if single stage search is not enabled, first run acero plan to get - // forward bitmap, then filter during vector search. otherwise, filter - // forward during forward search. + // forward bitmap, then filter during search. otherwise, filter forward + // during search. if (forward_filter && !single_stage_search) { ac::RecordBatchReaderSourceNodeOptions source_options{ seg->scan(query_info->get_forward_filter_field_names())}; @@ -536,9 +549,17 @@ Result QueryPlanner::vector_scan( })}); forward_filter.reset(); } - auto doc_filter = std::make_shared(seg, query_info, - std::move(forward_filter_plan), - std::move(forward_filter)); + return std::make_shared(seg, query_info, + std::move(forward_filter_plan), + std::move(forward_filter)); +} + +Result QueryPlanner::vector_scan( + Segment::Ptr seg, QueryInfo::Ptr query_info, + std::unique_ptr forward_filter, + bool single_stage_search) { + auto doc_filter = + build_doc_filter(seg, query_info, forward_filter, single_stage_search); int topn = query_info->query_topn(); int batch_size = get_batch_size(*query_info, false); @@ -616,6 +637,28 @@ Result QueryPlanner::invert_scan( return std::make_shared(std::move(node), std::move(schema)); } +Result QueryPlanner::fts_scan( + Segment::Ptr seg, QueryInfo::Ptr query_info, + std::unique_ptr forward_filter, + bool single_stage_search) { + auto doc_filter = + build_doc_filter(seg, query_info, forward_filter, single_stage_search); + + auto topn = query_info->query_topn(); + int batch_size = get_batch_size(*query_info, false); + auto recall_node = std::make_shared( + std::move(seg), std::move(query_info), std::move(doc_filter), batch_size); + + auto source_node_options = + arrow::acero::SourceNodeOptions{recall_node->schema(), recall_node->gen(), + arrow::compute::Ordering::Implicit()}; + ac::Declaration node{"source", source_node_options}; + + node = ac::Declaration{ + "fetch", {std::move(node)}, ac::FetchNodeOptions{0, topn}}; + return std::make_shared(std::move(node), recall_node->schema()); +} + int QueryPlanner::get_batch_size(const QueryInfo &info, bool has_later_filter) { // ref https://arrow.apache.org/docs/developers/cpp/acero.html#batch-size if (!info.query_orderbys().empty() || has_later_filter) { diff --git a/src/db/sqlengine/planner/query_planner.h b/src/db/sqlengine/planner/query_planner.h index b93fa34e9..c0cc61993 100644 --- a/src/db/sqlengine/planner/query_planner.h +++ b/src/db/sqlengine/planner/query_planner.h @@ -22,6 +22,7 @@ #include #include "db/index/segment/segment.h" #include "db/sqlengine/analyzer/query_info.h" +#include "db/sqlengine/planner/doc_filter.h" #include "plan_info.h" namespace zvec::sqlengine { @@ -59,6 +60,15 @@ class QueryPlanner { Result forward_scan( Segment::Ptr seg, QueryInfo::Ptr query_info, std::unique_ptr forward_filter); + Result fts_scan( + Segment::Ptr seg, QueryInfo::Ptr query_info, + std::unique_ptr forward_filter, + bool single_stage_search); + + static DocFilter::Ptr build_doc_filter( + const Segment::Ptr &seg, const QueryInfo::Ptr &query_info, + std::unique_ptr &forward_filter, + bool single_stage_search); static int get_batch_size(const QueryInfo &info, bool has_later_filter); diff --git a/src/db/sqlengine/planner/vector_recall_node.cc b/src/db/sqlengine/planner/vector_recall_node.cc index f56bb44e8..f58d02c1b 100644 --- a/src/db/sqlengine/planner/vector_recall_node.cc +++ b/src/db/sqlengine/planner/vector_recall_node.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -159,7 +160,8 @@ Result VectorRecallNode::prepare() { query_params.data_type = vector_cond_->vector_schema()->data_type(); query_params.dimension = vector_cond_->dimension(); query_params.query_params = vector_cond_->query_params(); - auto brute_force_keys = doc_filter_->get_bf_by_keys_and_update(); + auto brute_force_keys = doc_filter_->get_bf_by_keys_and_update( + GlobalConfig::Instance().brute_force_by_keys_ratio()); if (brute_force_keys) { query_params.bf_pks.emplace_back(std::move(brute_force_keys.value())); } diff --git a/src/db/sqlengine/sqlengine_impl.cc b/src/db/sqlengine/sqlengine_impl.cc index 1f5bd5141..ece72b3b5 100644 --- a/src/db/sqlengine/sqlengine_impl.cc +++ b/src/db/sqlengine/sqlengine_impl.cc @@ -14,11 +14,17 @@ #include "db/sqlengine/sqlengine_impl.h" #include +#include #include #include +#include #include #include "db/common/constants.h" +#include "db/index/column/fts_column/fts_ast_rewriter.h" +#include "db/index/column/fts_column/fts_pipeline.h" +#include "db/index/column/fts_column/fts_query_ast.h" #include "db/sqlengine/analyzer/query_analyzer.h" +#include "db/sqlengine/parser/select_info.h" #include "db/sqlengine/parser/sql_info_helper.h" #include "db/sqlengine/parser/zvec_parser.h" #include "db/sqlengine/planner/op_register.h" @@ -120,6 +126,113 @@ Result SQLEngineImpl::execute_group_by( return fill_group_by_result(*query_info.value(), reader.value().get()); } +Result SQLEngineImpl::parse_fts_query( + CollectionSchema::Ptr collection, const std::string &field_name, + const Fts &fts, const QueryParams::Ptr &query_params) { + // Exactly one of query_string_ or match_string_ must be provided. + bool has_query = !fts.query_string_.empty(); + bool has_match_string = !fts.match_string_.empty(); + if (has_query == has_match_string) { + return tl::make_unexpected(Status::InvalidArgument( + "Exactly one of query_string or match_string must be provided")); + } + + auto *fts_query_param = dynamic_cast(query_params.get()); + + // Determine default operator once, shared by both query_string and + // match_string paths. Accept "and"/"or" case-insensitively, empty means OR; + // any other value is a user error and must be reported, not silently + // downgraded to OR. strcasecmp is mapped to _stricmp on MSVC by platform.h. + fts::FtsDefaultOperator default_op = fts::FtsDefaultOperator::OR; + if (fts_query_param) { + const auto &op_str = fts_query_param->default_operator(); + if (op_str.empty() || strcasecmp(op_str.c_str(), "or") == 0) { + default_op = fts::FtsDefaultOperator::OR; + } else if (strcasecmp(op_str.c_str(), "and") == 0) { + default_op = fts::FtsDefaultOperator::AND; + } else { + return tl::make_unexpected(Status::InvalidArgument( + "FTS default_operator must be empty, 'and' or 'or' (case-insensitive)" + ", got: ", + op_str)); + } + } + + // Tokenizer pipeline is required by both branches: query_string needs it to + // tokenize phrase contents and bare terms, match_string needs it to split + // the natural-language input. Resolve once and share. + auto *field_schema = collection->get_field(field_name); + if (!field_schema) { + return tl::make_unexpected( + Status::InvalidArgument("FTS field not found: ", field_name)); + } + auto fts_idx_param = + std::dynamic_pointer_cast(field_schema->index_params()); + if (!fts_idx_param) { + return tl::make_unexpected(Status::InvalidArgument( + "FTS field has no FtsIndexParams: ", field_name)); + } + auto pipeline_result = detail::AcquireFtsPipeline(*fts_idx_param); + if (!pipeline_result.has_value()) { + return tl::make_unexpected(Status::InternalError( + "Failed to create tokenizer pipeline for field: ", field_name, " ", + pipeline_result.error().message())); + } + auto &pipeline = pipeline_result.value(); + + fts::FtsAstNodePtr ast; + if (has_query) { + // Structured query expression: parse via ANTLR grammar; phrase/term + // bodies are tokenized through the same pipeline used at index time. + fts::FtsQueryParser fts_parser; + ast = fts_parser.parse(fts.query_string_, pipeline, default_op); + if (!ast) { + LOG_ERROR("FTS query parse failed: %s", fts_parser.err_msg().c_str()); + return tl::make_unexpected(Status::InvalidArgument( + "FTS query parse failed: ", fts_parser.err_msg())); + } + } else { + // Natural language match_string: tokenize and combine with default_op. + auto tokens = pipeline->process(fts.match_string_); + if (tokens.empty()) { + // Analyzer dropped everything → zero-doc query, not an error. + return std::make_shared(field_name, + std::make_unique()); + } + if (tokens.size() == 1) { + ast = std::make_unique(std::move(tokens[0].text)); + } else { + if (default_op == fts::FtsDefaultOperator::AND) { + auto and_node = std::make_unique(); + for (auto &token : tokens) { + and_node->children.push_back( + std::make_unique(std::move(token.text))); + } + ast = std::move(and_node); + } else { + auto or_node = std::make_unique(); + for (auto &token : tokens) { + or_node->children.push_back( + std::make_unique(std::move(token.text))); + } + ast = std::move(or_node); + } + } + } + + // Structural rewrite: dedup repeated terms (collapsed into a single node + // with summed boost), flatten same-type composites for better WAND pruning, + // propagate EmptyNode, and detect must/must_not contradictions. The pre- + // rewrite AST is logged at DEBUG so the transform is auditable. LOG_DEBUG + // is gated by the configured log level, so ast->text() is only built when + // debug logging is enabled. + LOG_DEBUG("FTS AST before rewrite: %s", ast ? ast->text().c_str() : ""); + fts::simplify(ast); + LOG_DEBUG("FTS AST after rewrite : %s", ast ? ast->text().c_str() : ""); + + return std::make_shared(field_name, std::move(ast)); +} + Result SQLEngineImpl::parse_sql_info( const CollectionSchema &schema, const SQLInfo::Ptr &sql_info) { profiler_->open_stage("analyze stage"); @@ -172,6 +285,21 @@ Result SQLEngineImpl::parse_request( return tl::make_unexpected(Status::InvalidArgument( "Convert message to SQL info failed: ", err_msg)); } + + // If the request carries an FTS query, parse it and attach to SelectInfo + // so that query_analyzer can propagate it to QueryInfo. + if (request.fts_.has_value()) { + auto fts_result = + parse_fts_query(collection, request.field_name_, request.fts_.value(), + request.query_params_); + if (!fts_result) { + return tl::make_unexpected(fts_result.error()); + } + auto select_info = + std::dynamic_pointer_cast(sql_info->base_info()); + select_info->set_fts_cond_info(std::move(fts_result.value())); + } + LOG_DEBUG("Sql info is %s", sql_info->to_string().c_str()); return parse_sql_info(*collection, std::move(sql_info)); } @@ -577,4 +705,4 @@ Result SQLEngineImpl::fill_group_by_result( return group_results; } -} // namespace zvec::sqlengine \ No newline at end of file +} // namespace zvec::sqlengine diff --git a/src/db/sqlengine/sqlengine_impl.h b/src/db/sqlengine/sqlengine_impl.h index 88c279283..d59222e1a 100644 --- a/src/db/sqlengine/sqlengine_impl.h +++ b/src/db/sqlengine/sqlengine_impl.h @@ -22,6 +22,8 @@ #include #include "analyzer/query_info.h" #include "common/group_by.h" +#include "db/index/column/fts_column/fts_query_ast.h" +#include "db/index/column/fts_column/parser/fts_query_parser.h" #include "db/sqlengine/common/util.h" #include "db/sqlengine/parser/sql_info.h" #include "db/sqlengine/sqlengine.h" @@ -67,6 +69,11 @@ class SQLEngineImpl : public SQLEngine { Result fill_group_by_result(const QueryInfo &query_info, arrow::RecordBatchReader *reader); + //! Parse FTS query into a FtsCondInfo (AST + field name). + Result parse_fts_query( + CollectionSchema::Ptr collection, const std::string &field_name, + const Fts &fts, const QueryParams::Ptr &query_params); + private: zvec::Profiler::Ptr profiler_; std::string execution_time_info_{}; diff --git a/src/include/zvec/ailego/internal/platform.h b/src/include/zvec/ailego/internal/platform.h index ccd33971e..d30cb8865 100644 --- a/src/include/zvec/ailego/internal/platform.h +++ b/src/include/zvec/ailego/internal/platform.h @@ -67,6 +67,7 @@ typedef unsigned int id_t; #define ailego_bswap64(x) _byteswap_uint64(x) #define strncasecmp _strnicmp +#define strcasecmp _stricmp #else // !_MSC_VER diff --git a/src/include/zvec/c_api.h b/src/include/zvec/c_api.h index 74cc1bfbd..01512f9f6 100644 --- a/src/include/zvec/c_api.h +++ b/src/include/zvec/c_api.h @@ -679,6 +679,24 @@ zvec_config_data_set_brute_force_by_keys_ratio(zvec_config_data_t *config, ZVEC_EXPORT float ZVEC_CALL zvec_config_data_get_brute_force_by_keys_ratio( const zvec_config_data_t *config); +/** + * @brief Set FTS brute force by keys ratio in configuration data + * @param config Configuration data pointer + * @param ratio FTS brute force by keys ratio + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_config_data_set_fts_brute_force_by_keys_ratio(zvec_config_data_t *config, + float ratio); + +/** + * @brief Get FTS brute force by keys ratio from configuration data + * @param config Configuration data pointer + * @return float FTS brute force by keys ratio + */ +ZVEC_EXPORT float ZVEC_CALL zvec_config_data_get_fts_brute_force_by_keys_ratio( + const zvec_config_data_t *config); + /** * @brief Set optimize thread count in configuration data * @param config Configuration data pointer @@ -697,6 +715,20 @@ zvec_config_data_set_optimize_thread_count(zvec_config_data_t *config, ZVEC_EXPORT uint32_t ZVEC_CALL zvec_config_data_get_optimize_thread_count(const zvec_config_data_t *config); +/** + * @brief Set jieba dict directory in configuration data + * @param dir Dict directory; NULL or empty leaves the field empty + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_config_data_set_jieba_dict_dir( + zvec_config_data_t *config, const char *dir); + +/** + * @brief Get jieba dict directory from configuration data + * @return Pointer owned by config (do not free); empty when unset + */ +ZVEC_EXPORT const char *ZVEC_CALL +zvec_config_data_get_jieba_dict_dir(const zvec_config_data_t *config); + // ============================================================================= // Initialization and Cleanup Interface // ============================================================================= @@ -722,6 +754,26 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_shutdown(void); */ ZVEC_EXPORT bool ZVEC_CALL zvec_is_initialized(void); +/** + * @brief Set the process-wide default jieba dict directory. + * + * For language SDKs to call on module load. Thread-safe, decoupled from + * zvec_initialize(); last writer wins. A subsequent zvec_initialize() with + * non-empty config.jieba_dict_dir overrides this. JiebaTokenizer priority: + * per-field > ZVEC_JIEBA_DICT_DIR > this. + * + * @param dir UTF-8 directory containing jieba.dict.utf8 + hmm_model.utf8; + * NULL or empty clears the value. + */ +ZVEC_EXPORT void ZVEC_CALL zvec_set_default_jieba_dict_dir(const char *dir); + +/** + * @brief Get the process-wide default jieba dict directory. + * @return Thread-local string valid until the next call on this thread; + * empty when unset. + */ +ZVEC_EXPORT const char *ZVEC_CALL zvec_get_default_jieba_dict_dir(void); + // ============================================================================= // Data Type Enumerations // ============================================================================= @@ -775,6 +827,7 @@ typedef uint32_t zvec_index_type_t; #define ZVEC_INDEX_TYPE_IVF 2 #define ZVEC_INDEX_TYPE_FLAT 3 #define ZVEC_INDEX_TYPE_INVERT 10 +#define ZVEC_INDEX_TYPE_FTS 11 /** * @brief Distance metric type codes (must match zvec::MetricType in @@ -977,6 +1030,34 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_index_params_get_invert_params( ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_index_params_set_invert_params( zvec_index_params_t *params, bool enable_range_opt, bool enable_wildcard); +/** + * @brief Set FTS index specific parameters + * @param params Index parameters (must be FTS type) + * @param tokenizer_name Tokenizer pipeline name (NULL keeps current value) + * @param filters Token filter names (NULL keeps current value) + * @param extra_params Additional tokenizer parameters (NULL keeps current + * value) + * @return ZVEC_OK on success, error code on failure + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_index_params_set_fts_params( + zvec_index_params_t *params, const char *tokenizer_name, + const zvec_string_array_t *filters, const char *extra_params); + +/** + * @brief Get FTS index parameters (all at once) + * @param params Index parameters (must be FTS type) + * @param out_tokenizer_name Output parameter for tokenizer name (can be NULL, + * owned by params, do not free) + * @param out_filters Output parameter for filter list (can be NULL); caller + * must call zvec_string_array_destroy() to free + * @param out_extra_params Output parameter for extra params (can be NULL, + * owned by params, do not free) + * @return ZVEC_OK on success, error code on failure + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_index_params_get_fts_params( + const zvec_index_params_t *params, const char **out_tokenizer_name, + zvec_string_array_t **out_filters, const char **out_extra_params); + // ============================================================================= // Query Parameters Structures (Opaque Pointer Pattern) // ============================================================================= @@ -1011,6 +1092,16 @@ typedef struct zvec_ivf_query_params_t zvec_ivf_query_params_t; */ typedef struct zvec_flat_query_params_t zvec_flat_query_params_t; +/** + * @brief FTS query parameters handle (opaque pointer) + * + * Internally maps to zvec::FtsQueryParams* (raw pointer). + * Created by zvec_query_params_fts_create() and destroyed by + * zvec_query_params_fts_destroy(). Caller owns the pointer and must explicitly + * destroy it. + */ +typedef struct zvec_fts_query_params_t zvec_fts_query_params_t; + // ============================================================================= // Query Structures (Opaque Pointer Pattern) @@ -1032,6 +1123,13 @@ typedef struct zvec_vector_query_t zvec_vector_query_t; */ typedef struct zvec_group_by_vector_query_t zvec_group_by_vector_query_t; +/** + * @brief FTS query payload structure (opaque pointer) + * Aligned with zvec::Fts + * Use zvec_fts_create() to create and zvec_fts_destroy() to destroy + */ +typedef struct zvec_fts_t zvec_fts_t; + // ============================================================================= // Query Parameters Management Functions @@ -1327,6 +1425,46 @@ zvec_query_params_flat_set_is_using_refiner(zvec_flat_query_params_t *params, ZVEC_EXPORT bool ZVEC_CALL zvec_query_params_flat_get_is_using_refiner( const zvec_flat_query_params_t *params); +// ----------------------------------------------------------------------------- +// zvec_fts_query_params_t (FTS Query Parameters) +// ----------------------------------------------------------------------------- + +/** + * @brief Create FTS query parameters + * @param default_operator Default boolean operator for adjacent bare terms: + * "OR" / "AND" (case-insensitive); NULL or "" keeps + * the built-in default + * @return zvec_fts_query_params_t* Pointer to the newly created FTS query + * parameters + */ +ZVEC_EXPORT zvec_fts_query_params_t *ZVEC_CALL +zvec_query_params_fts_create(const char *default_operator); + +/** + * @brief Destroy FTS query parameters + * @param params FTS query parameters pointer + */ +ZVEC_EXPORT void ZVEC_CALL +zvec_query_params_fts_destroy(zvec_fts_query_params_t *params); + +/** + * @brief Set default boolean operator + * @param params FTS query parameters pointer + * @param default_operator Default boolean operator + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_query_params_fts_set_default_operator(zvec_fts_query_params_t *params, + const char *default_operator); + +/** + * @brief Get default boolean operator + * @param params FTS query parameters pointer + * @return const char* Default boolean operator (owned by params, do not free) + */ +ZVEC_EXPORT const char *ZVEC_CALL zvec_query_params_fts_get_default_operator( + const zvec_fts_query_params_t *params); + // ----------------------------------------------------------------------------- // zvec_vector_query_t (Vector Query) // ----------------------------------------------------------------------------- @@ -1500,6 +1638,83 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_vector_query_set_ivf_params( ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_vector_query_set_flat_params( zvec_vector_query_t *query, zvec_flat_query_params_t *flat_params); +/** + * @brief Set FTS query parameters (takes ownership) + * @param query Vector query pointer + * @param fts_params FTS query parameters pointer + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_vector_query_set_fts_params( + zvec_vector_query_t *query, zvec_fts_query_params_t *fts_params); + +// ----------------------------------------------------------------------------- +// zvec_fts_t (FTS query payload) +// ----------------------------------------------------------------------------- + +/** + * @brief Create FTS query payload + * @return zvec_fts_t* Pointer to the newly created FTS query payload + */ +ZVEC_EXPORT zvec_fts_t *ZVEC_CALL zvec_fts_create(void); + +/** + * @brief Destroy FTS query payload + * @param fts FTS query payload pointer + */ +ZVEC_EXPORT void ZVEC_CALL zvec_fts_destroy(zvec_fts_t *fts); + +/** + * @brief Set FTS boolean / advanced query expression + * @param fts FTS query payload pointer + * @param query_string Query expression (NULL is treated as empty string) + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_fts_set_query_string(zvec_fts_t *fts, const char *query_string); + +/** + * @brief Set FTS natural-language match string + * @param fts FTS query payload pointer + * @param match_string Match string (NULL is treated as empty string) + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_fts_set_match_string(zvec_fts_t *fts, const char *match_string); + +/** + * @brief Get FTS query expression + * @param fts FTS query payload pointer + * @return const char* Query expression (owned by fts, do not free) + */ +ZVEC_EXPORT const char *ZVEC_CALL +zvec_fts_get_query_string(const zvec_fts_t *fts); + +/** + * @brief Get FTS match string + * @param fts FTS query payload pointer + * @return const char* Match string (owned by fts, do not free) + */ +ZVEC_EXPORT const char *ZVEC_CALL +zvec_fts_get_match_string(const zvec_fts_t *fts); + +/** + * @brief Set FTS payload on a vector query (payload is copied) + * @param query Vector query pointer + * @param fts FTS query payload pointer (NULL clears the payload) + * @return zvec_error_code_t Error code + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_vector_query_set_fts(zvec_vector_query_t *query, const zvec_fts_t *fts); + +/** + * @brief Get FTS payload attached to a vector query + * @param query Vector query pointer + * @return const zvec_fts_t* FTS payload (owned by query, do not free), or + * NULL if no payload is attached + */ +ZVEC_EXPORT const zvec_fts_t *ZVEC_CALL +zvec_vector_query_get_fts(const zvec_vector_query_t *query); + // ----------------------------------------------------------------------------- // zvec_group_by_vector_query_t (Group By Vector Query) // ----------------------------------------------------------------------------- diff --git a/src/include/zvec/db/config.h b/src/include/zvec/db/config.h index 29fe19674..d5e7827d6 100644 --- a/src/include/zvec/db/config.h +++ b/src/include/zvec/db/config.h @@ -16,6 +16,8 @@ #include #include #include +#include +#include #include #include @@ -92,10 +94,17 @@ class GlobalConfig : public ailego::Singleton { uint32_t query_thread_count; float invert_to_forward_scan_ratio; float brute_force_by_keys_ratio; + // Independent from brute_force_by_keys_ratio: per-candidate FTS cost + // (phrase phase-2 IO, BM25) is higher, so a tighter default fits. + float fts_brute_force_by_keys_ratio; // optimize uint32_t optimize_thread_count; + // FTS jieba tokenizer default dict dir (lowest-priority fallback; + // per-field config > ZVEC_JIEBA_DICT_DIR > this). Empty by default. + std::string jieba_dict_dir; + ConfigData(); }; @@ -104,6 +113,11 @@ class GlobalConfig : public ailego::Singleton { Status Validate(const ConfigData &config) const; + // Set the process-wide default jieba dict dir. Thread-safe and decoupled + // from Initialize() so language SDKs can call it on module load. + // Initialize() with a non-empty config.jieba_dict_dir overrides this. + void set_default_jieba_dict_dir(const std::string &dir); + // Read-only accessors uint64_t memory_limit_bytes() const noexcept; @@ -161,17 +175,29 @@ class GlobalConfig : public ailego::Singleton { return config_.brute_force_by_keys_ratio; } + //! FTS brute force by keys ratio (independent from brute_force_by_keys_ratio + //! because FTS per-candidate cost is higher). + float fts_brute_force_by_keys_ratio() const noexcept { + return config_.fts_brute_force_by_keys_ratio; + } + //! Optimize thread count uint32_t optimize_thread_count() const noexcept { return config_.optimize_thread_count; } + //! Effective jieba dict dir. Thread-safe. + std::string jieba_dict_dir() const; + private: // Configuration data ConfigData config_; // Atomic flag to ensure initialization happens only once std::atomic initialized_{false}; + + // Guards config_ fields that may be written outside Initialize(). + mutable std::mutex mutex_; }; -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/include/zvec/db/doc.h b/src/include/zvec/db/doc.h index f702a43c3..cf076f71a 100644 --- a/src/include/zvec/db/doc.h +++ b/src/include/zvec/db/doc.h @@ -364,6 +364,14 @@ using DocPtrMap = std::unordered_map; using WriteResults = std::vector; +struct Fts { + std::string query_string_; // FTS query expression (e.g. "+vector -slow + // \"exact phrase\"") + std::string match_string_; // Natural language match string, tokenized and + // combined using default_operator. Mutually + // exclusive with query_string_. +}; + struct VectorQuery { int topk_; std::string field_name_; @@ -378,6 +386,8 @@ struct VectorQuery { std::optional> output_fields_; QueryParams::Ptr query_params_; + std::optional fts_; + Status validate_and_sanitize(const FieldSchema *schema); }; diff --git a/src/include/zvec/db/index_params.h b/src/include/zvec/db/index_params.h index 5f6faff4e..a6649b88b 100644 --- a/src/include/zvec/db/index_params.h +++ b/src/include/zvec/db/index_params.h @@ -16,13 +16,20 @@ #include #include #include +#include #include +#include #include #include "zvec/core/framework/index_provider.h" #include "zvec/core/framework/index_reformer.h" namespace zvec { +namespace detail { +struct FtsState; +struct FtsPipelineHelper; +} // namespace detail + /* * Column index params */ @@ -558,4 +565,84 @@ class VamanaIndexParams : public VectorIndexParams { bool use_id_map_; }; +/* + * FTS (Full-Text Search) index params + * + * Not copyable. Use shared_ptr for shared ownership. + */ +class FtsIndexParams : public IndexParams { + public: + FtsIndexParams(std::string tokenizer_name = "standard", + std::vector filters = {"lowercase"}, + std::string extra_params = ""); + + // Not copyable. + FtsIndexParams(const FtsIndexParams &) = delete; + FtsIndexParams &operator=(const FtsIndexParams &) = delete; + + // Movable. + FtsIndexParams(FtsIndexParams &&other) noexcept; + FtsIndexParams &operator=(FtsIndexParams &&) = delete; + + ~FtsIndexParams() override; + + Ptr clone() const override { + return std::make_shared(tokenizer_name_, filters_, + extra_params_); + } + + std::string to_string() const override { + std::ostringstream oss; + oss << "{FtsIndexParams,tokenizer_name:" << tokenizer_name_ << ",filters:["; + for (size_t i = 0; i < filters_.size(); ++i) { + if (i > 0) { + oss << ","; + } + oss << filters_[i]; + } + oss << "],extra_params:" << extra_params_ << "}"; + return oss.str(); + } + + bool operator==(const IndexParams &other) const override { + if (type() != other.type()) { + return false; + } + auto &other_fts = static_cast(other); + return tokenizer_name_ == other_fts.tokenizer_name_ && + filters_ == other_fts.filters_ && + extra_params_ == other_fts.extra_params_; + } + + const std::string &tokenizer_name() const { + return tokenizer_name_; + } + void set_tokenizer_name(std::string tokenizer_name) { + tokenizer_name_ = std::move(tokenizer_name); + } + + const std::vector &filters() const { + return filters_; + } + void set_filters(std::vector filters) { + filters_ = std::move(filters); + } + + const std::string &extra_params() const { + return extra_params_; + } + void set_extra_params(std::string extra_params) { + extra_params_ = std::move(extra_params); + } + + private: + std::string tokenizer_name_; + std::vector filters_; + std::string extra_params_; + + std::unique_ptr state_; + + friend struct detail::FtsPipelineHelper; +}; + } // namespace zvec \ No newline at end of file diff --git a/src/include/zvec/db/query_params.h b/src/include/zvec/db/query_params.h index fc0667252..df148aed0 100644 --- a/src/include/zvec/db/query_params.h +++ b/src/include/zvec/db/query_params.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include #include @@ -197,4 +198,25 @@ class VamanaQueryParams : public QueryParams { int ef_search_; }; +class FtsQueryParams : public QueryParams { + public: + using Ptr = std::shared_ptr; + + FtsQueryParams() : QueryParams(IndexType::FTS) {} + ~FtsQueryParams() override = default; + + const std::string &default_operator() const { + return default_operator_; + } + + void set_default_operator(const std::string &default_operator) { + default_operator_ = default_operator; + } + + private: + // Default boolean operator for adjacent bare terms. + // Supported values (case-insensitive): "OR" (default), "AND". + std::string default_operator_; +}; + } // namespace zvec \ No newline at end of file diff --git a/src/include/zvec/db/schema.h b/src/include/zvec/db/schema.h index 80e6cabd4..291abc571 100644 --- a/src/include/zvec/db/schema.h +++ b/src/include/zvec/db/schema.h @@ -359,6 +359,10 @@ class CollectionSchema { FieldSchemaPtrList vector_fields() const; + bool has_fts_field() const; + + FieldSchemaPtrList fts_fields() const; + uint64_t max_doc_count_per_segment() const; void set_max_doc_count_per_segment(uint64_t max_doc_count_per_segment); diff --git a/src/include/zvec/db/type.h b/src/include/zvec/db/type.h index 31b8850f3..a48267994 100644 --- a/src/include/zvec/db/type.h +++ b/src/include/zvec/db/type.h @@ -28,6 +28,7 @@ enum class IndexType : uint32_t { HNSW_RABITQ = 4, VAMANA = 5, INVERT = 10, + FTS = 11, }; /* diff --git a/tests/c/c_api_test.c b/tests/c/c_api_test.c index 846cc548c..f292a4e87 100644 --- a/tests/c/c_api_test.c +++ b/tests/c/c_api_test.c @@ -4127,6 +4127,244 @@ void test_actual_vector_queries(void) { TEST_END(); } +// ============================================================================= +// FTS (full-text search) tests +// ============================================================================= + +void test_fts_index_params_functions(void) { + TEST_START(); + + // Defaults: tokenizer="standard", filters=["lowercase"], extra_params="". + zvec_index_params_t *params = zvec_index_params_create(ZVEC_INDEX_TYPE_FTS); + TEST_ASSERT(params != NULL); + TEST_ASSERT(zvec_index_params_get_type(params) == ZVEC_INDEX_TYPE_FTS); + + const char *tokenizer = NULL; + const char *extra = NULL; + zvec_string_array_t *filters = NULL; + zvec_error_code_t err = + zvec_index_params_get_fts_params(params, &tokenizer, &filters, &extra); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(tokenizer != NULL && strcmp(tokenizer, "standard") == 0); + TEST_ASSERT(extra != NULL && strcmp(extra, "") == 0); + TEST_ASSERT(filters != NULL && filters->count == 1); + TEST_ASSERT(strcmp(filters->strings[0].data, "lowercase") == 0); + zvec_string_array_destroy(filters); + filters = NULL; + + // Override via set; filters list of 2 + extra_params + tokenizer. + zvec_string_array_t *new_filters = zvec_string_array_create(2); + TEST_ASSERT(new_filters != NULL); + zvec_string_array_add(new_filters, 0, "lowercase"); + zvec_string_array_add(new_filters, 1, "stop"); + + err = zvec_index_params_set_fts_params(params, "jieba", new_filters, + "key=value"); + TEST_ASSERT(err == ZVEC_OK); + zvec_string_array_destroy(new_filters); + + err = zvec_index_params_get_fts_params(params, &tokenizer, &filters, &extra); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(tokenizer != NULL && strcmp(tokenizer, "jieba") == 0); + TEST_ASSERT(extra != NULL && strcmp(extra, "key=value") == 0); + TEST_ASSERT(filters != NULL && filters->count == 2); + TEST_ASSERT(strcmp(filters->strings[0].data, "lowercase") == 0); + TEST_ASSERT(strcmp(filters->strings[1].data, "stop") == 0); + zvec_string_array_destroy(filters); + + // Type-mismatch error path: invert params must not accept fts setter. + zvec_index_params_t *invert = + zvec_index_params_create(ZVEC_INDEX_TYPE_INVERT); + TEST_ASSERT(invert != NULL); + err = zvec_index_params_set_fts_params(invert, "standard", NULL, ""); + TEST_ASSERT(err == ZVEC_ERROR_INVALID_ARGUMENT); + zvec_index_params_destroy(invert); + + // index_type_to_string should report FTS. + const char *type_str = zvec_index_type_to_string(ZVEC_INDEX_TYPE_FTS); + TEST_ASSERT(type_str != NULL && strcmp(type_str, "FTS") == 0); + + zvec_index_params_destroy(params); + TEST_END(); +} + +void test_fts_query_params_functions(void) { + TEST_START(); + + // Empty default_operator → engine default (empty string). + zvec_fts_query_params_t *p0 = zvec_query_params_fts_create(NULL); + TEST_ASSERT(p0 != NULL); + const char *op0 = zvec_query_params_fts_get_default_operator(p0); + TEST_ASSERT(op0 != NULL && strcmp(op0, "") == 0); + zvec_query_params_fts_destroy(p0); + + // Explicit AND. + zvec_fts_query_params_t *p1 = zvec_query_params_fts_create("AND"); + TEST_ASSERT(p1 != NULL); + const char *op1 = zvec_query_params_fts_get_default_operator(p1); + TEST_ASSERT(op1 != NULL && strcmp(op1, "AND") == 0); + + zvec_error_code_t err = zvec_query_params_fts_set_default_operator(p1, "OR"); + TEST_ASSERT(err == ZVEC_OK); + const char *op2 = zvec_query_params_fts_get_default_operator(p1); + TEST_ASSERT(op2 != NULL && strcmp(op2, "OR") == 0); + + // NULL → invalid arg. + err = zvec_query_params_fts_set_default_operator(NULL, "AND"); + TEST_ASSERT(err == ZVEC_ERROR_INVALID_ARGUMENT); + + zvec_query_params_fts_destroy(p1); + TEST_END(); +} + +void test_fts_wiring_on_vector_query(void) { + TEST_START(); + + zvec_fts_t *fts = zvec_fts_create(); + TEST_ASSERT(fts != NULL); + TEST_ASSERT(strcmp(zvec_fts_get_query_string(fts), "") == 0); + TEST_ASSERT(strcmp(zvec_fts_get_match_string(fts), "") == 0); + + zvec_error_code_t err = + zvec_fts_set_query_string(fts, "+hello -world \"phrase\""); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT( + strcmp(zvec_fts_get_query_string(fts), "+hello -world \"phrase\"") == 0); + err = zvec_fts_set_match_string(fts, "machine learning"); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(strcmp(zvec_fts_get_match_string(fts), "machine learning") == 0); + + zvec_vector_query_t *query = zvec_vector_query_create(); + TEST_ASSERT(query != NULL); + TEST_ASSERT(zvec_vector_query_get_fts(query) == NULL); + + err = zvec_vector_query_set_fts(query, fts); + TEST_ASSERT(err == ZVEC_OK); + + const zvec_fts_t *got = zvec_vector_query_get_fts(query); + TEST_ASSERT(got != NULL); + TEST_ASSERT( + strcmp(zvec_fts_get_query_string(got), "+hello -world \"phrase\"") == 0); + TEST_ASSERT(strcmp(zvec_fts_get_match_string(got), "machine learning") == 0); + + // Setter copies the payload — mutating the original must not affect the + // attached one. + zvec_fts_set_query_string(fts, "changed"); + TEST_ASSERT( + strcmp(zvec_fts_get_query_string(zvec_vector_query_get_fts(query)), + "+hello -world \"phrase\"") == 0); + + // Clearing. + err = zvec_vector_query_set_fts(query, NULL); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(zvec_vector_query_get_fts(query) == NULL); + + // Attach FtsQueryParams (transfers ownership). + zvec_fts_query_params_t *fts_params = zvec_query_params_fts_create("AND"); + TEST_ASSERT(fts_params != NULL); + err = zvec_vector_query_set_fts_params(query, fts_params); + TEST_ASSERT(err == ZVEC_OK); + // Ownership transferred — do NOT call zvec_query_params_fts_destroy on it. + + zvec_vector_query_destroy(query); + zvec_fts_destroy(fts); + TEST_END(); +} + +void test_fts_end_to_end(void) { + TEST_START(); + + char temp_dir[] = "./zvec_test_fts_end_to_end"; + cleanup_temp_directory(temp_dir); + + zvec_collection_schema_t *schema = zvec_collection_schema_create("fts_e2e"); + TEST_ASSERT(schema != NULL); + if (!schema) { + TEST_END(); + return; + } + + // id (int64) — primary scalar + zvec_field_schema_t *id_field = + zvec_field_schema_create("id", ZVEC_DATA_TYPE_INT64, false, 0); + zvec_collection_schema_add_field(schema, id_field); + + // content (string) — FTS-indexed field, no vector field in the schema. + zvec_index_params_t *fts_params = + zvec_index_params_create(ZVEC_INDEX_TYPE_FTS); + TEST_ASSERT(fts_params != NULL); + zvec_field_schema_t *content_field = + zvec_field_schema_create("content", ZVEC_DATA_TYPE_STRING, false, 0); + zvec_field_schema_set_index_params(content_field, fts_params); + zvec_collection_schema_add_field(schema, content_field); + zvec_index_params_destroy(fts_params); + + zvec_collection_t *collection = NULL; + zvec_error_code_t err = + zvec_collection_create_and_open(temp_dir, schema, NULL, &collection); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(collection != NULL); + + if (collection) { + const char *texts[3] = { + "machine learning is fun", + "deep learning uses neural networks", + "vector databases store embeddings", + }; + zvec_doc_t *docs[3]; + for (int i = 0; i < 3; i++) { + docs[i] = zvec_doc_create(); + zvec_doc_set_pk(docs[i], zvec_test_make_pk(i + 1)); + int64_t id = i + 1; + zvec_doc_add_field_by_value(docs[i], "id", ZVEC_DATA_TYPE_INT64, &id, + sizeof(id)); + zvec_doc_add_field_by_value(docs[i], "content", ZVEC_DATA_TYPE_STRING, + texts[i], strlen(texts[i])); + } + + size_t success_count = 0, error_count = 0; + err = zvec_collection_insert(collection, (const zvec_doc_t **)docs, 3, + &success_count, &error_count); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(success_count == 3); + TEST_ASSERT(error_count == 0); + + zvec_collection_flush(collection); + + // FTS-only query (no query vector): match on "learning" should hit docs + // 1+2. + zvec_vector_query_t *query = zvec_vector_query_create(); + TEST_ASSERT(query != NULL); + zvec_vector_query_set_field_name(query, "content"); + zvec_vector_query_set_topk(query, 10); + zvec_vector_query_set_include_doc_id(query, true); + + zvec_fts_t *fts = zvec_fts_create(); + zvec_fts_set_match_string(fts, "learning"); + err = zvec_vector_query_set_fts(query, fts); + TEST_ASSERT(err == ZVEC_OK); + zvec_fts_destroy(fts); + + zvec_doc_t **results = NULL; + size_t result_count = 0; + err = zvec_collection_query(collection, query, &results, &result_count); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(result_count >= 2); + + zvec_docs_free(results, result_count); + zvec_vector_query_destroy(query); + + for (int i = 0; i < 3; i++) { + zvec_doc_destroy(docs[i]); + } + zvec_collection_destroy(collection); + } + + zvec_collection_schema_destroy(schema); + cleanup_temp_directory(temp_dir); + TEST_END(); +} + void test_index_creation_and_management(void) { TEST_START(); @@ -5449,6 +5687,12 @@ int main(void) { test_query_params_functions(); test_actual_vector_queries(); + // FTS tests + test_fts_index_params_functions(); + test_fts_query_params_functions(); + test_fts_wiring_on_vector_query(); + test_fts_end_to_end(); + // Performance tests // test_performance_benchmarks(); diff --git a/tests/db/collection_test.cc b/tests/db/collection_test.cc index 9e2adfbbb..ee454a8e2 100644 --- a/tests/db/collection_test.cc +++ b/tests/db/collection_test.cc @@ -4841,18 +4841,6 @@ TEST_F(CollectionTest, CornerCase_CreateAndOpen) { ASSERT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); std::cout << result.error().message() << std::endl; } - - { - std::cout << "Collection::CreateAndOpen case 5" << std::endl; - FileHelper::RemoveDirectory(col_path); - // abnormal schema - auto schema = TestHelper::CreateScalarSchema(); - auto result = Collection::CreateAndOpen(col_path, *schema, - CollectionOptions{false, true}); - ASSERT_FALSE(result.has_value()); - ASSERT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); - std::cout << result.error().message() << std::endl; - } } { @@ -5115,3 +5103,77 @@ TEST_F(CollectionTest, Feature_Fetch_OutputFields) { ASSERT_TRUE(collection->Destroy().ok()); } + +// FTS-only collection (no vector field). Covers Create / Insert / FTS Query +// / Delete / Optimize-with-rebuild round trip — the rebuild path exercises +// SegmentHelper::ReduceFts, which is the most invasive consumer of the +// "schema may have zero vector fields" relaxation. +TEST_F(CollectionTest, Feature_NoVectorCollection_FtsLifecycle) { + FileHelper::RemoveDirectory(col_path); + + auto schema = std::make_shared("fts_only"); + schema->add_field(std::make_shared("title", DataType::STRING)); + schema->add_field(std::make_shared( + "content", DataType::STRING, false, std::make_shared())); + + auto create_res = Collection::CreateAndOpen(col_path, *schema, + CollectionOptions{false, true}); + ASSERT_TRUE(create_res.has_value()) << create_res.error().message(); + auto col = create_res.value(); + + // Insert a corpus where 4 of 5 docs contain "hello". Doc 4 is the only + // doc without "hello"; we'll delete it later to verify Optimize correctly + // rewrites postings + stats. + auto make_doc = [](uint64_t id, const std::string &title, + const std::string &content) { + Doc d; + d.set_pk("pk_" + std::to_string(id)); + d.set("title", title); + d.set("content", content); + return d; + }; + std::vector docs; + docs.push_back(make_doc(0, "intro", "hello world")); + docs.push_back(make_doc(1, "guide", "hello foo bar")); + docs.push_back(make_doc(2, "tips", "hello baz")); + docs.push_back(make_doc(3, "more", "hello hello")); + docs.push_back(make_doc(4, "other", "nothing relevant")); + ASSERT_TRUE(col->Insert(docs).has_value()); + ASSERT_EQ(col->Stats().value().doc_count, 5u); + + auto fts_search = [&](const std::string &term) { + VectorQuery vq; + vq.field_name_ = "content"; + vq.topk_ = 10; + Fts fts_q; + fts_q.query_string_ = term; + vq.fts_ = fts_q; + auto r = col->Query(vq); + EXPECT_TRUE(r.has_value()) << r.error().message(); + return r.has_value() ? r.value() : DocPtrList{}; + }; + + // Baseline: 4 docs hit "hello". + ASSERT_EQ(fts_search("hello").size(), 4u); + + // Delete enough to push delete ratio above COMPACT_DELETE_RATIO_THRESHOLD + // (0.3) so the next Optimize sets rebuild=true and exercises ReduceFts. + // Drop pk_0 and pk_4: 2/5 = 40% deletes, and pk_0 carries one "hello". + ASSERT_TRUE(col->Delete({"pk_0", "pk_4"}).has_value()); + ASSERT_EQ(col->Stats().value().doc_count, 3u); + + // Tombstone filter applied at query time — "hello" now returns 3 docs. + ASSERT_EQ(fts_search("hello").size(), 3u); + // Doc 4 (only "nothing") is deleted ⇒ no hit for its unique term. + ASSERT_EQ(fts_search("nothing").size(), 0u); + + // Optimize physically removes tombstones and rebuilds FTS postings via + // FtsRocksdbReducer. Same recall expected after rebuild. + ASSERT_TRUE(col->Optimize().ok()); + ASSERT_EQ(col->Stats().value().doc_count, 3u); + ASSERT_EQ(fts_search("hello").size(), 3u); + ASSERT_EQ(fts_search("nothing").size(), 0u); + + col.reset(); + FileHelper::RemoveDirectory(col_path); +} diff --git a/tests/db/common/config_test.cc b/tests/db/common/config_test.cc index fe4f027f1..1ca75d815 100644 --- a/tests/db/common/config_test.cc +++ b/tests/db/common/config_test.cc @@ -43,6 +43,7 @@ TEST_F(ConfigTest, InitializeWithDefaultConfig) { ASSERT_GT(GlobalConfig::Instance().query_thread_count(), 0); ASSERT_EQ(GlobalConfig::Instance().invert_to_forward_scan_ratio(), 0.9f); ASSERT_EQ(GlobalConfig::Instance().brute_force_by_keys_ratio(), 0.1f); + ASSERT_EQ(GlobalConfig::Instance().fts_brute_force_by_keys_ratio(), 0.05f); ASSERT_GT(GlobalConfig::Instance().optimize_thread_count(), 0); } @@ -150,6 +151,16 @@ TEST_F(ConfigTest, ValidateConfigWithInvalidRatios) { ASSERT_NE(status.message().find( "brute_force_by_keys_ratio must be between 0 and 1"), std::string::npos); + + // Test invalid fts_brute_force_by_keys_ratio + config.brute_force_by_keys_ratio = 0.1f; // Reset to valid value + config.fts_brute_force_by_keys_ratio = -0.5f; // Invalid value + status = config_instance.Validate(config); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.code(), StatusCode::INVALID_ARGUMENT); + ASSERT_NE(status.message().find( + "fts_brute_force_by_keys_ratio must be between 0 and 1"), + std::string::npos); } TEST_F(ConfigTest, ValidateConfigWithInvalidFileLogSettings) { @@ -209,4 +220,26 @@ TEST_F(ConfigTest, LogConfigPolymorphism) { ASSERT_EQ(console_config->GetLoggerType(), CONSOLE_LOG_TYPE_NAME); ASSERT_EQ(file_config->GetLoggerType(), FILE_LOG_TYPE_NAME); +} + +// jieba_dict_dir is the only ConfigData field that can be written outside +// of Initialize() — language SDKs call set_default_jieba_dict_dir() at +// module-load to register the dict path they bundled. The setter is +// independent of the Initialize() one-shot lifecycle. +TEST_F(ConfigTest, JiebaDictDirSetterIsIndependentOfInitialize) { + auto saved = GlobalConfig::Instance().jieba_dict_dir(); + + // Setter works regardless of whether Initialize was called. + GlobalConfig::Instance().set_default_jieba_dict_dir("/tmp/zvec/dict-A"); + ASSERT_EQ(GlobalConfig::Instance().jieba_dict_dir(), "/tmp/zvec/dict-A"); + + // Last writer wins. + GlobalConfig::Instance().set_default_jieba_dict_dir("/tmp/zvec/dict-B"); + ASSERT_EQ(GlobalConfig::Instance().jieba_dict_dir(), "/tmp/zvec/dict-B"); + + // Empty clears. + GlobalConfig::Instance().set_default_jieba_dict_dir(""); + ASSERT_EQ(GlobalConfig::Instance().jieba_dict_dir(), ""); + + GlobalConfig::Instance().set_default_jieba_dict_dir(saved); } \ No newline at end of file diff --git a/tests/db/fts_query_test.cc b/tests/db/fts_query_test.cc new file mode 100644 index 000000000..e52489b7c --- /dev/null +++ b/tests/db/fts_query_test.cc @@ -0,0 +1,230 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "db/common/file_helper.h" +#include "zvec/db/collection.h" +#include "zvec/db/doc.h" +#include "zvec/db/index_params.h" +#include "zvec/db/options.h" +#include "zvec/db/schema.h" +#include "zvec/db/status.h" +#include "zvec/db/type.h" + +using namespace zvec; + +static const std::string kTestPath = "./test_fts_query"; + +class FtsQueryTest : public ::testing::Test { + protected: + void SetUp() override { + FileHelper::RemoveDirectory(kTestPath); + } + void TearDown() override { + FileHelper::RemoveDirectory(kTestPath); + } + + // Create a schema with one STRING field (for forward) and one FTS field. + static CollectionSchema::Ptr CreateFtsSchema() { + auto schema = std::make_shared("fts_demo"); + // A simple scalar field for forward store + schema->add_field(std::make_shared("title", DataType::STRING)); + // FTS indexed field + schema->add_field( + std::make_shared("content", DataType::STRING, false, + std::make_shared())); + // A vector field is required for Collection to work (segment open expects + // at least one vector field in the normal schema path). + schema->add_field(std::make_shared( + "vec", DataType::VECTOR_FP32, 4, false, + std::make_shared(MetricType::IP))); + return schema; + } + + static Doc MakeDoc(uint64_t id, const std::string &title, + const std::string &content) { + Doc doc; + doc.set_pk("pk_" + std::to_string(id)); + doc.set("title", title); + doc.set("content", content); + // dummy vector + doc.set>("vec", std::vector(4, float(id + 0.1))); + return doc; + } +}; + +TEST_F(FtsQueryTest, BasicFtsQuery) { + auto schema = CreateFtsSchema(); + CollectionOptions options; + options.read_only_ = false; + + auto result = Collection::CreateAndOpen(kTestPath, *schema, options); + ASSERT_TRUE(result.has_value()) << result.error().message(); + auto col = result.value(); + + // Insert documents + std::vector docs; + docs.push_back(MakeDoc(0, "intro", "hello world from zvec")); + docs.push_back(MakeDoc(1, "guide", "hello foo bar")); + docs.push_back(MakeDoc(2, "faq", "baz qux nothing here")); + docs.push_back(MakeDoc(3, "tips", "hello hello hello world")); + + auto insert_res = col->Insert(docs); + ASSERT_TRUE(insert_res.has_value()) << insert_res.error().message(); + + // FTS query: search for "hello" + VectorQuery vq; + vq.field_name_ = "content"; + vq.topk_ = 10; + Fts fts; + fts.query_string_ = "hello"; + vq.fts_ = fts; + + auto query_res = col->Query(vq); + ASSERT_TRUE(query_res.has_value()) << query_res.error().message(); + + auto &results = query_res.value(); + // Documents 0, 1, 3 contain "hello"; document 2 does not. + ASSERT_GE(results.size(), 2u); + ASSERT_LE(results.size(), 3u); +} + +TEST_F(FtsQueryTest, FtsQueryEmptyField) { + auto schema = CreateFtsSchema(); + CollectionOptions options; + options.read_only_ = false; + + auto result = Collection::CreateAndOpen(kTestPath, *schema, options); + ASSERT_TRUE(result.has_value()); + auto col = result.value(); + + VectorQuery vq; + vq.field_name_ = ""; // empty + vq.topk_ = 10; + Fts fts; + fts.query_string_ = "hello"; + vq.fts_ = fts; + + auto query_res = col->Query(vq); + ASSERT_FALSE(query_res.has_value()); +} + +TEST_F(FtsQueryTest, FtsQueryNoMatch) { + auto schema = CreateFtsSchema(); + CollectionOptions options; + options.read_only_ = false; + + auto result = Collection::CreateAndOpen(kTestPath, *schema, options); + ASSERT_TRUE(result.has_value()); + auto col = result.value(); + + std::vector docs; + docs.push_back(MakeDoc(0, "intro", "hello world")); + auto insert_res = col->Insert(docs); + ASSERT_TRUE(insert_res.has_value()); + + VectorQuery vq; + vq.field_name_ = "content"; + vq.topk_ = 10; + Fts fts; + fts.query_string_ = "nonexistent_term_xyz"; + vq.fts_ = fts; + + auto query_res = col->Query(vq); + ASSERT_TRUE(query_res.has_value()); + ASSERT_EQ(query_res.value().size(), 0u); +} + +// Verify that FTS fields do NOT support add/alter/drop column operations. +// The schema change validation only allows basic numeric types [INT32..DOUBLE]. +TEST_F(FtsQueryTest, FtsFieldUnsupportedAddColumn) { + auto schema = CreateFtsSchema(); + CollectionOptions options; + options.read_only_ = false; + + auto result = Collection::CreateAndOpen(kTestPath, *schema, options); + ASSERT_TRUE(result.has_value()); + auto col = result.value(); + + // Insert a document so the collection is non-empty + std::vector docs; + docs.push_back(MakeDoc(0, "intro", "hello world")); + auto insert_res = col->Insert(docs); + ASSERT_TRUE(insert_res.has_value()); + ASSERT_TRUE(col->Flush().ok()); + + // Attempt to add a new FTS column — should fail + auto fts_field = std::make_shared( + "new_fts", DataType::STRING, true, std::make_shared()); + auto status = col->AddColumn(fts_field, "", AddColumnOptions()); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.code(), StatusCode::INVALID_ARGUMENT); +} + +TEST_F(FtsQueryTest, FtsFieldUnsupportedDropColumn) { + auto schema = CreateFtsSchema(); + CollectionOptions options; + options.read_only_ = false; + + auto result = Collection::CreateAndOpen(kTestPath, *schema, options); + ASSERT_TRUE(result.has_value()); + auto col = result.value(); + + // Insert a document so the collection is non-empty + std::vector docs; + docs.push_back(MakeDoc(0, "intro", "hello world")); + auto insert_res = col->Insert(docs); + ASSERT_TRUE(insert_res.has_value()); + ASSERT_TRUE(col->Flush().ok()); + + // Attempt to drop an existing FTS column — should fail + auto status = col->DropColumn("content"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.code(), StatusCode::INVALID_ARGUMENT); +} + +TEST_F(FtsQueryTest, FtsFieldUnsupportedAlterColumn) { + auto schema = CreateFtsSchema(); + CollectionOptions options; + options.read_only_ = false; + + auto result = Collection::CreateAndOpen(kTestPath, *schema, options); + ASSERT_TRUE(result.has_value()); + auto col = result.value(); + + // Insert a document so the collection is non-empty + std::vector docs; + docs.push_back(MakeDoc(0, "intro", "hello world")); + auto insert_res = col->Insert(docs); + ASSERT_TRUE(insert_res.has_value()); + ASSERT_TRUE(col->Flush().ok()); + + // Attempt to alter (rename) the FTS column — should fail + auto status = col->AlterColumn("content", "content_renamed", nullptr, + AlterColumnOptions()); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.code(), StatusCode::INVALID_ARGUMENT); + + // Attempt to alter the FTS column with a new schema — should also fail + auto new_fts_field = std::make_shared( + "content", DataType::STRING, true, std::make_shared()); + status = col->AlterColumn("content", "", new_fts_field, AlterColumnOptions()); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.code(), StatusCode::INVALID_ARGUMENT); +} diff --git a/tests/db/index/CMakeLists.txt b/tests/db/index/CMakeLists.txt index d600dca6a..441f49009 100644 --- a/tests/db/index/CMakeLists.txt +++ b/tests/db/index/CMakeLists.txt @@ -54,3 +54,10 @@ foreach(CC_SRCS ${ALL_TEST_SRCS}) ) cc_test_suite(zvec_index ${CC_TARGET}) endforeach() + +# Inject TEST_SOURCE_DIR for fts_column_indexer_test so it can locate testdata/ +if(TARGET fts_column_indexer_test) + target_compile_definitions(fts_column_indexer_test PRIVATE + TEST_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}/column/fts_column" + JIEBA_DICT_DIR="${PROJECT_SOURCE_DIR}/thirdparty/cppjieba/cppjieba-5.6.7/dict") +endif() diff --git a/tests/db/index/column/fts_column/bitpacked_posting_list_test.cc b/tests/db/index/column/fts_column/bitpacked_posting_list_test.cc new file mode 100644 index 000000000..76d28cd6e --- /dev/null +++ b/tests/db/index/column/fts_column/bitpacked_posting_list_test.cc @@ -0,0 +1,681 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "db/index/column/fts_column/posting/bitpacked_posting_list.h" +#include +#include +#include +#include +#include +#include +#include "db/index/column/fts_column/bm25_scorer.h" + +using namespace zvec::fts; + +// ============================================================ +// Helper: create a BM25Scorer with reasonable defaults +// ============================================================ + +static BM25Scorer make_scorer(uint64_t total_docs = 1000, + uint64_t total_tokens = 50000) { + BM25Scorer scorer; + scorer.update_stats(total_docs, total_tokens); + return scorer; +} + +// ============================================================ +// bits_needed() +// ============================================================ + +TEST(BitPackedPostingListTest, BitsNeededZero) { + EXPECT_EQ(BitPackedPostingList::bits_needed(0), 0); +} + +TEST(BitPackedPostingListTest, BitsNeededOne) { + EXPECT_EQ(BitPackedPostingList::bits_needed(1), 1); +} + +TEST(BitPackedPostingListTest, BitsNeededPowerOfTwo) { + EXPECT_EQ(BitPackedPostingList::bits_needed(2), 2); + EXPECT_EQ(BitPackedPostingList::bits_needed(4), 3); + EXPECT_EQ(BitPackedPostingList::bits_needed(8), 4); + EXPECT_EQ(BitPackedPostingList::bits_needed(256), 9); + EXPECT_EQ(BitPackedPostingList::bits_needed(1024), 11); +} + +TEST(BitPackedPostingListTest, BitsNeededNonPowerOfTwo) { + EXPECT_EQ(BitPackedPostingList::bits_needed(3), 2); + EXPECT_EQ(BitPackedPostingList::bits_needed(5), 3); + EXPECT_EQ(BitPackedPostingList::bits_needed(7), 3); + EXPECT_EQ(BitPackedPostingList::bits_needed(255), 8); + EXPECT_EQ(BitPackedPostingList::bits_needed(1023), 10); +} + +TEST(BitPackedPostingListTest, BitsNeededMaxUint32) { + EXPECT_EQ(BitPackedPostingList::bits_needed(0xFFFFFFFF), 32); +} + +// ============================================================ +// pack_uint32 / unpack_uint32 round-trip +// ============================================================ + +class BitPackingTest : public ::testing::TestWithParam {}; + +TEST_P(BitPackingTest, PackUnpackRoundTrip128) { + const uint8_t bitwidth = GetParam(); + if (bitwidth == 0) return; + + const uint32_t count = 128; + const uint32_t mask = + (bitwidth == 32) ? 0xFFFFFFFFu : ((1u << bitwidth) - 1u); + + // Generate test values + std::vector original(count); + for (uint32_t i = 0; i < count; ++i) { + original[i] = (i * 17 + 3) & mask; // deterministic pattern + } + + // Pack + const size_t packed_size = + BitPackedPostingList::packed_byte_size(bitwidth, count); + std::vector packed(packed_size, 0); + BitPackedPostingList::pack_uint32(original.data(), bitwidth, count, + packed.data()); + + // Unpack + std::vector decoded(count, 0); + BitPackedPostingList::unpack_uint32(packed.data(), bitwidth, count, + decoded.data()); + + // Verify + for (uint32_t i = 0; i < count; ++i) { + EXPECT_EQ(decoded[i], original[i]) + << "Mismatch at index " << i << " with bitwidth " << (int)bitwidth; + } +} + +TEST_P(BitPackingTest, PackUnpackRoundTripSmall) { + const uint8_t bitwidth = GetParam(); + if (bitwidth == 0) return; + + // Test with a small count (not a full block) + const uint32_t count = 7; + const uint32_t mask = + (bitwidth == 32) ? 0xFFFFFFFFu : ((1u << bitwidth) - 1u); + + std::vector original(count); + for (uint32_t i = 0; i < count; ++i) { + original[i] = i & mask; + } + + const size_t packed_size = + BitPackedPostingList::packed_byte_size(bitwidth, count); + std::vector packed(packed_size, 0); + BitPackedPostingList::pack_uint32(original.data(), bitwidth, count, + packed.data()); + + std::vector decoded(count, 0); + BitPackedPostingList::unpack_uint32(packed.data(), bitwidth, count, + decoded.data()); + + for (uint32_t i = 0; i < count; ++i) { + EXPECT_EQ(decoded[i], original[i]) + << "Mismatch at index " << i << " with bitwidth " << (int)bitwidth; + } +} + +// Test all bitwidths from 1 to 32 +INSTANTIATE_TEST_SUITE_P(AllBitwidths, BitPackingTest, + ::testing::Range(static_cast(1), + static_cast(33))); + +TEST(BitPackingTest, PackUnpackZeroBitwidth) { + const uint32_t count = 128; + std::vector original(count, 0); + std::vector decoded(count, 99); + + // bitwidth 0: all values must be 0 + BitPackedPostingList::unpack_uint32(nullptr, 0, count, decoded.data()); + for (uint32_t i = 0; i < count; ++i) { + EXPECT_EQ(decoded[i], 0u); + } +} + +// ============================================================ +// Encode / Decode: empty posting list +// ============================================================ + +TEST(BitPackedPostingListTest, EncodeDecodeEmpty) { + BM25Scorer scorer = make_scorer(); + std::string encoded = + BitPackedPostingList::encode(nullptr, nullptr, nullptr, 0, 0, scorer); + + EXPECT_TRUE(BitPackedPostingList::is_bitpacked_format(encoded.data(), + encoded.size())); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + EXPECT_EQ(iter.cost(), 0u); + EXPECT_EQ(iter.next_doc(), BitPackedPostingIterator::NO_MORE_DOCS); +} + +// ============================================================ +// Encode / Decode: single element +// ============================================================ + +TEST(BitPackedPostingListTest, EncodeDecodeSingleElement) { + BM25Scorer scorer = make_scorer(); + uint32_t doc_ids[] = {42}; + uint32_t tfs[] = {3}; + uint32_t doc_lens[] = {100}; + + std::string encoded = + BitPackedPostingList::encode(doc_ids, tfs, doc_lens, 1, 1, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + EXPECT_EQ(iter.cost(), 1u); + + EXPECT_EQ(iter.next_doc(), 42u); + EXPECT_EQ(iter.doc_id(), 42u); + EXPECT_EQ(iter.term_freq(), 3u); + EXPECT_EQ(iter.doc_len(), 100u); + + EXPECT_EQ(iter.next_doc(), BitPackedPostingIterator::NO_MORE_DOCS); +} + +// ============================================================ +// Encode / Decode: small list (< 128) +// ============================================================ + +TEST(BitPackedPostingListTest, EncodeDecodeSmallList) { + BM25Scorer scorer = make_scorer(); + const size_t count = 10; + std::vector doc_ids(count); + std::vector tfs(count); + std::vector doc_lens(count); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i * 5); + tfs[i] = static_cast(i + 1); + doc_lens[i] = static_cast(50 + i * 10); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + EXPECT_EQ(iter.cost(), count); + + for (size_t i = 0; i < count; ++i) { + uint32_t doc = iter.next_doc(); + EXPECT_EQ(doc, doc_ids[i]) << "Mismatch at index " << i; + EXPECT_EQ(iter.term_freq(), tfs[i]) << "TF mismatch at index " << i; + EXPECT_EQ(iter.doc_len(), doc_lens[i]) << "DocLen mismatch at index " << i; + } + + EXPECT_EQ(iter.next_doc(), BitPackedPostingIterator::NO_MORE_DOCS); +} + +// ============================================================ +// Encode / Decode: exactly 128 elements (one full block) +// ============================================================ + +TEST(BitPackedPostingListTest, EncodeDecodeExact128) { + BM25Scorer scorer = make_scorer(); + const size_t count = 128; + std::vector doc_ids(count); + std::vector tfs(count); + std::vector doc_lens(count); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i * 3); + tfs[i] = static_cast((i % 10) + 1); + doc_lens[i] = static_cast(100 + i); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + EXPECT_EQ(iter.cost(), count); + + for (size_t i = 0; i < count; ++i) { + uint32_t doc = iter.next_doc(); + EXPECT_EQ(doc, doc_ids[i]) << "Mismatch at index " << i; + EXPECT_EQ(iter.term_freq(), tfs[i]) << "TF mismatch at index " << i; + EXPECT_EQ(iter.doc_len(), doc_lens[i]) << "DocLen mismatch at index " << i; + } + + EXPECT_EQ(iter.next_doc(), BitPackedPostingIterator::NO_MORE_DOCS); +} + +// ============================================================ +// Encode / Decode: 129 elements (two blocks, last block has 1 element) +// ============================================================ + +TEST(BitPackedPostingListTest, EncodeDecodeCrossBlockBoundary) { + BM25Scorer scorer = make_scorer(); + const size_t count = 129; + std::vector doc_ids(count); + std::vector tfs(count); + std::vector doc_lens(count); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i * 2); + tfs[i] = static_cast((i % 5) + 1); + doc_lens[i] = static_cast(200 + i); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + EXPECT_EQ(iter.cost(), count); + + for (size_t i = 0; i < count; ++i) { + uint32_t doc = iter.next_doc(); + EXPECT_EQ(doc, doc_ids[i]) << "Mismatch at index " << i; + EXPECT_EQ(iter.term_freq(), tfs[i]) << "TF mismatch at index " << i; + EXPECT_EQ(iter.doc_len(), doc_lens[i]) << "DocLen mismatch at index " << i; + } + + EXPECT_EQ(iter.next_doc(), BitPackedPostingIterator::NO_MORE_DOCS); +} + +// ============================================================ +// Encode / Decode: large list (multiple blocks) +// ============================================================ + +TEST(BitPackedPostingListTest, EncodeDecodeLargeList) { + BM25Scorer scorer = make_scorer(10000, 500000); + const size_t count = 1000; + std::vector doc_ids(count); + std::vector tfs(count); + std::vector doc_lens(count); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i * 10); + tfs[i] = static_cast((i % 20) + 1); + doc_lens[i] = static_cast(50 + (i % 200)); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + EXPECT_EQ(iter.cost(), count); + + for (size_t i = 0; i < count; ++i) { + uint32_t doc = iter.next_doc(); + EXPECT_EQ(doc, doc_ids[i]) << "Mismatch at index " << i; + EXPECT_EQ(iter.term_freq(), tfs[i]) << "TF mismatch at index " << i; + EXPECT_EQ(iter.doc_len(), doc_lens[i]) << "DocLen mismatch at index " << i; + } + + EXPECT_EQ(iter.next_doc(), BitPackedPostingIterator::NO_MORE_DOCS); +} + +// ============================================================ +// advance(): basic skip-list functionality +// ============================================================ + +TEST(BitPackedPostingListTest, AdvanceToExactDocId) { + BM25Scorer scorer = make_scorer(); + const size_t count = 500; + std::vector doc_ids(count); + std::vector tfs(count, 1); + std::vector doc_lens(count, 100); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i * 3); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + + // Advance to exact doc_id + EXPECT_EQ(iter.advance(300), 300u); + EXPECT_EQ(iter.doc_id(), 300u); + + // Advance to a doc_id that doesn't exist (should return next >= target) + EXPECT_EQ(iter.advance(301), 303u); + EXPECT_EQ(iter.doc_id(), 303u); +} + +TEST(BitPackedPostingListTest, AdvanceToFirstDoc) { + BM25Scorer scorer = make_scorer(); + uint32_t doc_ids[] = {10, 20, 30, 40, 50}; + uint32_t tfs[] = {1, 2, 3, 4, 5}; + uint32_t doc_lens[] = {100, 200, 300, 400, 500}; + + std::string encoded = + BitPackedPostingList::encode(doc_ids, tfs, doc_lens, 5, 5, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + + // Advance to 0 should return the first doc (10) + EXPECT_EQ(iter.advance(0), 10u); + EXPECT_EQ(iter.term_freq(), 1u); + EXPECT_EQ(iter.doc_len(), 100u); +} + +TEST(BitPackedPostingListTest, AdvanceBeyondLastDoc) { + BM25Scorer scorer = make_scorer(); + uint32_t doc_ids[] = {10, 20, 30}; + uint32_t tfs[] = {1, 2, 3}; + uint32_t doc_lens[] = {100, 200, 300}; + + std::string encoded = + BitPackedPostingList::encode(doc_ids, tfs, doc_lens, 3, 3, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + + EXPECT_EQ(iter.advance(31), BitPackedPostingIterator::NO_MORE_DOCS); +} + +TEST(BitPackedPostingListTest, AdvanceAcrossBlocks) { + BM25Scorer scorer = make_scorer(); + const size_t count = 300; + std::vector doc_ids(count); + std::vector tfs(count, 2); + std::vector doc_lens(count, 50); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i * 5); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + + // Advance from start to a doc in the 3rd block (block 2, index 256+) + // Block 0: doc_ids 0..635 (indices 0..127) + // Block 1: doc_ids 640..1275 (indices 128..255) + // Block 2: doc_ids 1280..1495 (indices 256..299) + EXPECT_EQ(iter.advance(1280), 1280u); + EXPECT_EQ(iter.doc_id(), 1280u); + EXPECT_EQ(iter.term_freq(), 2u); + + // Continue with next_doc + EXPECT_EQ(iter.next_doc(), 1285u); +} + +TEST(BitPackedPostingListTest, AdvanceSequentialCalls) { + BM25Scorer scorer = make_scorer(); + const size_t count = 200; + std::vector doc_ids(count); + std::vector tfs(count, 1); + std::vector doc_lens(count, 100); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i * 7); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + + // Multiple sequential advance calls + EXPECT_EQ(iter.advance(100), 105u); // 15*7=105 + EXPECT_EQ(iter.advance(500), 504u); // 72*7=504 + EXPECT_EQ(iter.advance(1000), 1001u); // 143*7=1001 + EXPECT_EQ(iter.advance(1400), BitPackedPostingIterator::NO_MORE_DOCS); +} + +// ============================================================ +// advance() after next_doc() +// ============================================================ + +TEST(BitPackedPostingListTest, AdvanceAfterNextDoc) { + BM25Scorer scorer = make_scorer(); + const size_t count = 256; + std::vector doc_ids(count); + std::vector tfs(count, 1); + std::vector doc_lens(count, 50); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i * 4); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + + // Read a few docs + EXPECT_EQ(iter.next_doc(), 0u); + EXPECT_EQ(iter.next_doc(), 4u); + EXPECT_EQ(iter.next_doc(), 8u); + + // Now advance past the current block + EXPECT_EQ(iter.advance(600), 600u); // 150*4=600 + EXPECT_EQ(iter.term_freq(), 1u); + + // Continue with next_doc + EXPECT_EQ(iter.next_doc(), 604u); +} + +// ============================================================ +// block_max_score correctness +// ============================================================ + +TEST(BitPackedPostingListTest, BlockMaxScoreCorrectness) { + BM25Scorer scorer = make_scorer(100, 5000); + const size_t count = 256; // 2 blocks + std::vector doc_ids(count); + std::vector tfs(count); + std::vector doc_lens(count); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i); + tfs[i] = static_cast((i % 10) + 1); + doc_lens[i] = static_cast(50 + (i % 50)); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + + // Verify block_max_score for block 0 via block_max_info_for() + auto info0 = iter.block_max_info_for(0); + + // Manually compute max score for block 0 + float expected_max = 0.0f; + for (size_t i = 0; i < 128; ++i) { + float s = scorer.score(count, tfs[i], doc_lens[i]); + expected_max = std::max(expected_max, s); + } + EXPECT_FLOAT_EQ(info0.block_max_score, expected_max); + EXPECT_EQ(info0.block_last_doc, 127u); + + // Verify block_max_score for block 1 via block_max_info_for() + auto info1 = iter.block_max_info_for(128); + + expected_max = 0.0f; + for (size_t i = 128; i < 256; ++i) { + float s = scorer.score(count, tfs[i], doc_lens[i]); + expected_max = std::max(expected_max, s); + } + EXPECT_FLOAT_EQ(info1.block_max_score, expected_max); + EXPECT_EQ(info1.block_last_doc, 255u); +} + +// ============================================================ +// max_score() (global) +// ============================================================ + +TEST(BitPackedPostingListTest, GlobalMaxScore) { + BM25Scorer scorer = make_scorer(100, 5000); + const size_t count = 256; + std::vector doc_ids(count); + std::vector tfs(count); + std::vector doc_lens(count); + + for (size_t i = 0; i < count; ++i) { + doc_ids[i] = static_cast(i); + tfs[i] = static_cast((i % 10) + 1); + doc_lens[i] = static_cast(50 + (i % 50)); + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(encoded.data(), encoded.size()), 0); + + // Global max_score should be the maximum of all block_max_scores + float global_max = 0.0f; + for (size_t i = 0; i < count; ++i) { + float s = scorer.score(count, tfs[i], doc_lens[i]); + global_max = std::max(global_max, s); + } + EXPECT_FLOAT_EQ(iter.max_score(), global_max); +} + +// ============================================================ +// is_bitpacked_format() +// ============================================================ + +TEST(BitPackedPostingListTest, IsBitpackedFormatTrue) { + BM25Scorer scorer = make_scorer(); + uint32_t doc_ids[] = {1}; + uint32_t tfs[] = {1}; + uint32_t doc_lens[] = {10}; + + std::string encoded = + BitPackedPostingList::encode(doc_ids, tfs, doc_lens, 1, 1, scorer); + EXPECT_TRUE(BitPackedPostingList::is_bitpacked_format(encoded.data(), + encoded.size())); +} + +TEST(BitPackedPostingListTest, IsBitpackedFormatFalse) { + // Random data that doesn't start with the magic number + std::string random_data = "hello world"; + EXPECT_FALSE(BitPackedPostingList::is_bitpacked_format(random_data.data(), + random_data.size())); +} + +TEST(BitPackedPostingListTest, IsBitpackedFormatTooShort) { + std::string short_data = "ab"; + EXPECT_FALSE(BitPackedPostingList::is_bitpacked_format(short_data.data(), + short_data.size())); +} + +// ============================================================ +// Error handling: open() with invalid data +// ============================================================ + +TEST(BitPackedPostingListTest, OpenWithNullData) { + BitPackedPostingIterator iter; + EXPECT_NE(iter.open(nullptr, 0), 0); +} + +TEST(BitPackedPostingListTest, OpenWithTruncatedHeader) { + BitPackedPostingIterator iter; + char data[4] = {0}; + EXPECT_NE(iter.open(data, 4), 0); +} + +TEST(BitPackedPostingListTest, OpenWithBadMagic) { + BitPackedPostingIterator iter; + char data[16] = {0}; + EXPECT_NE(iter.open(data, 16), 0); +} + +// ============================================================ +// Consistency: advance() vs sequential next_doc() +// ============================================================ + +TEST(BitPackedPostingListTest, AdvanceConsistentWithNextDoc) { + BM25Scorer scorer = make_scorer(); + const size_t count = 500; + std::vector doc_ids(count); + std::vector tfs(count); + std::vector doc_lens(count); + + std::mt19937 rng(42); + uint32_t current = 0; + for (size_t i = 0; i < count; ++i) { + current += (rng() % 10) + 1; + doc_ids[i] = current; + tfs[i] = (rng() % 10) + 1; + doc_lens[i] = (rng() % 200) + 10; + } + + std::string encoded = BitPackedPostingList::encode( + doc_ids.data(), tfs.data(), doc_lens.data(), count, count, scorer); + + // Collect all docs via next_doc + BitPackedPostingIterator iter1; + EXPECT_EQ(iter1.open(encoded.data(), encoded.size()), 0); + std::vector all_docs; + std::vector all_tfs; + std::vector all_doc_lens; + uint32_t doc = iter1.next_doc(); + while (doc != BitPackedPostingIterator::NO_MORE_DOCS) { + all_docs.push_back(doc); + all_tfs.push_back(iter1.term_freq()); + all_doc_lens.push_back(iter1.doc_len()); + doc = iter1.next_doc(); + } + + ASSERT_EQ(all_docs.size(), count); + + // Verify advance to various targets matches sequential scan + BitPackedPostingIterator iter2; + EXPECT_EQ(iter2.open(encoded.data(), encoded.size()), 0); + + std::vector targets = {0, + 1, + doc_ids[50], + doc_ids[127], + doc_ids[128], + doc_ids[200], + doc_ids[count - 1]}; + + for (uint32_t target : targets) { + BitPackedPostingIterator iter_adv; + EXPECT_EQ(iter_adv.open(encoded.data(), encoded.size()), 0); + uint32_t adv_doc = iter_adv.advance(target); + + // Find expected result via linear scan + auto it = std::lower_bound(all_docs.begin(), all_docs.end(), target); + if (it == all_docs.end()) { + EXPECT_EQ(adv_doc, BitPackedPostingIterator::NO_MORE_DOCS) + << "target=" << target; + } else { + size_t idx = it - all_docs.begin(); + EXPECT_EQ(adv_doc, all_docs[idx]) << "target=" << target; + EXPECT_EQ(iter_adv.term_freq(), all_tfs[idx]) << "target=" << target; + EXPECT_EQ(iter_adv.doc_len(), all_doc_lens[idx]) << "target=" << target; + } + } +} diff --git a/tests/db/index/column/fts_column/fts_ast_rewriter_test.cc b/tests/db/index/column/fts_column/fts_ast_rewriter_test.cc new file mode 100644 index 000000000..8b17781c9 --- /dev/null +++ b/tests/db/index/column/fts_column/fts_ast_rewriter_test.cc @@ -0,0 +1,466 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "db/index/column/fts_column/fts_ast_rewriter.h" +#include +#include +#include +#include +#include +#include "db/index/column/fts_column/fts_query_ast.h" + +namespace zvec::fts { + +namespace { + +// Convenience constructors keep the test bodies focused on what's being +// asserted rather than on AST scaffolding. + +FtsAstNodePtr term(const std::string &t, bool must = false, + bool must_not = false, float boost = 1.0f) { + auto n = std::make_unique(t, must, must_not); + n->boost = boost; + return n; +} + +FtsAstNodePtr phrase(std::vector ts, bool must = false, + bool must_not = false, float boost = 1.0f) { + auto n = std::make_unique(); + n->terms = std::move(ts); + n->must = must; + n->must_not = must_not; + n->boost = boost; + return n; +} + +FtsAstNodePtr empty_node() { + return std::make_unique(); +} + +template +FtsAstNodePtr composite(std::vector children, bool must = false, + bool must_not = false) { + auto n = std::make_unique(); + n->children = std::move(children); + n->must = must; + n->must_not = must_not; + return n; +} + +FtsAstNodePtr or_node(std::vector c, bool must = false, + bool must_not = false) { + return composite(std::move(c), must, must_not); +} +FtsAstNodePtr and_node(std::vector c, bool must = false, + bool must_not = false) { + return composite(std::move(c), must, must_not); +} + +// Pull the single TermNode child out of a composite for boost assertions. +const TermNode &as_term(const FtsAstNode &n) { + return static_cast(n); +} +const PhraseNode &as_phrase(const FtsAstNode &n) { + return static_cast(n); +} +const OrNode &as_or(const FtsAstNode &n) { + return static_cast(n); +} +const AndNode &as_and(const FtsAstNode &n) { + return static_cast(n); +} + +} // namespace + +// --- Dedup --- + +TEST(FtsAstRewriterTest, OrDedupsRepeatedTerms) { + // OR(apple, apple, banana) → OR(apple^2, banana) + std::vector children; + children.push_back(term("apple")); + children.push_back(term("apple")); + children.push_back(term("banana")); + auto ast = or_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::OR); + const auto &n = as_or(*ast); + ASSERT_EQ(n.children.size(), 2u); + EXPECT_EQ(as_term(*n.children[0]).term, "apple"); + EXPECT_FLOAT_EQ(n.children[0]->boost, 2.0f); + EXPECT_EQ(as_term(*n.children[1]).term, "banana"); + EXPECT_FLOAT_EQ(n.children[1]->boost, 1.0f); +} + +TEST(FtsAstRewriterTest, AndDedupsRepeatedTerms) { + std::vector children; + children.push_back(term("apple")); + children.push_back(term("apple")); + children.push_back(term("apple")); + auto ast = and_node(std::move(children)); + + simplify(ast); + + // Single-child fold collapses AND(apple^3) → apple^3. + ASSERT_EQ(ast->type(), FtsNodeType::TERM); + EXPECT_EQ(as_term(*ast).term, "apple"); + EXPECT_FLOAT_EQ(ast->boost, 3.0f); +} + +TEST(FtsAstRewriterTest, DifferentOccurDoesNotMerge) { + // OR(apple, +apple) — same term, different modifiers must NOT collapse; + // dedup keys include the must/must_not bits so the two stay distinct. + std::vector children; + children.push_back(term("apple")); + children.push_back(term("apple", /*must=*/true)); + auto ast = or_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::OR); + EXPECT_EQ(as_or(*ast).children.size(), 2u); +} + +// --- Conflict --- + +TEST(FtsAstRewriterTest, AndMustVsMustNotSameTermBecomesEmpty) { + std::vector children; + children.push_back(term("apple", /*must=*/true)); + children.push_back(term("apple", /*must=*/false, /*must_not=*/true)); + children.push_back(term("banana")); + auto ast = and_node(std::move(children)); + + simplify(ast); + + EXPECT_EQ(ast->type(), FtsNodeType::EMPTY); +} + +TEST(FtsAstRewriterTest, AndAllMustNotBecomesEmpty) { + std::vector children; + children.push_back(term("apple", false, true)); + children.push_back(term("banana", false, true)); + auto ast = and_node(std::move(children)); + + simplify(ast); + + EXPECT_EQ(ast->type(), FtsNodeType::EMPTY); +} + +// --- Flattening --- + +TEST(FtsAstRewriterTest, OrFlattensNestedOr) { + // OR(a, OR(b, c)) → OR(a, b, c) + std::vector inner; + inner.push_back(term("b")); + inner.push_back(term("c")); + std::vector outer; + outer.push_back(term("a")); + outer.push_back(or_node(std::move(inner))); + auto ast = or_node(std::move(outer)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::OR); + ASSERT_EQ(as_or(*ast).children.size(), 3u); +} + +TEST(FtsAstRewriterTest, AndFlattensNestedAndWithoutMustNot) { + std::vector inner; + inner.push_back(term("b")); + inner.push_back(term("c", true)); + std::vector outer; + outer.push_back(term("a")); + outer.push_back(and_node(std::move(inner))); + auto ast = and_node(std::move(outer)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::AND); + EXPECT_EQ(as_and(*ast).children.size(), 3u); +} + +TEST(FtsAstRewriterTest, AndDoesNotFlattenAndWithMustNotChild) { + std::vector inner; + inner.push_back(term("b")); + inner.push_back(term("c", false, true)); + std::vector outer; + outer.push_back(term("a")); + outer.push_back(and_node(std::move(inner))); + auto ast = and_node(std::move(outer)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::AND); + EXPECT_EQ(as_and(*ast).children.size(), 2u); +} + +TEST(FtsAstRewriterTest, FlattenThenDedupCrossLayer) { + // OR(a, OR(a, b)) → flatten → OR(a, a, b) → dedup → OR(a^2, b) + std::vector inner; + inner.push_back(term("a")); + inner.push_back(term("b")); + std::vector outer; + outer.push_back(term("a")); + outer.push_back(or_node(std::move(inner))); + auto ast = or_node(std::move(outer)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::OR); + const auto &n = as_or(*ast); + ASSERT_EQ(n.children.size(), 2u); + EXPECT_EQ(as_term(*n.children[0]).term, "a"); + EXPECT_FLOAT_EQ(n.children[0]->boost, 2.0f); + EXPECT_EQ(as_term(*n.children[1]).term, "b"); +} + +// --- Phrase --- + +TEST(FtsAstRewriterTest, PhraseSameTermsAreDeduped) { + std::vector children; + children.push_back(phrase({"new", "york"})); + children.push_back(phrase({"new", "york"})); + children.push_back(phrase({"new", "york"})); + auto ast = or_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::PHRASE); + EXPECT_FLOAT_EQ(ast->boost, 3.0f); +} + +TEST(FtsAstRewriterTest, PhraseInternalRepeatNotMerged) { + // Position-sensitive: "new new york" must keep its internal duplication. + auto p = phrase({"new", "new", "york"}); + FtsAstNodePtr ast = std::move(p); + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::PHRASE); + ASSERT_EQ(as_phrase(*ast).terms.size(), 3u); + EXPECT_EQ(as_phrase(*ast).terms[0], "new"); + EXPECT_EQ(as_phrase(*ast).terms[1], "new"); + EXPECT_EQ(as_phrase(*ast).terms[2], "york"); +} + +TEST(FtsAstRewriterTest, DifferentPhrasesDoNotMerge) { + std::vector children; + children.push_back(phrase({"new", "york"})); + children.push_back(phrase({"york", "new"})); + auto ast = or_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::OR); + EXPECT_EQ(as_or(*ast).children.size(), 2u); +} + +// --- EmptyNode propagation --- + +TEST(FtsAstRewriterTest, AndWithEmptyChildShortCircuits) { + std::vector children; + children.push_back(term("apple")); + children.push_back(empty_node()); + auto ast = and_node(std::move(children)); + + simplify(ast); + + EXPECT_EQ(ast->type(), FtsNodeType::EMPTY); +} + +TEST(FtsAstRewriterTest, OrDropsEmptyChild) { + std::vector children; + children.push_back(term("apple")); + children.push_back(empty_node()); + children.push_back(term("banana")); + auto ast = or_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::OR); + EXPECT_EQ(as_or(*ast).children.size(), 2u); +} + +TEST(FtsAstRewriterTest, MustNotEmptyInAndIsNoOp) { + // AND(apple, -EMPTY) — excluding nothing has no effect. + std::vector children; + children.push_back(term("apple")); + auto e = std::make_unique(); + e->must_not = true; + children.push_back(std::move(e)); + auto ast = and_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::TERM); + EXPECT_EQ(as_term(*ast).term, "apple"); +} + +// --- Single-child fold --- + +TEST(FtsAstRewriterTest, SingleChildOrFolds) { + std::vector children; + children.push_back(term("apple")); + auto ast = or_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::TERM); + EXPECT_EQ(as_term(*ast).term, "apple"); +} + +TEST(FtsAstRewriterTest, FoldedSingleChildInheritsParentModifier) { + // +OR(apple) → +apple (must flag lifts onto the surviving child) + std::vector children; + children.push_back(term("apple")); + auto ast = or_node(std::move(children), /*must=*/true); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::TERM); + EXPECT_TRUE(ast->must); + EXPECT_FALSE(ast->must_not); +} + +// --- Idempotence --- + +TEST(FtsAstRewriterTest, SimplifyIsIdempotent) { + // Build something gnarly enough to exercise multiple rules at once. + std::vector inner_or; + inner_or.push_back(term("a")); + inner_or.push_back(term("a")); + std::vector children; + children.push_back(term("a")); + children.push_back(or_node(std::move(inner_or))); + children.push_back(term("b")); + children.push_back(empty_node()); + auto ast = or_node(std::move(children)); + + simplify(ast); + const std::string after_first = ast->text(); + simplify(ast); + const std::string after_second = ast->text(); + + EXPECT_EQ(after_first, after_second); +} + +// --- OR must_not canonicalization --- + +TEST(FtsAstRewriterTest, OrWithSinglePositiveAndMustNotBecomesAnd) { + // OR(a, -b) → AND(a, -b) + std::vector children; + children.push_back(term("a")); + children.push_back(term("b", false, true)); + auto ast = or_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::AND); + const auto &n = as_and(*ast); + ASSERT_EQ(n.children.size(), 2u); + EXPECT_EQ(as_term(*n.children[0]).term, "a"); + EXPECT_FALSE(n.children[0]->must_not); + EXPECT_EQ(as_term(*n.children[1]).term, "b"); + EXPECT_TRUE(n.children[1]->must_not); +} + +TEST(FtsAstRewriterTest, OrWithMultiplePositivesAndMustNotWrapsInAnd) { + // OR(a, b, -c) → AND(OR(a, b), -c) + std::vector children; + children.push_back(term("a")); + children.push_back(term("b")); + children.push_back(term("c", false, true)); + auto ast = or_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::AND); + const auto &n = as_and(*ast); + ASSERT_EQ(n.children.size(), 2u); + ASSERT_EQ(n.children[0]->type(), FtsNodeType::OR); + EXPECT_EQ(as_or(*n.children[0]).children.size(), 2u); + EXPECT_EQ(as_term(*n.children[1]).term, "c"); + EXPECT_TRUE(n.children[1]->must_not); +} + +TEST(FtsAstRewriterTest, OrCanonicalizationCatchesSameTermConflict) { + // OR(a, -a) — canonicalization moves -a into AND with a, then + // and_has_mustnot_conflict fires → EmptyNode. + std::vector children; + children.push_back(term("a")); + children.push_back(term("a", false, true)); + auto ast = or_node(std::move(children)); + + simplify(ast); + + EXPECT_EQ(ast->type(), FtsNodeType::EMPTY); +} + +TEST(FtsAstRewriterTest, OrCanonicalizationLiftsParentModifier) { + // +OR(a, -b) → +AND(a, -b) + std::vector children; + children.push_back(term("a")); + children.push_back(term("b", false, true)); + auto ast = or_node(std::move(children), /*must=*/true); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::AND); + EXPECT_TRUE(ast->must); + EXPECT_FALSE(ast->must_not); +} + +TEST(FtsAstRewriterTest, NestedOrWithMustNotCanonicalizedAtBothLevels) { + // OR(x, OR(b, -c)) — inner canonicalizes to AND(b, -c); outer keeps OR + // since it has no must_not after recursion. + std::vector inner; + inner.push_back(term("b")); + inner.push_back(term("c", false, true)); + std::vector outer; + outer.push_back(term("x")); + outer.push_back(or_node(std::move(inner))); + auto ast = or_node(std::move(outer)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::OR); + const auto &n = as_or(*ast); + ASSERT_EQ(n.children.size(), 2u); + EXPECT_EQ(as_term(*n.children[0]).term, "x"); + EXPECT_EQ(n.children[1]->type(), FtsNodeType::AND); +} + +TEST(FtsAstRewriterTest, OrWithoutMustNotIsLeftAlone) { + std::vector children; + children.push_back(term("a")); + children.push_back(term("b")); + auto ast = or_node(std::move(children)); + + simplify(ast); + + ASSERT_EQ(ast->type(), FtsNodeType::OR); + EXPECT_EQ(as_or(*ast).children.size(), 2u); +} + +// --- Leaf untouched --- + +TEST(FtsAstRewriterTest, BareTermPassthrough) { + FtsAstNodePtr ast = term("apple"); + simplify(ast); + ASSERT_EQ(ast->type(), FtsNodeType::TERM); + EXPECT_EQ(as_term(*ast).term, "apple"); + EXPECT_FLOAT_EQ(ast->boost, 1.0f); +} + +} // namespace zvec::fts diff --git a/tests/db/index/column/fts_column/fts_candidate_iterator_test.cc b/tests/db/index/column/fts_column/fts_candidate_iterator_test.cc new file mode 100644 index 000000000..bd4728525 --- /dev/null +++ b/tests/db/index/column/fts_column/fts_candidate_iterator_test.cc @@ -0,0 +1,98 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "db/index/column/fts_column/iterator/fts_candidate_iterator.h" +#include +#include +#include +#include "db/index/column/fts_column/iterator/fts_doc_iterator.h" + +using zvec::fts::CandidateDocIterator; +using zvec::fts::DocIterator; + +namespace { + +constexpr uint32_t kNoMore = DocIterator::NO_MORE_DOCS; + +} // namespace + +TEST(CandidateDocIteratorTest, EmptyVectorYieldsNothing) { + CandidateDocIterator it({}); + EXPECT_EQ(it.cost(), 0u); + EXPECT_EQ(it.next_doc(), kNoMore); + EXPECT_EQ(it.next_doc(), kNoMore); +} + +TEST(CandidateDocIteratorTest, NextDocStreamsAscending) { + CandidateDocIterator it({0, 5, 10, 100}); + EXPECT_EQ(it.cost(), 4u); + EXPECT_FLOAT_EQ(it.max_score(), 0.0f); + EXPECT_FLOAT_EQ(it.score(), 0.0f); + EXPECT_TRUE(it.matches()); + + EXPECT_EQ(it.next_doc(), 0u); + EXPECT_EQ(it.doc_id(), 0u); + EXPECT_EQ(it.next_doc(), 5u); + EXPECT_EQ(it.next_doc(), 10u); + EXPECT_EQ(it.next_doc(), 100u); + EXPECT_EQ(it.next_doc(), kNoMore); + EXPECT_EQ(it.next_doc(), kNoMore); +} + +TEST(CandidateDocIteratorTest, AdvanceLandsOnExactMatch) { + CandidateDocIterator it({10, 20, 30, 40, 50}); + EXPECT_EQ(it.advance(20), 20u); + EXPECT_EQ(it.doc_id(), 20u); + // Subsequent next_doc continues past the advanced position. + EXPECT_EQ(it.next_doc(), 30u); +} + +TEST(CandidateDocIteratorTest, AdvanceSeeksToNextHigher) { + CandidateDocIterator it({10, 20, 30, 40, 50}); + EXPECT_EQ(it.advance(25), 30u); + EXPECT_EQ(it.next_doc(), 40u); +} + +TEST(CandidateDocIteratorTest, AdvancePastLastYieldsNoMore) { + CandidateDocIterator it({10, 20, 30}); + EXPECT_EQ(it.advance(50), kNoMore); + EXPECT_EQ(it.next_doc(), kNoMore); +} + +TEST(CandidateDocIteratorTest, AdvanceBeforeAnyConsumeWorks) { + CandidateDocIterator it({10, 20, 30}); + EXPECT_EQ(it.advance(0), 10u); + EXPECT_EQ(it.next_doc(), 20u); +} + +TEST(CandidateDocIteratorTest, AdvanceInterleavedWithNext) { + CandidateDocIterator it({5, 10, 15, 20, 25, 30}); + EXPECT_EQ(it.next_doc(), 5u); + EXPECT_EQ(it.advance(15), 15u); + EXPECT_EQ(it.next_doc(), 20u); + EXPECT_EQ(it.advance(99), kNoMore); +} + +TEST(CandidateDocIteratorTest, SingleElement) { + CandidateDocIterator it({42}); + EXPECT_EQ(it.cost(), 1u); + EXPECT_EQ(it.advance(42), 42u); + EXPECT_EQ(it.next_doc(), kNoMore); +} + +TEST(CandidateDocIteratorTest, AdvanceCachesDocId) { + CandidateDocIterator it({1, 2, 3}); + EXPECT_EQ(it.advance(2), 2u); + EXPECT_EQ(it.doc_id(), 2u); +} diff --git a/tests/db/index/column/fts_column/fts_column_indexer_test.cc b/tests/db/index/column/fts_column/fts_column_indexer_test.cc new file mode 100644 index 000000000..1712ad7ae --- /dev/null +++ b/tests/db/index/column/fts_column/fts_column_indexer_test.cc @@ -0,0 +1,1757 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "db/index/column/fts_column/fts_column_indexer.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "db/common/file_helper.h" +#include "db/index/common/index_filter.h" +// FtsQueryParams defined below +#include "db/index/column/fts_column/fts_ast_rewriter.h" +#include "db/index/column/fts_column/fts_rocksdb_merge.h" +#include "db/index/column/fts_column/parser/fts_query_parser.h" +#include "db/index/column/fts_column/tokenizer/tokenizer_factory.h" +// meta.h not needed in zvec +#include "db/common/constants.h" +#include "db/common/rocksdb_context.h" + +using namespace zvec; +using namespace zvec::fts; + +namespace { + +// Build a transient FieldSchema for FTS unit tests. +// When fts_params is provided, it is attached as the field's index_params +// so that FtsColumnIndexer::open() can retrieve the tokenizer configuration. +FieldSchema::Ptr make_test_field_meta( + const std::string &field_name, + std::shared_ptr fts_params = nullptr) { + if (fts_params) { + return std::make_shared(field_name, DataType::STRING, false, + fts_params); + } + return std::make_shared(field_name, DataType::STRING); +} + +} // namespace + +// Build a tokenizer pipeline matching the indexer config used by the tests. +// A standalone helper so tests can pass it to parser.parse() without +// reaching into FtsColumnIndexer internals. +static zvec::fts::TokenizerPipelinePtr make_whitespace_pipeline() { + zvec::fts::FtsIndexParams params; + params.tokenizer_name = "whitespace"; + params.filters = {"lowercase"}; + return zvec::fts::TokenizerFactory::create(params); +} + +// Helper: parse a query string and call search() on a reader/indexer. +// Terminates the test with ASSERT if parsing fails. +template +static bool search_ok(Reader &reader, const std::string &query_str, + uint32_t topk, std::vector *results, + const zvec::fts::TokenizerPipelinePtr &pipeline = + make_whitespace_pipeline()) { + FtsQueryParser parser; + auto ast = parser.parse(query_str, pipeline); + if (!ast) { + ADD_FAILURE() << "FtsQueryParser failed to parse: " << query_str + << " err: " << parser.err_msg(); + return false; + } + // Apply the same AST rewrite the production sqlengine path runs so that + // FtsColumnIndexer::search() sees a canonical AST (no must_not children + // inside an OrNode, dedup-collapsed siblings, etc.). + zvec::fts::simplify(ast); + zvec::fts::FtsQueryParams qp; + qp.topk = topk; + auto ret = reader.search(*ast, qp); + if (!ret.has_value()) { + return false; + } + *results = std::move(ret.value()); + return true; +} + +// Helper: parse a query string with a filter and call search(). +template +static bool search_ok_with_filter(Reader &reader, const std::string &query_str, + uint32_t topk, zvec::IndexFilter::Ptr filter, + std::vector *results, + const zvec::fts::TokenizerPipelinePtr + &pipeline = make_whitespace_pipeline()) { + FtsQueryParser parser; + auto ast = parser.parse(query_str, pipeline); + if (!ast) { + ADD_FAILURE() << "FtsQueryParser failed to parse: " << query_str + << " err: " << parser.err_msg(); + return false; + } + zvec::fts::simplify(ast); + zvec::fts::FtsQueryParams qp; + qp.topk = topk; + qp.filter = std::move(filter); + auto ret = reader.search(*ast, qp); + if (!ret.has_value()) { + return false; + } + *results = std::move(ret.value()); + return true; +} + +// ============================================================ +// Test fixture +// ============================================================ + +static const std::string kDbPath{"./test_fts_db"}; + +static const std::string kPostingsCf{"fts"}; +static const std::string kMaxTfCf{kPostingsCf + zvec::kFtsMaxTfSuffix}; +static const std::string kPositionsCf{kPostingsCf + zvec::kFtsPositionsSuffix}; +static const std::string kTermFreqCf{kPostingsCf + zvec::kFtsTfSuffix}; +static const std::string kDocLenCf{kPostingsCf + zvec::kFtsDocLenSuffix}; +static const std::string kStatCf{zvec::kFtsStatCfName}; + +class FtsColumnIndexerTest : public ::testing::Test { + protected: + void SetUp() override { + zvec::FileHelper::RemoveDirectory(kDbPath); + + // Single RocksDB instance with per-CF merge operators. + std::vector cf_names = {kPostingsCf, kMaxTfCf, kPositionsCf, + kTermFreqCf, kDocLenCf, kStatCf}; + std::unordered_map> + per_cf_ops = { + {kPostingsCf, std::make_shared()}, + {kMaxTfCf, std::make_shared()}, + }; + ASSERT_TRUE( + db_.create(RocksdbContext::Args{kDbPath, cf_names, nullptr, per_cf_ops}) + .ok()); + + postings_cf_ = db_.get_cf(kPostingsCf); + max_tf_cf_ = db_.get_cf(kMaxTfCf); + positions_cf_ = db_.get_cf(kPositionsCf); + term_freq_cf_ = db_.get_cf(kTermFreqCf); + doc_len_cf_ = db_.get_cf(kDocLenCf); + stat_cf_ = db_.get_cf(kStatCf); + + ASSERT_NE(postings_cf_, nullptr); + ASSERT_NE(max_tf_cf_, nullptr); + ASSERT_NE(positions_cf_, nullptr); + ASSERT_NE(term_freq_cf_, nullptr); + ASSERT_NE(doc_len_cf_, nullptr); + ASSERT_NE(stat_cf_, nullptr); + } + + void TearDown() override { + db_.close(); + zvec::FileHelper::RemoveDirectory(kDbPath); + } + + // Create and open a fresh indexer with whitespace tokenizer. + // Returns unique_ptr because FtsColumnIndexer is not copyable (atomic + // members). + std::unique_ptr make_indexer( + const std::string &field_name = "content") { + auto fts_params = std::make_shared("whitespace"); + auto field_meta = make_test_field_meta(field_name, fts_params); + auto indexer = std::make_unique(); + auto ret = indexer->open(field_meta, &db_, postings_cf_, positions_cf_, + term_freq_cf_, max_tf_cf_, doc_len_cf_, stat_cf_); + EXPECT_TRUE(ret.has_value()); + return indexer; + } + + RocksdbContext db_; + + rocksdb::ColumnFamilyHandle *postings_cf_{nullptr}; + rocksdb::ColumnFamilyHandle *max_tf_cf_{nullptr}; + rocksdb::ColumnFamilyHandle *positions_cf_{nullptr}; + rocksdb::ColumnFamilyHandle *term_freq_cf_{nullptr}; + rocksdb::ColumnFamilyHandle *doc_len_cf_{nullptr}; + rocksdb::ColumnFamilyHandle *stat_cf_{nullptr}; +}; +// ============================================================ +// open() +// ============================================================ + +TEST_F(FtsColumnIndexerTest, OpenWithValidTokenizer) { + auto fts_params = std::make_shared("whitespace"); + auto field_meta = make_test_field_meta("content", fts_params); + FtsColumnIndexer indexer; + auto ret = indexer.open(field_meta, &db_, postings_cf_, positions_cf_, + term_freq_cf_, max_tf_cf_, doc_len_cf_, stat_cf_); + EXPECT_TRUE(ret.has_value()); + EXPECT_EQ(indexer.total_docs(), 0u); + EXPECT_EQ(indexer.total_tokens(), 0u); +} + +TEST_F(FtsColumnIndexerTest, OpenWithNullFieldMetaFails) { + FtsColumnIndexer indexer; + auto ret = + indexer.open(FieldSchema::Ptr{nullptr}, &db_, postings_cf_, positions_cf_, + term_freq_cf_, max_tf_cf_, doc_len_cf_, stat_cf_); + EXPECT_FALSE(ret.has_value()); +} + +TEST_F(FtsColumnIndexerTest, OpenWithNullStoreFails) { + auto fts_params = std::make_shared("whitespace"); + auto field_meta = make_test_field_meta("content", fts_params); + FtsColumnIndexer indexer; + auto ret = + indexer.open(field_meta, /*store=*/nullptr, postings_cf_, positions_cf_, + term_freq_cf_, max_tf_cf_, doc_len_cf_, stat_cf_); + EXPECT_FALSE(ret.has_value()); +} + +// ============================================================ +// insert() - statistics update +// ============================================================ + +TEST_F(FtsColumnIndexerTest, InsertUpdatesTotalDocs) { + auto indexer = make_indexer(); + + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_EQ(indexer->total_docs(), 1u); + + EXPECT_TRUE(indexer->insert(1, "foo bar baz").has_value()); + EXPECT_EQ(indexer->total_docs(), 2u); +} + +TEST_F(FtsColumnIndexerTest, InsertUpdatesTotalTokens) { + auto indexer = make_indexer(); + + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_EQ(indexer->total_tokens(), 2u); // "hello", "world" + + EXPECT_TRUE(indexer->insert(1, "foo bar baz").has_value()); + EXPECT_EQ(indexer->total_tokens(), 5u); // 2 + 3 +} + +TEST_F(FtsColumnIndexerTest, InsertEmptyTextCountsAsZeroTokens) { + auto indexer = make_indexer(); + + EXPECT_TRUE(indexer->insert(0, "").has_value()); + EXPECT_EQ(indexer->total_docs(), 1u); + EXPECT_EQ(indexer->total_tokens(), 0u); +} + +// ============================================================ +// flush() - persist stats to RocksDB +// ============================================================ + +TEST_F(FtsColumnIndexerTest, FlushPersistsStats) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "foo bar").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + // Verify stats were written to stat_cf by opening a standalone reader. + // Pass doc_len_cf as nullptr so the reader loads stats from stat_cf. + FtsColumnIndexer reader; + auto ret = reader.open_reader("content", &db_, postings_cf_, positions_cf_, + term_freq_cf_, max_tf_cf_, + /*doc_len_cf=*/nullptr, stat_cf_); + EXPECT_TRUE(ret.has_value()); + // Reader loads stats from stat_cf on open; search should succeed + std::vector results; + EXPECT_TRUE(search_ok(reader, "hello", 10, &results)); + ASSERT_EQ(results.size(), 1u); +} + +// ============================================================ +// search() - term query +// ============================================================ + +TEST_F(FtsColumnIndexerTest, SearchTermFound) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "hello foo").has_value()); + EXPECT_TRUE(indexer->insert(2, "bar baz").has_value()); + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello", 10, &results)); + EXPECT_EQ(results.size(), 2u); + + bool found_doc0 = false; + bool found_doc1 = false; + for (const auto &result : results) { + if (result.doc_id == 0) found_doc0 = true; + if (result.doc_id == 1) found_doc1 = true; + } + EXPECT_TRUE(found_doc0); + EXPECT_TRUE(found_doc1); +} + +TEST_F(FtsColumnIndexerTest, SearchTermNotFound) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "missing", 10, &results)); + EXPECT_TRUE(results.empty()); +} + +TEST_F(FtsColumnIndexerTest, SearchResultsSortedByScoreDescending) { + auto indexer = make_indexer(); + // Doc 0: "hello" appears once + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + // Doc 1: "hello" appears twice (higher TF -> higher BM25 score) + EXPECT_TRUE(indexer->insert(1, "hello hello").has_value()); + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello", 10, &results)); + ASSERT_EQ(results.size(), 2u); + + // Results must be in descending score order + EXPECT_GE(results[0].score, results[1].score); + // Doc 1 (higher TF) should rank first + EXPECT_EQ(results[0].doc_id, 1ull); +} + +TEST_F(FtsColumnIndexerTest, SearchTopkLimitsResults) { + auto indexer = make_indexer(); + for (uint64_t doc_id = 0; doc_id < 10; ++doc_id) { + EXPECT_TRUE(indexer->insert(doc_id, "hello world").has_value()); + } + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello", 3, &results)); + EXPECT_LE(results.size(), 3u); +} + +// ============================================================ +// search() - phrase query +// ============================================================ + +TEST_F(FtsColumnIndexerTest, SearchPhraseFound) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "machine learning model").has_value()); + EXPECT_TRUE(indexer->insert(1, "learning machine translation").has_value()); + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "\"machine learning\"", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); +} + +TEST_F(FtsColumnIndexerTest, SearchPhraseNotFound) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world foo").has_value()); + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "\"hello foo\"", 10, &results)); + EXPECT_TRUE(results.empty()); +} + +// Phrase with a repeated term ("a b a") exercises the dedup path in +// PhraseDocIterator::verify_phrase_positions: the two "a" entries must share +// a single MultiGet slot while still validating positions 0 and 2. +TEST_F(FtsColumnIndexerTest, SearchPhraseWithRepeatedTermFound) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "a b a").has_value()); // match + EXPECT_TRUE(indexer->insert(1, "a b c").has_value()); // a b ✓, trailing a ✗ + EXPECT_TRUE(indexer->insert(2, "b a c").has_value()); // wrong order + EXPECT_TRUE(indexer->insert(3, "a a b").has_value()); // wrong adjacency + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "\"a b a\"", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); +} + +// When the first phrase term is high-frequency in the doc (e.g., "the the the +// the model"), the anchor must be chosen from the rarest position list rather +// than terms_[0]; otherwise the anchor loop iterates many useless candidates. +// This test only asserts correctness — the anchor heuristic is internal — but +// guards against regressions in the shortest-list selection. +TEST_F(FtsColumnIndexerTest, SearchPhraseHighFrequencyLeadingTerm) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "the the the the model").has_value()); + EXPECT_TRUE(indexer->insert(1, "the model the the the").has_value()); + EXPECT_TRUE( + indexer->insert(2, "the the the the the").has_value()); // no model + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "\"the model\"", 10, &results)); + ASSERT_EQ(results.size(), 2u); + std::vector ids{results[0].doc_id, results[1].doc_id}; + std::sort(ids.begin(), ids.end()); + EXPECT_EQ(ids[0], 0ull); + EXPECT_EQ(ids[1], 1ull); +} + +// ============================================================ +// search() - boolean query (AND / OR) +// ============================================================ + +TEST_F(FtsColumnIndexerTest, SearchExplicitAnd) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); // matches both + EXPECT_TRUE(indexer->insert(1, "hello foo").has_value()); // only hello + EXPECT_TRUE(indexer->insert(2, "world bar").has_value()); // only world + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello AND world", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); +} + +TEST_F(FtsColumnIndexerTest, SearchExplicitOr) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "foo bar").has_value()); + EXPECT_TRUE(indexer->insert(2, "baz qux").has_value()); + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello OR foo", 10, &results)); + ASSERT_EQ(results.size(), 2u); +} + +TEST_F(FtsColumnIndexerTest, SearchImplicitAdjacency) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "foo bar").has_value()); + + // Adjacent terms without operator -> OR semantics (default operator) + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello foo", 10, &results)); + EXPECT_EQ(results.size(), 2u); +} + +// ============================================================ +// search() - EmptyNode (matches zero docs) +// ============================================================ + +TEST_F(FtsColumnIndexerTest, SearchEmptyNodeReturnsNoResults) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "foo bar").has_value()); + + EmptyNode empty; + zvec::fts::FtsQueryParams qp; + qp.topk = 10; + auto ret = indexer->search(empty, qp); + ASSERT_TRUE(ret.has_value()); + EXPECT_TRUE(ret.value().empty()); +} + +TEST_F(FtsColumnIndexerTest, SearchAndWithEmptyChildReturnsNoResults) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "hello foo").has_value()); + + // AND with EmptyNode child → whole conjunction matches nothing. + AndNode and_node; + and_node.children.push_back(std::make_unique()); + and_node.children.push_back(std::make_unique("hello")); + + zvec::fts::FtsQueryParams qp; + qp.topk = 10; + auto ret = indexer->search(and_node, qp); + ASSERT_TRUE(ret.has_value()); + EXPECT_TRUE(ret.value().empty()); +} + +TEST_F(FtsColumnIndexerTest, SearchOrWithEmptyChildIgnoresIt) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "foo bar").has_value()); + + // OR with EmptyNode child → empty is skipped, equivalent to OR(hello). + OrNode or_node; + or_node.children.push_back(std::make_unique()); + or_node.children.push_back(std::make_unique("hello")); + + zvec::fts::FtsQueryParams qp; + qp.topk = 10; + auto ret = indexer->search(or_node, qp); + ASSERT_TRUE(ret.has_value()); + ASSERT_EQ(ret.value().size(), 1u); + EXPECT_EQ(ret.value()[0].doc_id, 0ull); +} + +// ============================================================ +// search() - must_not modifier +// ============================================================ + +TEST_F(FtsColumnIndexerTest, SearchMustNotExcludesDoc) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "hello foo").has_value()); + + // "hello" matches both; "- world" (with space) excludes doc 0 + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello - world", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 1ull); +} + +// `a NOT b` is the new binary AND-NOT operator (`a AND NOT b`). +TEST_F(FtsColumnIndexerTest, SearchBinaryNotExcludesDoc) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "hello foo").has_value()); + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello NOT world", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 1ull); +} + +// `a NOT (b OR c)` — must_not on a parenthesised OR sub-expression must +// exclude every doc matching either `b` or `c`. +TEST_F(FtsColumnIndexerTest, SearchMustNotOnGroupedOrExcludesDocs) { + auto indexer = make_indexer(); + EXPECT_TRUE( + indexer->insert(0, "hello world").has_value()); // excluded (has world) + EXPECT_TRUE( + indexer->insert(1, "hello foo").has_value()); // excluded (has foo) + EXPECT_TRUE(indexer->insert(2, "hello bar").has_value()); // kept + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello NOT (world OR foo)", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 2ull); +} + +// Top-level `-(...)` produces a must_not root and must be rejected by +// search() (see fts_column_indexer.cc::search early-out). +TEST_F(FtsColumnIndexerTest, SearchTopLevelMustNotIsRejected) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + + // -(hello AND world) => AndNode with must_not=true at the root + FtsQueryParser parser; + auto ast = parser.parse("-(hello AND world)", make_whitespace_pipeline()); + ASSERT_NE(ast, nullptr); + EXPECT_TRUE(ast->must_not); + + std::vector results; + FtsQueryParams query_params; + query_params.topk = 10; + EXPECT_FALSE(indexer->search(*ast, query_params).has_value()); +} + +// ============================================================ +// BM25 stats are updated in real-time after insert +// ============================================================ + +TEST_F(FtsColumnIndexerTest, BM25StatsUpdatedAfterInsert) { + auto indexer = make_indexer(); + EXPECT_EQ(indexer->total_docs(), 0u); + EXPECT_EQ(indexer->total_tokens(), 0u); + + EXPECT_TRUE(indexer->insert(0, "hello world foo").has_value()); + EXPECT_EQ(indexer->total_docs(), 1u); + EXPECT_EQ(indexer->total_tokens(), 3u); + + EXPECT_TRUE(indexer->insert(1, "bar baz").has_value()); + EXPECT_EQ(indexer->total_docs(), 2u); + EXPECT_EQ(indexer->total_tokens(), 5u); +} + +TEST_F(FtsColumnIndexerTest, SearchScorePositiveAfterInsert) { + auto indexer = make_indexer(); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "hello", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_GT(results[0].score, 0.0f); +} + +// ============================================================ +// End-to-end: multiple inserts and searches +// ============================================================ + +TEST_F(FtsColumnIndexerTest, MultipleInsertsAndSearches) { + auto indexer = make_indexer("content"); + + const std::vector docs = { + "the quick brown fox", + "the lazy dog", + "quick brown dog", + "fox and dog", + }; + + for (uint64_t doc_id = 0; doc_id < docs.size(); ++doc_id) { + EXPECT_TRUE(indexer->insert(doc_id, docs[doc_id]).has_value()); + } + + EXPECT_EQ(indexer->total_docs(), docs.size()); + + // "quick" appears in doc 0 and doc 2 + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "quick", 10, &results)); + EXPECT_EQ(results.size(), 2u); + + // "the" appears in doc 0 and doc 1 + results.clear(); + EXPECT_TRUE(search_ok(*indexer, "the", 10, &results)); + EXPECT_EQ(results.size(), 2u); + + // "quick AND dog" -> only doc 2 + results.clear(); + EXPECT_TRUE(search_ok(*indexer, "quick AND dog", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 2ull); +} + +// ============================================================ +// Jieba Chinese tokenizer tests +// ============================================================ + +// JIEBA_DICT_DIR points to thirdparty/cppjieba/.../dict/ (injected by CMake). +#ifndef JIEBA_DICT_DIR +#define JIEBA_DICT_DIR "." +#endif + +static const std::string kJiebaDictDir{JIEBA_DICT_DIR}; + +static bool jieba_dict_available() { + std::string path = kJiebaDictDir + "/jieba.dict.utf8"; + std::ifstream ifs(path); + return ifs.good(); +} + +static std::string make_jieba_extra_params() { + return std::string(R"({"jieba_dict_dir":")") + kJiebaDictDir + R"("})"; +} + +class FtsColumnIndexerJiebaTest : public FtsColumnIndexerTest { + protected: + void SetUp() override { + if (!jieba_dict_available()) { + GTEST_SKIP() << "Jieba dict not available at: " << kJiebaDictDir; + } + FtsColumnIndexerTest::SetUp(); + } + // Create and open a fresh indexer with jieba tokenizer. + std::unique_ptr make_jieba_indexer( + const std::string &field_name = "content") { + auto fts_params = std::make_shared( + "jieba", std::vector{"lowercase"}, + make_jieba_extra_params()); + auto field_meta = make_test_field_meta(field_name, fts_params); + auto indexer = std::make_unique(); + auto ret = indexer->open(field_meta, &db_, postings_cf_, positions_cf_, + term_freq_cf_, max_tf_cf_, doc_len_cf_, stat_cf_); + EXPECT_TRUE(ret.has_value()); + return indexer; + } +}; + +// Verify that jieba tokenizer opens successfully with valid dict paths. +TEST_F(FtsColumnIndexerJiebaTest, OpenWithJiebaTokenizerSucceeds) { + auto fts_params = std::make_shared( + "jieba", std::vector{"lowercase"}, + make_jieba_extra_params()); + auto field_meta = make_test_field_meta("content", fts_params); + FtsColumnIndexer indexer; + auto ret = indexer.open(field_meta, &db_, postings_cf_, positions_cf_, + term_freq_cf_, max_tf_cf_, doc_len_cf_, stat_cf_); + EXPECT_TRUE(ret.has_value()); +} + +// Verify that jieba tokenizer fails when no jieba_dict_dir source resolves. +// (cppjieba FATAL-aborts on non-existent dict files, so we test the init-time +// validation in JiebaTokenizer instead.) +TEST_F(FtsColumnIndexerJiebaTest, OpenWithJiebaTokenizerFailsWithoutDictDir) { + // Make sure neither env-var nor GlobalConfig has a value; ensure + // extra_params is also empty. + ::unsetenv("ZVEC_JIEBA_DICT_DIR"); + zvec::GlobalConfig::Instance().set_default_jieba_dict_dir(""); + + fts::FtsIndexParams bad_params; + bad_params.tokenizer_name = "jieba"; + bad_params.extra_params = ""; + auto pipeline = TokenizerFactory::create(bad_params); + EXPECT_EQ(pipeline, nullptr); +} + +// Insert a Chinese sentence and verify that total_docs and total_tokens are +// updated correctly (jieba should produce at least one token). +TEST_F(FtsColumnIndexerJiebaTest, InsertChineseTextUpdatesStats) { + auto indexer = make_jieba_indexer(); + + // "中文分词测试" should be segmented into multiple tokens by jieba. + EXPECT_TRUE(indexer->insert(0, "中文分词测试").has_value()); + EXPECT_EQ(indexer->total_docs(), 1u); + EXPECT_GT(indexer->total_tokens(), 0u); +} + +// Insert multiple Chinese documents and verify that a segmented term can be +// found via search(). The dedicated FtsLexer supports UNICODE_TERM so Chinese +// words can be used as bare terms without quoting. +TEST_F(FtsColumnIndexerJiebaTest, SearchChineseTermFound) { + auto indexer = make_jieba_indexer(); + + // doc 0: contains "中文" and "分词" + EXPECT_TRUE(indexer->insert(0, "中文分词技术").has_value()); + // doc 1: contains "搜索" and "引擎" + EXPECT_TRUE(indexer->insert(1, "搜索引擎优化").has_value()); + // doc 2: contains "中文" again + EXPECT_TRUE(indexer->insert(2, "中文搜索").has_value()); + + // jieba CutForSearch segments "中文分词技术" → [中文, 分词, 技术, ...] and + // "中文搜索" → [中文, 搜索], so doc 0 and + // doc 2 should match "中文". + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "中文", 10, &results)); + EXPECT_GE(results.size(), 1u); + + bool found_doc0 = false; + bool found_doc2 = false; + for (const auto &result : results) { + if (result.doc_id == 0) found_doc0 = true; + if (result.doc_id == 2) found_doc2 = true; + } + EXPECT_TRUE(found_doc0); + EXPECT_TRUE(found_doc2); +} + +// Verify that a term not present in any document returns empty results. +TEST_F(FtsColumnIndexerJiebaTest, SearchChineseTermNotFound) { + auto indexer = make_jieba_indexer(); + + EXPECT_TRUE(indexer->insert(0, "中文分词技术").has_value()); + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "日语", 10, &results)); + EXPECT_EQ(results.size(), 0u); +} + +// Verify BM25 scores are positive after inserting Chinese documents. +TEST_F(FtsColumnIndexerJiebaTest, SearchChineseTermHasPositiveScore) { + auto indexer = make_jieba_indexer(); + + EXPECT_TRUE(indexer->insert(0, "自然语言处理技术").has_value()); + EXPECT_TRUE(indexer->insert(1, "机器学习算法").has_value()); + + // Search for a token that jieba should produce from "自然语言处理技术". + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "自然语言", 10, &results)); + if (!results.empty()) { + EXPECT_GT(results[0].score, 0.0f); + } +} + +// Verify that topk limits the number of results for Chinese queries. +TEST_F(FtsColumnIndexerJiebaTest, SearchChineseTermTopkLimitsResults) { + auto indexer = make_jieba_indexer(); + + // Insert 5 documents all containing "技术" + for (uint64_t doc_id = 0; doc_id < 5; ++doc_id) { + EXPECT_TRUE(indexer->insert(doc_id, "人工智能技术发展").has_value()); + } + + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "技术", /*topk=*/3, &results)); + EXPECT_LE(results.size(), 3u); +} + +// End-to-end: flush and reload with jieba tokenizer. +TEST_F(FtsColumnIndexerJiebaTest, FlushAndReloadWithJiebaTokenizer) { + auto indexer = make_jieba_indexer("content"); + + EXPECT_TRUE(indexer->insert(0, "深度学习模型").has_value()); + EXPECT_TRUE(indexer->insert(1, "神经网络结构").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + // Reload via a standalone reader (no tokenizer needed for reading). + // Pass doc_len_cf as nullptr so the reader loads stats from stat_cf. + FtsColumnIndexer reader; + auto ret = reader.open_reader("content", &db_, postings_cf_, positions_cf_, + term_freq_cf_, max_tf_cf_, + /*doc_len_cf=*/nullptr, stat_cf_); + EXPECT_TRUE(ret.has_value()); + + // Search with a term that jieba produces from "深度学习模型": + // jieba CutForSearch segments it into [深度, 学习, 深度学习, 模型]. + TermNode term_node("模型"); + FtsQueryParams query_params; + query_params.topk = 10; + auto search_ret = reader.search(term_node, query_params); + EXPECT_TRUE(search_ret.has_value()); + EXPECT_GE(search_ret.value().size(), 1u); +} + +// Construct a jieba pipeline matching the indexer config so phrase queries +// tokenize the same way the index did. +static zvec::fts::TokenizerPipelinePtr make_jieba_pipeline_for_test() { + zvec::fts::FtsIndexParams params; + params.tokenizer_name = "jieba"; + params.filters = {"lowercase"}; + params.extra_params = + std::string(R"({"jieba_dict_dir":")") + kJiebaDictDir + R"("})"; + return zvec::fts::TokenizerFactory::create(params); +} + +// Phrase queries on a jieba-indexed doc must hit when the query goes through +// the same pipeline as the document. Before the parser was pipeline-aware +// the query path split the phrase on ASCII whitespace, so a CJK phrase +// became a single opaque token and failed to match the per-segment tokens +// the index actually stored. +TEST_F(FtsColumnIndexerJiebaTest, PhraseSearchHitsAfterJiebaTokenization) { + auto indexer = make_jieba_indexer(); + EXPECT_TRUE(indexer->insert(0, "中华人民共和国成立").has_value()); + EXPECT_TRUE(indexer->insert(1, "无关文档").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + auto pipeline = make_jieba_pipeline_for_test(); + ASSERT_NE(pipeline, nullptr); + + // Phrase covering the full doc text — query and doc tokenize identically + // so the strict anchor+i adjacency check in PhraseDocIterator succeeds. + std::vector results; + EXPECT_TRUE( + search_ok(*indexer, "\"中华人民共和国成立\"", 10, &results, pipeline)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); + + // A single-token phrase still works after the position-as-sequence fix: + // jieba emits "成立" once with a deterministic sequence position, the + // single-term phrase trivially matches. + results.clear(); + EXPECT_TRUE(search_ok(*indexer, "\"成立\"", 10, &results, pipeline)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); +} + +// JiebaTokenizer.position must be a strictly increasing per-output-token +// sequence number. CutForSearch emits overlapping sub-words for a long +// parent word; using cppjieba's unicode_offset would assign duplicate or +// non-monotonic positions and break PhraseDocIterator's strict adjacency +// check. Sequence numbers are guaranteed contiguous across all emitted +// tokens. +TEST(JiebaTokenizerTest, PositionIsContiguousSequence) { + if (!jieba_dict_available()) { + GTEST_SKIP() << "Jieba dict not available at: " << kJiebaDictDir; + } + auto pipeline = make_jieba_pipeline_for_test(); + ASSERT_NE(pipeline, nullptr); + + // CutForSearch on this string emits the long parent word followed by its + // shorter sub-words; the sub-words share a unicode_offset with the parent + // but get distinct sequence numbers under the new scheme. + auto tokens = pipeline->process("中华人民共和国"); + ASSERT_FALSE(tokens.empty()); + for (size_t i = 0; i < tokens.size(); ++i) { + EXPECT_EQ(tokens[i].position, static_cast(i)) + << "tokens[" << i << "].text=" << tokens[i].text; + } +} + +// ============================================================ +// jieba_dict_dir resolution priority chain +// ============================================================ +// +// JiebaTokenizer::init resolves jieba_dict_dir in this order: +// 1. extra_params.jieba_dict_dir (per-field) +// 2. ZVEC_JIEBA_DICT_DIR env var +// 3. zvec::GlobalConfig::jieba_dict_dir() (set by SDK or zvec.init) +// +// The fixture below exercises each tier independently. + +class JiebaDictDirPriorityTest : public FtsColumnIndexerJiebaTest { + protected: + void SetUp() override { + FtsColumnIndexerJiebaTest::SetUp(); + if (IsSkipped()) { + return; + } + saved_env_set_ = false; + if (const char *prev = std::getenv("ZVEC_JIEBA_DICT_DIR"); + prev != nullptr) { + saved_env_set_ = true; + saved_env_ = prev; + } + saved_global_ = zvec::GlobalConfig::Instance().jieba_dict_dir(); + ::unsetenv("ZVEC_JIEBA_DICT_DIR"); + zvec::GlobalConfig::Instance().set_default_jieba_dict_dir(""); + } + + void TearDown() override { + if (saved_env_set_) { + ::setenv("ZVEC_JIEBA_DICT_DIR", saved_env_.c_str(), /*overwrite=*/1); + } else { + ::unsetenv("ZVEC_JIEBA_DICT_DIR"); + } + zvec::GlobalConfig::Instance().set_default_jieba_dict_dir(saved_global_); + FtsColumnIndexerJiebaTest::TearDown(); + } + + // Build an indexer with arbitrary extra_params (so individual cases can + // toggle whether jieba_dict_dir is in the per-field config). + std::unique_ptr make_indexer_with_extra_params( + const std::string &extra_params, + const std::string &field_name = "content") { + auto fts_params = std::make_shared( + "jieba", std::vector{"lowercase"}, extra_params); + auto field_meta = make_test_field_meta(field_name, fts_params); + auto indexer = std::make_unique(); + auto ret = indexer->open(field_meta, &db_, postings_cf_, positions_cf_, + term_freq_cf_, max_tf_cf_, doc_len_cf_, stat_cf_); + EXPECT_TRUE(ret.has_value()); + return indexer; + } + + private: + bool saved_env_set_{false}; + std::string saved_env_; + std::string saved_global_; +}; + +// Core scenario: SDK in module-load called set_default_jieba_dict_dir; user +// never called zvec_initialize; per-field extra_params is empty. Jieba must +// still work end-to-end. Validates that SDK auto-registration is decoupled +// from the GlobalConfig::Initialize one-shot lifecycle. +TEST_F(JiebaDictDirPriorityTest, GlobalConfigDefaultUsedWithoutInitialize) { + zvec::GlobalConfig::Instance().set_default_jieba_dict_dir(kJiebaDictDir); + + auto indexer = make_indexer_with_extra_params(""); + EXPECT_TRUE(indexer->insert(0, "中华人民共和国").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + std::vector results; + auto pipeline = make_jieba_pipeline_for_test(); + EXPECT_TRUE(search_ok(*indexer, "中华", 10, &results, pipeline)); + EXPECT_GE(results.size(), 1u); +} + +// env-var must override GlobalConfig even when zvec_initialize was never +// called. Set GlobalConfig to a bogus path; with env-var pointing at the +// real dict, jieba should resolve via env-var and succeed. +TEST_F(JiebaDictDirPriorityTest, EnvVarBeatsGlobalConfig) { + zvec::GlobalConfig::Instance().set_default_jieba_dict_dir( + "/zvec/intentionally/missing/global"); + ::setenv("ZVEC_JIEBA_DICT_DIR", kJiebaDictDir.c_str(), /*overwrite=*/1); + + auto indexer = make_indexer_with_extra_params(""); + EXPECT_TRUE(indexer->insert(0, "搜索引擎").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + std::vector results; + auto pipeline = make_jieba_pipeline_for_test(); + EXPECT_TRUE(search_ok(*indexer, "搜索", 10, &results, pipeline)); + EXPECT_GE(results.size(), 1u); +} + +// per-field extra_params.jieba_dict_dir must beat env-var and GlobalConfig +// even when both of them are bogus. +TEST_F(JiebaDictDirPriorityTest, PerFieldBeatsEnvAndGlobalConfig) { + zvec::GlobalConfig::Instance().set_default_jieba_dict_dir( + "/zvec/intentionally/missing/global"); + ::setenv("ZVEC_JIEBA_DICT_DIR", "/zvec/intentionally/missing/env", + /*overwrite=*/1); + + auto extra = std::string(R"({"jieba_dict_dir":")") + kJiebaDictDir + R"("})"; + auto indexer = make_indexer_with_extra_params(extra); + EXPECT_TRUE(indexer->insert(0, "自然语言处理").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + std::vector results; + auto pipeline = make_jieba_pipeline_for_test(); + EXPECT_TRUE(search_ok(*indexer, "自然", 10, &results, pipeline)); + EXPECT_GE(results.size(), 1u); +} + +// ============================================================ +// convert_postings_to_bitpacked() +// ============================================================ +// +// These tests exercise the BitPacked conversion path that is invoked from +// MutableSegment::dump_fts_column_indexers() right before the SST dump. +// They use the BitPackedPostingList::is_bitpacked_format magic-number probe +// to verify that postings have been re-encoded, and iterate $TF / $DOC_LEN +// CFs to verify the DeleteRange tombstones effectively removed all entries. + +#include "db/index/column/fts_column/posting/bitpacked_posting_list.h" // NOLINT: in-test include + +namespace { + +// Count entries in a CF by iterating from the first key. Used to verify that +// $TF / $DOC_LEN have been DeleteRange-cleared. +size_t count_cf_entries(RocksdbContext &db, rocksdb::ColumnFamilyHandle *cf) { + size_t count = 0; + std::unique_ptr iter( + db.db_->NewIterator(db.read_opts_, cf)); + for (iter->SeekToFirst(); iter->Valid(); iter->Next()) { + ++count; + } + return count; +} + +// Verify every value in postings_cf_ is in BitPacked format. +size_t count_postings_entries_and_check_bitpacked( + RocksdbContext &db, rocksdb::ColumnFamilyHandle *cf) { + size_t count = 0; + std::unique_ptr iter( + db.db_->NewIterator(db.read_opts_, cf)); + for (iter->SeekToFirst(); iter->Valid(); iter->Next()) { + const std::string value = iter->value().ToString(); + EXPECT_TRUE( + BitPackedPostingList::is_bitpacked_format(value.data(), value.size())) + << "Posting for term[" << iter->key().ToString() + << "] is not BitPacked"; + ++count; + } + return count; +} + +} // namespace + +// Insert N docs, run the conversion, and verify: +// - postings_cf_ values all carry the BitPacked magic +// - decoded posting iterators yield the original (doc_id, tf, doc_len) +// - $TF / $DOC_LEN CFs are empty +TEST_F(FtsColumnIndexerTest, ConvertPostingsToBitpackedBasic) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "hello foo bar").has_value()); + EXPECT_TRUE(indexer->insert(2, "hello hello world").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + EXPECT_TRUE(indexer->convert_postings_to_bitpacked().has_value()); + + // All postings must now be BitPacked. + size_t postings_count = + count_postings_entries_and_check_bitpacked(db_, postings_cf_); + EXPECT_GT(postings_count, 0u); + + // Spot-check: decode the "hello" posting and confirm doc_ids/tfs/doc_lens + // match what we wrote. Doc 0 -> tf=1, dl=2; Doc 1 -> tf=1, dl=3; Doc 2 -> + // tf=2, dl=3. + std::string raw; + ASSERT_TRUE(db_.db_->Get(db_.read_opts_, postings_cf_, "hello", &raw).ok()); + ASSERT_FALSE(raw.empty()); + ASSERT_TRUE( + BitPackedPostingList::is_bitpacked_format(raw.data(), raw.size())); + + BitPackedPostingIterator iter; + ASSERT_EQ(iter.open(raw.data(), raw.size()), 0); + + std::vector> decoded; + while (true) { + uint32_t did = iter.next_doc(); + if (did == BitPackedPostingIterator::NO_MORE_DOCS) break; + decoded.emplace_back(did, iter.term_freq(), iter.doc_len()); + } + ASSERT_EQ(decoded.size(), 3u); + EXPECT_EQ(std::get<0>(decoded[0]), 0u); + EXPECT_EQ(std::get<1>(decoded[0]), 1u); + EXPECT_EQ(std::get<2>(decoded[0]), 2u); + EXPECT_EQ(std::get<0>(decoded[1]), 1u); + EXPECT_EQ(std::get<1>(decoded[1]), 1u); + EXPECT_EQ(std::get<2>(decoded[1]), 3u); + EXPECT_EQ(std::get<0>(decoded[2]), 2u); + EXPECT_EQ(std::get<1>(decoded[2]), 2u); + EXPECT_EQ(std::get<2>(decoded[2]), 3u); +} + +// After conversion the $TF / $DOC_LEN / $MAX_TF side CFs must be EMPTY: the +// indexer DeleteRange's them once their content has been inlined into the +// BitPacked posting list. MutableSegment then drops the CFs entirely. +TEST_F(FtsColumnIndexerTest, ConvertPostingsToBitpackedClearsSideCfs) { + auto indexer = make_indexer("content"); + for (uint64_t doc_id = 0; doc_id < 5; ++doc_id) { + EXPECT_TRUE(indexer->insert(doc_id, "alpha beta gamma").has_value()); + } + EXPECT_TRUE(indexer->flush().has_value()); + + // Sanity: side CFs are populated before conversion. + EXPECT_GT(count_cf_entries(db_, term_freq_cf_), 0u); + EXPECT_GT(count_cf_entries(db_, doc_len_cf_), 0u); + EXPECT_GT(count_cf_entries(db_, max_tf_cf_), 0u); + + EXPECT_TRUE(indexer->convert_postings_to_bitpacked().has_value()); + + // Side CFs must be empty after conversion (DeleteRange'd by the indexer). + EXPECT_EQ(count_cf_entries(db_, term_freq_cf_), 0u); + EXPECT_EQ(count_cf_entries(db_, doc_len_cf_), 0u); + EXPECT_EQ(count_cf_entries(db_, max_tf_cf_), 0u); + + // After reset_side_cfs, search should still work (BitPacked path). + indexer->reset_side_cfs(); + std::vector results; + EXPECT_TRUE(search_ok(*indexer, "alpha", 10, &results)); + EXPECT_EQ(results.size(), 5u); +} + +// Conversion must be idempotent: calling it twice should not corrupt postings, +// nor should it re-encode terms that are already BitPacked. +TEST_F(FtsColumnIndexerTest, ConvertPostingsToBitpackedIsIdempotent) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "hello foo").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + EXPECT_TRUE(indexer->convert_postings_to_bitpacked().has_value()); + + // Snapshot the BitPacked posting for "hello" after the first conversion. + std::string snapshot; + ASSERT_TRUE( + db_.db_->Get(db_.read_opts_, postings_cf_, "hello", &snapshot).ok()); + ASSERT_FALSE(snapshot.empty()); + + // Second invocation must succeed and leave the posting byte-for-byte + // identical (the idempotency guard skips re-encoding). + EXPECT_TRUE(indexer->convert_postings_to_bitpacked().has_value()); + + std::string after; + ASSERT_TRUE(db_.db_->Get(db_.read_opts_, postings_cf_, "hello", &after).ok()); + EXPECT_EQ(snapshot, after); +} + +// An indexer with no inserted documents must still allow the conversion to +// succeed (no-op path) — this matches MutableSegment dump-flow expectations +// for FTS fields that received zero writes. +TEST_F(FtsColumnIndexerTest, ConvertPostingsToBitpackedEmptyIndexer) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->flush().has_value()); + EXPECT_TRUE(indexer->convert_postings_to_bitpacked().has_value()); + EXPECT_EQ(count_postings_entries_and_check_bitpacked(db_, postings_cf_), 0u); + // Side CFs were never populated (empty indexer); no special expectation + // about them here beyond "the conversion did not crash". +} + +// After conversion the search() path must keep working — readers fall through +// to the BitPacked branch via is_bitpacked_format(), and no longer require the +// $TF / $DOC_LEN CFs. +TEST_F(FtsColumnIndexerTest, SearchAfterConvertPostingsToBitpacked) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "the quick brown fox").has_value()); + EXPECT_TRUE(indexer->insert(1, "the lazy dog").has_value()); + EXPECT_TRUE(indexer->insert(2, "quick brown dog").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + // Pre-conversion baseline: "quick" hits doc 0 and doc 2. + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "quick", 10, &baseline)); + ASSERT_EQ(baseline.size(), 2u); + + EXPECT_TRUE(indexer->convert_postings_to_bitpacked().has_value()); + + // Post-conversion via a standalone reader (mirrors immutable segment use). + // Side CFs are passed as nullptr — immutable segments no longer register + // them. + FtsColumnIndexer reader; + ASSERT_TRUE(reader + .open_reader("content", &db_, postings_cf_, positions_cf_, + /*term_freq_cf=*/nullptr, /*max_tf_cf=*/nullptr, + /*doc_len_cf=*/nullptr, stat_cf_) + .has_value()); + std::vector results; + EXPECT_TRUE(search_ok(reader, "quick", 10, &results)); + ASSERT_EQ(results.size(), 2u); + + // Same set of doc_ids as the baseline; scores may differ slightly because + // the reader loaded stats fresh from stat_cf, but both must be positive. + std::vector ids; + for (const auto &r : results) { + ids.push_back(r.doc_id); + EXPECT_GT(r.score, 0.0f); + } + std::sort(ids.begin(), ids.end()); + EXPECT_EQ(ids[0], 0ull); + EXPECT_EQ(ids[1], 2ull); +} + +// ============================================================ +// Multi-column shared RocksDB tests +// +// Mirrors the CF-naming scheme used by SegmentImpl::open_fts_indexers(): +// field_name -> postings CF +// field_name_positions -> positions CF +// field_name_tf -> term-freq CF +// field_name_max_tf -> max-tf CF +// field_name_doc_len -> doc-len CF +// fts_stat -> shared stat CF +// ============================================================ + +static const std::string kMultiDbPath{"./test_fts_multi_db"}; + +class FtsMultiColumnSharedDbTest : public ::testing::Test { + protected: + // Two FTS fields sharing the same RocksDB instance. + static constexpr const char *kFields[] = {"title", "body"}; + static constexpr size_t kNumFields = 2; + + void SetUp() override { + zvec::FileHelper::RemoveDirectory(kMultiDbPath); + + // Build CF names and per-CF merge operators following the segment pattern. + std::vector cf_names; + std::unordered_map> + per_cf_ops; + + for (size_t i = 0; i < kNumFields; ++i) { + std::string f{kFields[i]}; + cf_names.push_back(f); // postings + cf_names.push_back(f + kFtsPositionsSuffix); // positions + cf_names.push_back(f + kFtsTfSuffix); // term freq + cf_names.push_back(f + kFtsMaxTfSuffix); // max tf + cf_names.push_back(f + kFtsDocLenSuffix); // doc len + + per_cf_ops[f] = std::make_shared(); + per_cf_ops[f + kFtsMaxTfSuffix] = std::make_shared(); + } + cf_names.push_back(zvec::kFtsStatCfName); + + ASSERT_TRUE(db_.create(RocksdbContext::Args{kMultiDbPath, cf_names, nullptr, + per_cf_ops}) + .ok()); + + // Resolve CF handles per field. + for (size_t i = 0; i < kNumFields; ++i) { + std::string f{kFields[i]}; + postings_cf_[i] = db_.get_cf(f); + positions_cf_[i] = db_.get_cf(f + kFtsPositionsSuffix); + term_freq_cf_[i] = db_.get_cf(f + kFtsTfSuffix); + max_tf_cf_[i] = db_.get_cf(f + kFtsMaxTfSuffix); + doc_len_cf_[i] = db_.get_cf(f + kFtsDocLenSuffix); + ASSERT_NE(postings_cf_[i], nullptr) << "field=" << f; + ASSERT_NE(positions_cf_[i], nullptr) << "field=" << f; + ASSERT_NE(term_freq_cf_[i], nullptr) << "field=" << f; + ASSERT_NE(max_tf_cf_[i], nullptr) << "field=" << f; + ASSERT_NE(doc_len_cf_[i], nullptr) << "field=" << f; + } + stat_cf_ = db_.get_cf(zvec::kFtsStatCfName); + ASSERT_NE(stat_cf_, nullptr); + } + + void TearDown() override { + db_.close(); + zvec::FileHelper::RemoveDirectory(kMultiDbPath); + } + + // Return the array index for a field name (0 = title, 1 = body). + size_t field_index(const std::string &field_name) const { + for (size_t i = 0; i < kNumFields; ++i) { + if (field_name == kFields[i]) return i; + } + ADD_FAILURE() << "Unknown field: " << field_name; + return 0; + } + + // Create and open a FtsColumnIndexer bound to the CFs of the given field. + std::unique_ptr make_indexer( + const std::string &field_name) { + size_t idx = field_index(field_name); + auto fts_params = std::make_shared("whitespace"); + auto field_meta = make_test_field_meta(field_name, fts_params); + auto indexer = std::make_unique(); + auto ret = indexer->open(field_meta, &db_, postings_cf_[idx], + positions_cf_[idx], term_freq_cf_[idx], + max_tf_cf_[idx], doc_len_cf_[idx], stat_cf_); + EXPECT_TRUE(ret.has_value()); + return indexer; + } + + RocksdbContext db_; + rocksdb::ColumnFamilyHandle *postings_cf_[kNumFields]{}; + rocksdb::ColumnFamilyHandle *positions_cf_[kNumFields]{}; + rocksdb::ColumnFamilyHandle *term_freq_cf_[kNumFields]{}; + rocksdb::ColumnFamilyHandle *max_tf_cf_[kNumFields]{}; + rocksdb::ColumnFamilyHandle *doc_len_cf_[kNumFields]{}; + rocksdb::ColumnFamilyHandle *stat_cf_{nullptr}; +}; + +// Two FTS columns write different documents; search on each column only +// returns hits from that column's data. +TEST_F(FtsMultiColumnSharedDbTest, MultiColumnInsertAndSearchIsolation) { + auto title_indexer = make_indexer("title"); + auto body_indexer = make_indexer("body"); + + // title column: documents about animals + EXPECT_TRUE(title_indexer->insert(0, "quick brown fox").has_value()); + EXPECT_TRUE(title_indexer->insert(1, "lazy dog").has_value()); + + // body column: documents about programming + EXPECT_TRUE(body_indexer->insert(0, "hello world program").has_value()); + EXPECT_TRUE(body_indexer->insert(1, "quick sort algorithm").has_value()); + + // Search "quick" in title -> only doc 0 + { + std::vector results; + EXPECT_TRUE(search_ok(*title_indexer, "quick", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); + } + + // Search "quick" in body -> only doc 1 + { + std::vector results; + EXPECT_TRUE(search_ok(*body_indexer, "quick", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 1ull); + } + + // Search "hello" in title -> no results + { + std::vector results; + EXPECT_TRUE(search_ok(*title_indexer, "hello", 10, &results)); + EXPECT_TRUE(results.empty()); + } + + // Search "fox" in body -> no results + { + std::vector results; + EXPECT_TRUE(search_ok(*body_indexer, "fox", 10, &results)); + EXPECT_TRUE(results.empty()); + } +} + +// Flush both columns, then open read-only readers and verify each column's +// search results survive the reload. +TEST_F(FtsMultiColumnSharedDbTest, MultiColumnFlushAndReload) { + auto title_indexer = make_indexer("title"); + auto body_indexer = make_indexer("body"); + + EXPECT_TRUE(title_indexer->insert(0, "alpha beta gamma").has_value()); + EXPECT_TRUE(body_indexer->insert(0, "delta epsilon").has_value()); + EXPECT_TRUE(body_indexer->insert(1, "alpha zeta").has_value()); + + EXPECT_TRUE(title_indexer->flush().has_value()); + EXPECT_TRUE(body_indexer->flush().has_value()); + + // Open standalone readers (pass doc_len_cf as nullptr to exercise the + // stat-CF reload path, matching immutable segment behaviour). + size_t ti = field_index("title"); + size_t bi = field_index("body"); + + FtsColumnIndexer title_reader; + ASSERT_TRUE(title_reader + .open_reader("title", &db_, postings_cf_[ti], + positions_cf_[ti], term_freq_cf_[ti], + max_tf_cf_[ti], /*doc_len_cf=*/nullptr, stat_cf_) + .has_value()); + + FtsColumnIndexer body_reader; + ASSERT_TRUE(body_reader + .open_reader("body", &db_, postings_cf_[bi], + positions_cf_[bi], term_freq_cf_[bi], + max_tf_cf_[bi], /*doc_len_cf=*/nullptr, stat_cf_) + .has_value()); + + // title reader: "alpha" -> doc 0 only + { + std::vector results; + EXPECT_TRUE(search_ok(title_reader, "alpha", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); + } + + // body reader: "alpha" -> doc 1 only + { + std::vector results; + EXPECT_TRUE(search_ok(body_reader, "alpha", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 1ull); + } + + // body reader: "delta" -> doc 0 only + { + std::vector results; + EXPECT_TRUE(search_ok(body_reader, "delta", 10, &results)); + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); + } +} + +// Each column maintains independent total_docs and total_tokens counters. +TEST_F(FtsMultiColumnSharedDbTest, MultiColumnStatsIndependent) { + auto title_indexer = make_indexer("title"); + auto body_indexer = make_indexer("body"); + + // title: 2 docs, 4 tokens + EXPECT_TRUE(title_indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(title_indexer->insert(1, "foo bar").has_value()); + EXPECT_EQ(title_indexer->total_docs(), 2u); + EXPECT_EQ(title_indexer->total_tokens(), 4u); + + // body: 1 doc, 3 tokens + EXPECT_TRUE(body_indexer->insert(0, "alpha beta gamma").has_value()); + EXPECT_EQ(body_indexer->total_docs(), 1u); + EXPECT_EQ(body_indexer->total_tokens(), 3u); + + // Inserting into body must not affect title's counters. + EXPECT_EQ(title_indexer->total_docs(), 2u); + EXPECT_EQ(title_indexer->total_tokens(), 4u); +} + +// ============================================================ +// Filter pushdown into FTS iterators (single-term / OR / Phrase) +// ============================================================ + +namespace { + +// Build an IndexFilter that excludes any doc_id present in `blocked`. +zvec::IndexFilter::Ptr make_blocked_filter( + std::initializer_list blocked) { + std::unordered_set set(blocked); + return zvec::EasyIndexFilter::Create( + [set = std::move(set)](uint64_t id) { return set.count(id) > 0; }); +} + +} // namespace + +// Single-term query path: TermDocIterator inherits the base-class default +// next_doc(filter), which loops over next_doc() and skips filtered docs. +TEST_F(FtsColumnIndexerTest, FilterPushdownExcludesFilteredDocs) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "hello foo").has_value()); + EXPECT_TRUE(indexer->insert(2, "hello world bar").has_value()); + EXPECT_TRUE(indexer->insert(3, "hello baz").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + // Baseline: no filter — all 4 docs match "hello". + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "hello", 10, &baseline)); + EXPECT_EQ(baseline.size(), 4u); + + // Block docs 1 and 3. + std::vector filtered; + EXPECT_TRUE(search_ok_with_filter(*indexer, "hello", 10, + make_blocked_filter({1, 3}), &filtered)); + ASSERT_EQ(filtered.size(), 2u); + + std::vector ids; + for (const auto &r : filtered) { + ids.push_back(r.doc_id); + EXPECT_GT(r.score, 0.0f); + } + std::sort(ids.begin(), ids.end()); + EXPECT_EQ(ids[0], 0ull); + EXPECT_EQ(ids[1], 2ull); +} + +// OR query exercises DisjunctionIterator::next_doc(filter) override — +// pivot_doc is filter-checked before block-max accumulation and resort. +TEST_F(FtsColumnIndexerTest, FilterPushdownWithDisjunction) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "alpha beta").has_value()); + EXPECT_TRUE(indexer->insert(1, "alpha gamma").has_value()); + EXPECT_TRUE(indexer->insert(2, "beta gamma").has_value()); + EXPECT_TRUE(indexer->insert(3, "alpha beta gamma").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + // Baseline: "alpha OR beta" matches all 4 docs. + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "alpha beta", 10, &baseline)); + EXPECT_EQ(baseline.size(), 4u); + + std::vector filtered; + EXPECT_TRUE(search_ok_with_filter(*indexer, "alpha beta", 10, + make_blocked_filter({0, 2}), &filtered)); + ASSERT_EQ(filtered.size(), 2u); + + std::vector ids; + for (const auto &r : filtered) { + ids.push_back(r.doc_id); + EXPECT_GT(r.score, 0.0f); + } + std::sort(ids.begin(), ids.end()); + EXPECT_EQ(ids[0], 1ull); + EXPECT_EQ(ids[1], 3ull); +} + +// Phrase query exercises PhraseDocIterator::next_doc(filter) -> inner +// ConjunctionIterator::next_doc(filter), ensuring verify_phrase_positions() +// is never executed for filtered docs. +TEST_F(FtsColumnIndexerTest, FilterPushdownWithPhrase) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "machine learning model").has_value()); + EXPECT_TRUE(indexer->insert(1, "machine learning notes").has_value()); + EXPECT_TRUE(indexer->insert(2, "learning machine translation").has_value()); + EXPECT_TRUE(indexer->insert(3, "machine learning systems").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + // Baseline: phrase "machine learning" matches docs 0, 1, 3 (not 2, where + // the order is reversed). + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "\"machine learning\"", 10, &baseline)); + EXPECT_EQ(baseline.size(), 3u); + + // Block docs 1 and 3 — only doc 0 should remain. + std::vector filtered; + EXPECT_TRUE(search_ok_with_filter(*indexer, "\"machine learning\"", 10, + make_blocked_filter({1, 3}), &filtered)); + ASSERT_EQ(filtered.size(), 1u); + EXPECT_EQ(filtered[0].doc_id, 0ull); + EXPECT_GT(filtered[0].score, 0.0f); +} + +// ============================================================ +// Brute-force (candidate-driven) mode via FtsQueryParams.candidate_ids +// ============================================================ + +namespace { + +// Helper: run a query with an explicit candidate id list. +template +static bool search_ok_with_candidates(Reader &reader, + const std::string &query_str, + uint32_t topk, + std::vector candidates, + std::vector *results) { + FtsQueryParser parser; + auto ast = parser.parse(query_str, make_whitespace_pipeline()); + if (!ast) { + ADD_FAILURE() << "FtsQueryParser failed to parse: " << query_str + << " err: " << parser.err_msg(); + return false; + } + zvec::fts::FtsQueryParams qp; + qp.topk = topk; + qp.candidate_ids = std::move(candidates); + auto ret = reader.search(*ast, qp); + if (!ret.has_value()) { + return false; + } + *results = std::move(ret.value()); + return true; +} + +// Compare two result vectors as (doc_id, score) sets — order independent on +// doc_id, scores compared with FLOAT_EQ. Brute-force and posting-driven +// paths reuse the same TermDocIterator / BM25Scorer so scores must agree. +static void ExpectSameResults(std::vector a, + std::vector b) { + ASSERT_EQ(a.size(), b.size()); + auto by_id = [](const FtsResult &x, const FtsResult &y) { + return x.doc_id < y.doc_id; + }; + std::sort(a.begin(), a.end(), by_id); + std::sort(b.begin(), b.end(), by_id); + for (size_t i = 0; i < a.size(); ++i) { + EXPECT_EQ(a[i].doc_id, b[i].doc_id) << "i=" << i; + EXPECT_FLOAT_EQ(a[i].score, b[i].score) << "i=" << i; + } +} + +} // namespace + +// Single-term query: candidate-driven path returns the intersection of the +// term posting and the candidate set, with the same BM25 scores as the +// posting-driven baseline. +TEST_F(FtsColumnIndexerTest, BruteForceTermMatchesPostingDriven) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "hello world").has_value()); + EXPECT_TRUE(indexer->insert(1, "hello foo").has_value()); + EXPECT_TRUE(indexer->insert(2, "hello world bar").has_value()); + EXPECT_TRUE(indexer->insert(3, "hello baz").has_value()); + EXPECT_TRUE(indexer->insert(4, "world only").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + // Baseline: "hello" matches docs 0,1,2,3. + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "hello", 10, &baseline)); + ASSERT_EQ(baseline.size(), 4u); + + // Candidate-driven with {1, 2, 4} -> expect {1, 2} (4 is not in posting). + std::vector bf; + EXPECT_TRUE(search_ok_with_candidates(*indexer, "hello", 10, + /*candidates=*/{1, 2, 4}, &bf)); + + std::vector expected; + for (const auto &r : baseline) { + if (r.doc_id == 1 || r.doc_id == 2) expected.push_back(r); + } + ExpectSameResults(std::move(expected), std::move(bf)); +} + +// Disjunction (OR) — same BM25 score, only intersected docs returned. +TEST_F(FtsColumnIndexerTest, BruteForceDisjunctionMatchesPostingDriven) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "alpha beta").has_value()); + EXPECT_TRUE(indexer->insert(1, "alpha gamma").has_value()); + EXPECT_TRUE(indexer->insert(2, "beta gamma").has_value()); + EXPECT_TRUE(indexer->insert(3, "alpha beta gamma").has_value()); + EXPECT_TRUE(indexer->insert(4, "delta").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "alpha beta", 10, &baseline)); + ASSERT_EQ(baseline.size(), 4u); // 0,1,2,3 all match OR + + std::vector bf; + EXPECT_TRUE(search_ok_with_candidates(*indexer, "alpha beta", 10, + /*candidates=*/{0, 3, 4}, &bf)); + + std::vector expected; + for (const auto &r : baseline) { + if (r.doc_id == 0 || r.doc_id == 3) expected.push_back(r); + } + ExpectSameResults(std::move(expected), std::move(bf)); +} + +// Conjunction (AND) — wrapped AND-of-AND is semantically transparent. +TEST_F(FtsColumnIndexerTest, BruteForceConjunctionMatchesPostingDriven) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "alpha beta gamma").has_value()); + EXPECT_TRUE(indexer->insert(1, "alpha gamma").has_value()); // missing beta + EXPECT_TRUE(indexer->insert(2, "alpha beta").has_value()); // missing gamma + EXPECT_TRUE(indexer->insert(3, "alpha beta gamma").has_value()); + EXPECT_TRUE(indexer->insert(4, "alpha beta gamma").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "alpha AND beta AND gamma", 10, &baseline)); + ASSERT_EQ(baseline.size(), 3u); // 0,3,4 + + std::vector bf; + EXPECT_TRUE(search_ok_with_candidates(*indexer, "alpha AND beta AND gamma", + 10, /*candidates=*/{0, 1, 4}, &bf)); + + std::vector expected; + for (const auto &r : baseline) { + if (r.doc_id == 0 || r.doc_id == 4) expected.push_back(r); + } + ExpectSameResults(std::move(expected), std::move(bf)); +} + +// Phrase query — phase-2 position check is preserved in candidate-driven mode. +TEST_F(FtsColumnIndexerTest, BruteForcePhraseMatchesPostingDriven) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "machine learning model").has_value()); + EXPECT_TRUE(indexer->insert(1, "machine notes learning").has_value()); + EXPECT_TRUE(indexer->insert(2, "the machine learning jumps").has_value()); + EXPECT_TRUE(indexer->insert(3, "learning machine").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "\"machine learning\"", 10, &baseline)); + ASSERT_EQ(baseline.size(), 2u); // 0,2 + + // Candidate set = {1, 2, 3}: only 2 is a real phrase match. + std::vector bf; + EXPECT_TRUE(search_ok_with_candidates(*indexer, "\"machine learning\"", 10, + /*candidates=*/{1, 2, 3}, &bf)); + + std::vector expected; + for (const auto &r : baseline) { + if (r.doc_id == 2) expected.push_back(r); + } + ExpectSameResults(std::move(expected), std::move(bf)); +} + +// Nested (AND of OR) — root iterator type does not matter; wrap is +// transparent. +TEST_F(FtsColumnIndexerTest, BruteForceNestedMatchesPostingDriven) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "alpha").has_value()); + EXPECT_TRUE(indexer->insert(1, "beta").has_value()); + EXPECT_TRUE(indexer->insert(2, "alpha gamma").has_value()); // matches + EXPECT_TRUE(indexer->insert(3, "beta gamma").has_value()); // matches + EXPECT_TRUE(indexer->insert(4, "gamma only").has_value()); // no alpha/beta + EXPECT_TRUE(indexer->flush().has_value()); + + // (alpha OR beta) AND gamma -> docs 2, 3 + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "(alpha OR beta) AND gamma", 10, &baseline)); + ASSERT_EQ(baseline.size(), 2u); + + std::vector bf; + EXPECT_TRUE(search_ok_with_candidates(*indexer, "(alpha OR beta) AND gamma", + 10, /*candidates=*/{2, 4}, &bf)); + + std::vector expected; + for (const auto &r : baseline) { + if (r.doc_id == 2) expected.push_back(r); + } + ExpectSameResults(std::move(expected), std::move(bf)); +} + +// Candidate-driven coexists with the existing filter pushdown: +// candidate_ids narrows the doc set; filter further drops some. +TEST_F(FtsColumnIndexerTest, BruteForceCoexistsWithFilterPushdown) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "alpha").has_value()); + EXPECT_TRUE(indexer->insert(1, "alpha").has_value()); + EXPECT_TRUE(indexer->insert(2, "alpha").has_value()); + EXPECT_TRUE(indexer->insert(3, "alpha").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + FtsQueryParser parser; + auto ast = parser.parse("alpha", make_whitespace_pipeline()); + ASSERT_NE(ast, nullptr); + + zvec::fts::FtsQueryParams qp; + qp.topk = 10; + qp.candidate_ids = {0, 1, 2}; // candidates restrict to {0,1,2} + qp.filter = make_blocked_filter({1}); // further drop doc 1 + auto ret = indexer->search(*ast, qp); + ASSERT_TRUE(ret.has_value()); + auto results = std::move(ret.value()); + ASSERT_EQ(results.size(), 2u); + + std::vector ids; + for (const auto &r : results) ids.push_back(r.doc_id); + std::sort(ids.begin(), ids.end()); + EXPECT_EQ(ids[0], 0ull); + EXPECT_EQ(ids[1], 2ull); +} + +// Empty candidate_ids takes the regular posting-driven path (the wrap guard +// requires non-empty), so search still finds all matching docs. +TEST_F(FtsColumnIndexerTest, BruteForceEmptyCandidatesFallsBack) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "alpha beta").has_value()); + EXPECT_TRUE(indexer->insert(1, "alpha gamma").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + std::vector r; + EXPECT_TRUE(search_ok_with_candidates(*indexer, "alpha", 10, {}, &r)); + EXPECT_EQ(r.size(), 2u); +} + +// Regression guard: a null filter yields the same doc_ids and scores as the +// baseline path (which still uses the no-filter next_doc() overload). +TEST_F(FtsColumnIndexerTest, FilterPushdownNullFilterUnchanged) { + auto indexer = make_indexer("content"); + EXPECT_TRUE(indexer->insert(0, "quick brown fox").has_value()); + EXPECT_TRUE(indexer->insert(1, "lazy brown dog").has_value()); + EXPECT_TRUE(indexer->flush().has_value()); + + std::vector baseline; + EXPECT_TRUE(search_ok(*indexer, "brown", 10, &baseline)); + ASSERT_EQ(baseline.size(), 2u); + + std::vector with_null; + EXPECT_TRUE(search_ok_with_filter(*indexer, "brown", 10, /*filter=*/nullptr, + &with_null)); + ASSERT_EQ(with_null.size(), 2u); + + auto by_id = [](const FtsResult &a, const FtsResult &b) { + return a.doc_id < b.doc_id; + }; + std::sort(baseline.begin(), baseline.end(), by_id); + std::sort(with_null.begin(), with_null.end(), by_id); + for (size_t i = 0; i < baseline.size(); ++i) { + EXPECT_EQ(baseline[i].doc_id, with_null[i].doc_id); + EXPECT_FLOAT_EQ(baseline[i].score, with_null[i].score); + } +} diff --git a/tests/db/index/column/fts_column/fts_rocksdb_reducer_test.cc b/tests/db/index/column/fts_column/fts_rocksdb_reducer_test.cc new file mode 100644 index 000000000..d860ba2d5 --- /dev/null +++ b/tests/db/index/column/fts_column/fts_rocksdb_reducer_test.cc @@ -0,0 +1,1124 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "db/index/column/fts_column/fts_rocksdb_reducer.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "db/common/file_helper.h" +// FtsSegmentStats defined below +#include "db/index/column/fts_column/fts_column_indexer.h" +#include "db/index/column/fts_column/fts_rocksdb_merge.h" +#include "db/index/column/fts_column/parser/fts_query_parser.h" +#include "db/index/column/fts_column/posting/bitpacked_posting_list.h" +#include "db/index/column/fts_column/tokenizer/tokenizer_factory.h" +// meta.h not needed in zvec +#include "db/common/constants.h" +#include "db/common/rocksdb_context.h" +#include "db/index/column/fts_column/fts_utils.h" + +using namespace zvec::fts; +using namespace zvec; +using namespace zvec::fts; + +// Build the same whitespace pipeline used by the reducer's source indexers +// so the query path tokenizes identically to what the index stored. +static zvec::fts::TokenizerPipelinePtr make_reducer_test_pipeline() { + zvec::fts::FtsIndexParams params; + params.tokenizer_name = "whitespace"; + params.filters = {"lowercase"}; + return zvec::fts::TokenizerFactory::create(params); +} + +// Helper: parse a query string and call search() on a reader. +// Returns true on success, false on failure. +template +static bool search_str_ok(Reader &reader, const std::string &query_str, + uint32_t topk, std::vector *results) { + FtsQueryParser parser; + auto ast = parser.parse(query_str, make_reducer_test_pipeline()); + if (!ast) { + ADD_FAILURE() << "FtsQueryParser failed to parse: " << query_str + << " err: " << parser.err_msg(); + return false; + } + zvec::fts::FtsQueryParams qp; + qp.topk = topk; + auto ret = reader.search(*ast, qp); + if (!ret.has_value()) { + return false; + } + *results = std::move(ret.value()); + return true; +} + +// ============================================================ +// Constants +// ============================================================ + +static const std::string kTestDir{"./test_fts_reducer"}; +static const std::string kSrc0Dir{kTestDir + "/src0"}; +static const std::string kSrc1Dir{kTestDir + "/src1"}; +static const std::string kDstDir{kTestDir + "/dst"}; +static const std::string kMid0Dir{kTestDir + "/mid0"}; +static const std::string kMid1Dir{kTestDir + "/mid1"}; +static const std::string kDst2Dir{kTestDir + "/dst2"}; + +static const std::string kPostingsCf{"fts"}; +static const std::string kMaxTfCf{kPostingsCf + zvec::kFtsMaxTfSuffix}; +static const std::string kPositionsCf{kPostingsCf + zvec::kFtsPositionsSuffix}; +static const std::string kTermFreqCf{kPostingsCf + zvec::kFtsTfSuffix}; +static const std::string kDocLenCf{kPostingsCf + zvec::kFtsDocLenSuffix}; +static const std::string kStatCf{zvec::kFtsStatCfName}; + +static const std::string kFieldName{"content"}; + +// ============================================================ +// Helper: build a transient FieldMeta with whitespace tokenizer for tests +// ============================================================ + +static FieldSchema::Ptr MakeWhitespaceFieldMeta(const std::string &field_name) { + auto fts_params = std::make_shared("whitespace"); + return std::make_shared(field_name, DataType::STRING, false, + fts_params); +} + +// ============================================================ +// Helper: open a RocksDB store with FTS merge operators +// ============================================================ + +// Build RocksDB args for source/indexer stores (mutable stage: includes side +// CFs). +static Status OpenFtsStoreWithSideCfs(RocksdbContext &db, + const std::string &data_dir) { + std::vector cf_names = {kPostingsCf, kMaxTfCf, kPositionsCf, + kTermFreqCf, kDocLenCf, kStatCf}; + std::unordered_map> + per_cf_ops = { + {kPostingsCf, std::make_shared()}, + {kMaxTfCf, std::make_shared()}, + }; + return db.create( + RocksdbContext::Args{data_dir, cf_names, nullptr, per_cf_ops}); +} + +// Build RocksDB args for destination/reader stores (immutable stage: no side +// CFs). +static Status OpenFtsStore(RocksdbContext &db, const std::string &data_dir) { + std::vector cf_names = {kPostingsCf, kPositionsCf, kStatCf}; + std::unordered_map> + per_cf_ops = { + {kPostingsCf, std::make_shared()}, + }; + return db.create( + RocksdbContext::Args{data_dir, cf_names, nullptr, per_cf_ops}); +} + +// Open an existing RocksDB FTS store (immutable stage: no side CFs). +static Status OpenExistingFtsStore(RocksdbContext &db, + const std::string &data_dir) { + std::vector cf_names = {kPostingsCf, kPositionsCf, kStatCf}; + std::unordered_map> + per_cf_ops = { + {kPostingsCf, std::make_shared()}, + }; + return db.open(RocksdbContext::Args{data_dir, cf_names, nullptr, per_cf_ops}, + false); +} + + +// ============================================================ +// Helper: build a SegmentStats with given doc_id range +// ============================================================ + +static FtsSegmentStats MakeSegmentStats(uint64_t min_doc_id, + uint64_t max_doc_id) { + FtsSegmentStats stats; + stats.min_doc_id = min_doc_id; + stats.max_doc_id = max_doc_id; + // Tests build fresh source segments where local doc_id space is dense over + // [min_doc_id, max_doc_id], so doc_count is the range size. + stats.doc_count = max_doc_id - min_doc_id + 1; + return stats; +} + +// ============================================================ +// Helper: insert documents into a source segment via FtsColumnIndexer +// ============================================================ + +static void InsertDocs( + FtsColumnIndexer *indexer, + const std::vector> &docs) { + for (const auto &[doc_id, text] : docs) { + ASSERT_TRUE(indexer->insert(doc_id, text).has_value()); + } + ASSERT_TRUE(indexer->flush().has_value()); + // The post-2026 reducer requires source postings_cf to be in BitPacked + // format (and the side CFs to be empty), which is exactly what + // MutableSegment::dump_fts_column_indexers() produces via + // convert_postings_to_bitpacked(). Mirror that here so every src segment + // looks identical to a real on-disk SST. + ASSERT_TRUE(indexer->convert_postings_to_bitpacked().has_value()); +} + +// ============================================================ +// Helper: build a roaring bitmap of deleted positions in input scan order. +// In these tests segments are contiguous starting at min_doc_id=0 with +// doc_count == range, so "scan position" of a global doc_id equals the +// global value itself. Kept under the original name for callsite stability. +// ============================================================ + +static roaring::Roaring NoDeleteFilter() { + return roaring::Roaring{}; +} + +static roaring::Roaring DeleteFilter( + std::initializer_list deleted_scan_positions) { + roaring::Roaring r; + for (uint32_t p : deleted_scan_positions) { + r.add(p); + } + return r; +} + +// ============================================================ +// Test fixture +// ============================================================ + +class FtsRocksdbReducerTest : public ::testing::Test { + protected: + void SetUp() override { + zvec::FileHelper::RemoveDirectory(kTestDir); + zvec::FileHelper::CreateDirectory(kTestDir); + + // Source stores need side CFs for FtsColumnIndexer::insert(). + ASSERT_TRUE(OpenFtsStoreWithSideCfs(src0_db_, kSrc0Dir).ok()); + ASSERT_TRUE(OpenFtsStoreWithSideCfs(src1_db_, kSrc1Dir).ok()); + // Destination store mirrors immutable/reducer layout - no side CFs. + ASSERT_TRUE(OpenFtsStore(dst_db_, kDstDir).ok()); + + // Grab CF pointers for src0 + src0_postings_ = src0_db_.get_cf(kPostingsCf); + src0_positions_ = src0_db_.get_cf(kPositionsCf); + src0_term_freq_ = src0_db_.get_cf(kTermFreqCf); + src0_max_tf_ = src0_db_.get_cf(kMaxTfCf); + src0_doc_len_ = src0_db_.get_cf(kDocLenCf); + src0_stat_ = src0_db_.get_cf(kStatCf); + + // Grab CF pointers for src1 + src1_postings_ = src1_db_.get_cf(kPostingsCf); + src1_positions_ = src1_db_.get_cf(kPositionsCf); + src1_term_freq_ = src1_db_.get_cf(kTermFreqCf); + src1_max_tf_ = src1_db_.get_cf(kMaxTfCf); + src1_doc_len_ = src1_db_.get_cf(kDocLenCf); + src1_stat_ = src1_db_.get_cf(kStatCf); + + // Grab CF pointers for dst (no side CFs) + dst_postings_ = dst_db_.get_cf(kPostingsCf); + dst_positions_ = dst_db_.get_cf(kPositionsCf); + dst_stat_ = dst_db_.get_cf(kStatCf); + } + + void TearDown() override { + src0_db_.close(); + src1_db_.close(); + dst_db_.close(); + zvec::FileHelper::RemoveDirectory(kTestDir); + } + + std::unique_ptr MakeSrc0Indexer() { + auto field_meta = MakeWhitespaceFieldMeta(kFieldName); + auto indexer = std::make_unique(); + EXPECT_TRUE(indexer + ->open(field_meta, &src0_db_, src0_postings_, + src0_positions_, src0_term_freq_, src0_max_tf_, + src0_doc_len_, src0_stat_) + .has_value()); + return indexer; + } + + // Create and open a FtsColumnIndexer for src1 (doc_ids start at offset) + std::unique_ptr MakeSrc1Indexer() { + auto field_meta = MakeWhitespaceFieldMeta(kFieldName); + auto indexer = std::make_unique(); + EXPECT_TRUE(indexer + ->open(field_meta, &src1_db_, src1_postings_, + src1_positions_, src1_term_freq_, src1_max_tf_, + src1_doc_len_, src1_stat_) + .has_value()); + return indexer; + } + + // Open a FtsColumnIndexer (read-only) on the merged destination store. + // Side CFs are nullptr — immutable/reducer stores no longer contain them. + std::unique_ptr MakeDstReader() { + auto reader = std::make_unique(); + EXPECT_TRUE(reader + ->open_reader(kFieldName, &dst_db_, dst_postings_, + dst_positions_, /*term_freq_cf=*/nullptr, + /*max_tf_cf=*/nullptr, + /*doc_len_cf=*/nullptr, dst_stat_) + .has_value()); + return reader; + } + + // Initialize a reducer targeting the destination store + FtsRocksdbReducer MakeReducer() { + FtsRocksdbReducer reducer; + EXPECT_TRUE(reducer + .init(kFieldName, &dst_db_, dst_postings_, dst_positions_, + dst_stat_) + .has_value()); + return reducer; + } + + RocksdbContext src0_db_; + RocksdbContext src1_db_; + RocksdbContext dst_db_; + + rocksdb::ColumnFamilyHandle *src0_postings_{nullptr}; + rocksdb::ColumnFamilyHandle *src0_positions_{nullptr}; + rocksdb::ColumnFamilyHandle *src0_term_freq_{nullptr}; + rocksdb::ColumnFamilyHandle *src0_max_tf_{nullptr}; + rocksdb::ColumnFamilyHandle *src0_doc_len_{nullptr}; + rocksdb::ColumnFamilyHandle *src0_stat_{nullptr}; + + rocksdb::ColumnFamilyHandle *src1_postings_{nullptr}; + rocksdb::ColumnFamilyHandle *src1_positions_{nullptr}; + rocksdb::ColumnFamilyHandle *src1_term_freq_{nullptr}; + rocksdb::ColumnFamilyHandle *src1_max_tf_{nullptr}; + rocksdb::ColumnFamilyHandle *src1_doc_len_{nullptr}; + rocksdb::ColumnFamilyHandle *src1_stat_{nullptr}; + + rocksdb::ColumnFamilyHandle *dst_postings_{nullptr}; + rocksdb::ColumnFamilyHandle *dst_positions_{nullptr}; + rocksdb::ColumnFamilyHandle *dst_stat_{nullptr}; +}; + +// ============================================================ +// init() error cases +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, InitFailsWithNullCF) { + FtsRocksdbReducer reducer; + EXPECT_FALSE( + reducer.init(kFieldName, &dst_db_, nullptr, dst_positions_, dst_stat_) + .has_value()); +} + +// ============================================================ +// feed() error cases +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, FeedFailsBeforeInit) { + FtsRocksdbReducer reducer; + FtsSegmentStats stats = MakeSegmentStats(0, 2); + EXPECT_FALSE(reducer.feed(stats, &src0_db_, src0_postings_, src0_positions_) + .has_value()); +} + +TEST_F(FtsRocksdbReducerTest, FeedFailsWithNonConsecutiveDocIds) { + FtsRocksdbReducer reducer = MakeReducer(); + + FtsSegmentStats stats0 = MakeSegmentStats(0, 2); + EXPECT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + + // Gap: src1 starts at 4 instead of 3 + FtsSegmentStats stats1 = MakeSegmentStats(4, 6); + EXPECT_FALSE(reducer.feed(stats1, &src1_db_, src1_postings_, src1_positions_) + .has_value()); +} + +TEST_F(FtsRocksdbReducerTest, FeedAcceptsEmptySegmentAsNoop) { + // Empty segments (doc_count == 0) silently contribute nothing — the + // surrounding non-empty segments still get their contiguity validated + // against each other, as if the empty one wasn't there. + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), {{0, "hello world"}, {1, "foo"}, {2, "bar"}}); + auto indexer1 = MakeSrc1Indexer(); + InsertDocs(indexer1.get(), {{0, "baz"}}); + + FtsRocksdbReducer reducer = MakeReducer(); + + FtsSegmentStats stats0 = MakeSegmentStats(0, 2); + ASSERT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + + // Empty middle segment — accepted, doesn't break contiguity. + FtsSegmentStats empty_stats; + empty_stats.min_doc_id = 0; + empty_stats.max_doc_id = 0; + empty_stats.doc_count = 0; + EXPECT_TRUE( + reducer.feed(empty_stats, &src1_db_, src1_postings_, src1_positions_) + .has_value()); + + // src1 must still start at stats0.max_doc_id + 1 = 3, not be shifted by + // the (skipped) empty segment. + FtsSegmentStats stats1 = MakeSegmentStats(3, 3); + ASSERT_TRUE(reducer.feed(stats1, &src1_db_, src1_postings_, src1_positions_) + .has_value()); + + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + + auto reader = MakeDstReader(); + std::vector results; + ASSERT_TRUE(search_str_ok(*reader, "hello", 10, &results)); + EXPECT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); + + results.clear(); + ASSERT_TRUE(search_str_ok(*reader, "baz", 10, &results)); + EXPECT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 3ull); +} + +// ============================================================ +// Single segment: basic merge without deletes +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, SingleSegmentMergeNoDeletes) { + // Segment 0: doc_ids 0..2 + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), + {{0, "hello world"}, {1, "hello foo"}, {2, "bar"}}); + + FtsRocksdbReducer reducer = MakeReducer(); + FtsSegmentStats stats0 = MakeSegmentStats(0, 2); + ASSERT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + + // Verify: search "hello" should return doc_ids 0 and 1 + auto reader = MakeDstReader(); + std::vector results; + ASSERT_TRUE(search_str_ok(*reader, "hello", 10, &results)); + EXPECT_EQ(results.size(), 2u); + + std::vector found_ids; + for (const auto &result : results) { + found_ids.push_back(result.doc_id); + } + EXPECT_NE(std::find(found_ids.begin(), found_ids.end(), 0ull), + found_ids.end()); + EXPECT_NE(std::find(found_ids.begin(), found_ids.end(), 1ull), + found_ids.end()); + + // "bar" should return doc_id 2 + results.clear(); + ASSERT_TRUE(search_str_ok(*reader, "bar", 10, &results)); + EXPECT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 2ull); +} + +// ============================================================ +// Single segment: delete filter removes documents +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, SingleSegmentMergeWithDeletes) { + // Segment 0: doc_ids 0..2 + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), + {{0, "hello world"}, {1, "hello foo"}, {2, "bar"}}); + + FtsRocksdbReducer reducer = MakeReducer(); + FtsSegmentStats stats0 = MakeSegmentStats(0, 2); + ASSERT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + + // Delete doc_id 0 (global). After reduce, the dst segment has dense local + // doc_ids; surviving global {1,2} get dense ranks {0,1}. + ASSERT_TRUE(reducer.reduce(DeleteFilter({0})).has_value()); + + auto reader = MakeDstReader(); + std::vector results; + + // "hello" survived in global doc 1 → dense doc_id 0. + ASSERT_TRUE(search_str_ok(*reader, "hello", 10, &results)); + EXPECT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); + + // "world" should return nothing (its only document was deleted) + results.clear(); + ASSERT_TRUE(search_str_ok(*reader, "world", 10, &results)); + EXPECT_EQ(results.size(), 0u); +} + +// ============================================================ +// Two segments: doc_id remapping across segment boundary +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, TwoSegmentsMergeDocIdRemapping) { + // Segment 0: GLOBAL doc_ids 0..2 + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), + {{0, "hello world"}, {1, "hello baz"}, {2, "foo bar"}}); + + // Segment 1: GLOBAL doc_ids 3..3 (stored as LOCAL 0 in src1 RocksDB) + auto indexer1 = MakeSrc1Indexer(); + InsertDocs(indexer1.get(), {{0, "hello qux"}}); + + FtsRocksdbReducer reducer = MakeReducer(); + + FtsSegmentStats stats0 = MakeSegmentStats(0, 2); + ASSERT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + + FtsSegmentStats stats1 = MakeSegmentStats(3, 3); + ASSERT_TRUE(reducer.feed(stats1, &src1_db_, src1_postings_, src1_positions_) + .has_value()); + + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + + // Dst segment starts at GLOBAL doc_id 0 (covers 0..3); reader returns + // GLOBAL doc_ids by adding start_doc_id back to local doc_ids stored in + // the merged dst RocksDB. + auto reader = MakeDstReader(); + std::vector results; + + // "hello" appears in global doc_ids 0, 1 (seg0) and 3 (seg1) + ASSERT_TRUE(search_str_ok(*reader, "hello", 10, &results)); + EXPECT_EQ(results.size(), 3u); + + std::vector found_ids; + for (const auto &result : results) { + found_ids.push_back(result.doc_id); + } + EXPECT_NE(std::find(found_ids.begin(), found_ids.end(), 0ull), + found_ids.end()); + EXPECT_NE(std::find(found_ids.begin(), found_ids.end(), 1ull), + found_ids.end()); + EXPECT_NE(std::find(found_ids.begin(), found_ids.end(), 3ull), + found_ids.end()); + + // "world" appears only in global doc_id 0 + results.clear(); + ASSERT_TRUE(search_str_ok(*reader, "world", 10, &results)); + EXPECT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); + + // "qux" appears only in global doc_id 3 + results.clear(); + ASSERT_TRUE(search_str_ok(*reader, "qux", 10, &results)); + EXPECT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 3ull); +} + +// ============================================================ +// Two segments: delete from second segment +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, TwoSegmentsMergeDeleteFromSecondSegment) { + // Segment 0: GLOBAL doc_ids 0..1 + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), {{0, "hello world"}, {1, "foo bar"}}); + + // Segment 1: GLOBAL doc_ids 2..3 (stored as LOCAL 0..1 in src1 RocksDB) + auto indexer1 = MakeSrc1Indexer(); + InsertDocs(indexer1.get(), {{0, "hello baz"}, {1, "qux"}}); + + FtsRocksdbReducer reducer = MakeReducer(); + + FtsSegmentStats stats0 = MakeSegmentStats(0, 1); + ASSERT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + + FtsSegmentStats stats1 = MakeSegmentStats(2, 3); + ASSERT_TRUE(reducer.feed(stats1, &src1_db_, src1_postings_, src1_positions_) + .has_value()); + + // Delete global doc_id 2 (first doc of segment 1, local 0). Survivors in + // input scan order are global {0, 1, 3}, getting dense ranks {0, 1, 2}. + ASSERT_TRUE(reducer.reduce(DeleteFilter({2})).has_value()); + + auto reader = MakeDstReader(); + std::vector results; + + // "hello" survived in global doc 0 → dense rank 0. + ASSERT_TRUE(search_str_ok(*reader, "hello", 10, &results)); + EXPECT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 0ull); + + // "qux" was global doc 3 → dense rank 2. + results.clear(); + ASSERT_TRUE(search_str_ok(*reader, "qux", 10, &results)); + EXPECT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].doc_id, 2ull); +} + +// ============================================================ +// BM25 scores are positive after merge +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, MergedResultsHavePositiveScores) { + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), + {{0, "hello world"}, {1, "hello foo"}, {2, "bar baz"}}); + + FtsRocksdbReducer reducer = MakeReducer(); + FtsSegmentStats stats0 = MakeSegmentStats(0, 2); + ASSERT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + + auto reader = MakeDstReader(); + std::vector results; + ASSERT_TRUE(search_str_ok(*reader, "hello", 10, &results)); + ASSERT_EQ(results.size(), 2u); + + for (const auto &result : results) { + EXPECT_GT(result.score, 0.0f) + << "Expected positive BM25 score for doc_id " << result.doc_id; + } +} + +// ============================================================ +// reduce() fails if called before feed() +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, ReduceFailsBeforeFeed) { + FtsRocksdbReducer reducer = MakeReducer(); + EXPECT_FALSE(reducer.reduce(NoDeleteFilter()).has_value()); +} + +// ============================================================ +// cleanup() resets state so reducer can be reused +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, CleanupResetsState) { + FtsRocksdbReducer reducer = MakeReducer(); + + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), {{0, "hello"}, {1, "world"}}); + + FtsSegmentStats stats0 = MakeSegmentStats(0, 1); + ASSERT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + ASSERT_TRUE(reducer.cleanup().has_value()); + + // After cleanup, reduce() should fail (no segments fed) + EXPECT_FALSE(reducer.reduce(NoDeleteFilter()).has_value()); +} + +// ============================================================ +// Verify reduce produces BitPacked format postings +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, ReduceProducesBitPackedFormat) { + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), + {{0, "hello world"}, {1, "hello foo"}, {2, "bar baz"}}); + + FtsRocksdbReducer reducer = MakeReducer(); + FtsSegmentStats stats0 = MakeSegmentStats(0, 2); + ASSERT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + + // Verify that postings in destination CF are in BitPacked format + std::string raw_data; + ASSERT_TRUE( + dst_db_.db_->Get(dst_db_.read_opts_, dst_postings_, "hello", &raw_data) + .ok()); + EXPECT_TRUE(fts::BitPackedPostingList::is_bitpacked_format(raw_data.data(), + raw_data.size())); + + // Verify the BitPacked data can be opened and iterated + fts::BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(raw_data.data(), raw_data.size()), 0); + EXPECT_EQ(iter.cost(), 2u); // "hello" appears in doc 0 and doc 1 + + // Verify inline payloads are accessible + uint32_t doc = iter.next_doc(); + EXPECT_EQ(doc, 0u); + EXPECT_GT(iter.term_freq(), 0u); + EXPECT_GT(iter.doc_len(), 0u); + + doc = iter.next_doc(); + EXPECT_EQ(doc, 1u); + EXPECT_GT(iter.term_freq(), 0u); + EXPECT_GT(iter.doc_len(), 0u); + + EXPECT_EQ(iter.next_doc(), fts::BitPackedPostingIterator::NO_MORE_DOCS); +} + +// ============================================================ +// Verify two-segment merge produces correct BitPacked postings +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, TwoSegmentMergeBitPackedCorrectness) { + // Segment 0: GLOBAL doc_ids 0..1 + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), {{0, "hello world"}, {1, "foo bar"}}); + + // Segment 1: GLOBAL doc_ids 2..3 (stored as LOCAL 0..1 in src1 RocksDB) + auto indexer1 = MakeSrc1Indexer(); + InsertDocs(indexer1.get(), {{0, "hello baz"}, {1, "qux"}}); + + FtsRocksdbReducer reducer = MakeReducer(); + + FtsSegmentStats stats0 = MakeSegmentStats(0, 1); + ASSERT_TRUE(reducer.feed(stats0, &src0_db_, src0_postings_, src0_positions_) + .has_value()); + + FtsSegmentStats stats1 = MakeSegmentStats(2, 3); + ASSERT_TRUE(reducer.feed(stats1, &src1_db_, src1_postings_, src1_positions_) + .has_value()); + + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + + // Verify "hello" postings are BitPacked and contain both doc_ids + std::string raw_data; + ASSERT_TRUE( + dst_db_.db_->Get(dst_db_.read_opts_, dst_postings_, "hello", &raw_data) + .ok()); + EXPECT_TRUE(fts::BitPackedPostingList::is_bitpacked_format(raw_data.data(), + raw_data.size())); + + fts::BitPackedPostingIterator iter; + EXPECT_EQ(iter.open(raw_data.data(), raw_data.size()), 0); + EXPECT_EQ(iter.cost(), 2u); // "hello" in doc 0 and doc 2 + + EXPECT_EQ(iter.next_doc(), 0u); + EXPECT_EQ(iter.next_doc(), 2u); + EXPECT_EQ(iter.next_doc(), fts::BitPackedPostingIterator::NO_MORE_DOCS); + + // Verify search still works correctly via FtsColumnIndexer + auto reader = MakeDstReader(); + std::vector results; + ASSERT_TRUE(search_str_ok(*reader, "hello", 10, &results)); + EXPECT_EQ(results.size(), 2u); + + // Verify BM25 scores are positive + for (const auto &result : results) { + EXPECT_GT(result.score, 0.0f); + } +} + +// ============================================================ +// Two BitPacked segments merged: both source segments have already been +// reduced (postings in BitPacked format), verify the reducer can handle +// BitPacked-to-BitPacked merge correctly. +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, MergeTwoBitPackedSegments) { + // --- Phase 1: Build two intermediate segments with BitPacked postings --- + // Each intermediate segment is produced by a single-segment reduce. + + // Mid0: reduce src0 -> mid0 (produces BitPacked postings) + { + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), + {{0, "hello world"}, {1, "hello foo"}, {2, "bar"}}); + + RocksdbContext mid0_db; + ASSERT_TRUE(OpenFtsStore(mid0_db, kMid0Dir).ok()); + + auto *mid0_postings = mid0_db.get_cf(kPostingsCf); + auto *mid0_positions = mid0_db.get_cf(kPositionsCf); + auto *mid0_stat = mid0_db.get_cf(kStatCf); + FtsRocksdbReducer reducer0; + ASSERT_TRUE(reducer0 + .init(kFieldName, &mid0_db, mid0_postings, mid0_positions, + mid0_stat) + .has_value()); + ASSERT_TRUE(reducer0 + .feed(MakeSegmentStats(0, 2), &src0_db_, src0_postings_, + src0_positions_) + .has_value()); + ASSERT_TRUE(reducer0.reduce(NoDeleteFilter()).has_value()); + + // Verify mid0 postings are in BitPacked format + std::string raw; + ASSERT_TRUE( + mid0_db.db_->Get(mid0_db.read_opts_, mid0_postings, "hello", &raw) + .ok()); + ASSERT_TRUE( + fts::BitPackedPostingList::is_bitpacked_format(raw.data(), raw.size())); + + mid0_db.close(); + } + + // Mid1: reduce src1 -> mid1 (produces BitPacked postings) + { + auto indexer1 = MakeSrc1Indexer(); + InsertDocs(indexer1.get(), {{0, "hello baz"}, {1, "qux bar"}}); + + RocksdbContext mid1_db; + ASSERT_TRUE(OpenFtsStore(mid1_db, kMid1Dir).ok()); + + auto *mid1_postings = mid1_db.get_cf(kPostingsCf); + auto *mid1_positions = mid1_db.get_cf(kPositionsCf); + auto *mid1_stat = mid1_db.get_cf(kStatCf); + FtsRocksdbReducer reducer1; + ASSERT_TRUE(reducer1 + .init(kFieldName, &mid1_db, mid1_postings, mid1_positions, + mid1_stat) + .has_value()); + ASSERT_TRUE(reducer1 + .feed(MakeSegmentStats(0, 1), &src1_db_, src1_postings_, + src1_positions_) + .has_value()); + ASSERT_TRUE(reducer1.reduce(NoDeleteFilter()).has_value()); + + // Verify mid1 postings are in BitPacked format + std::string raw; + ASSERT_TRUE( + mid1_db.db_->Get(mid1_db.read_opts_, mid1_postings, "hello", &raw) + .ok()); + ASSERT_TRUE( + fts::BitPackedPostingList::is_bitpacked_format(raw.data(), raw.size())); + + mid1_db.close(); + } + + // --- Phase 2: Merge the two BitPacked intermediate segments --- + // Reopen mid0 and mid1 as source (existing=true since they were created + // in Phase 1), reduce into dst. + RocksdbContext mid0_db, mid1_db; + ASSERT_TRUE(OpenExistingFtsStore(mid0_db, kMid0Dir).ok()); + ASSERT_TRUE(OpenExistingFtsStore(mid1_db, kMid1Dir).ok()); + + auto *mid0_postings = mid0_db.get_cf(kPostingsCf); + auto *mid0_positions = mid0_db.get_cf(kPositionsCf); + auto *mid1_postings = mid1_db.get_cf(kPostingsCf); + auto *mid1_positions = mid1_db.get_cf(kPositionsCf); + FtsRocksdbReducer final_reducer = MakeReducer(); + // mid0 has doc_ids 0..2, mid1 has doc_ids 3..4 + ASSERT_TRUE( + final_reducer + .feed(MakeSegmentStats(0, 2), &mid0_db, mid0_postings, mid0_positions) + .has_value()); + ASSERT_TRUE( + final_reducer + .feed(MakeSegmentStats(3, 4), &mid1_db, mid1_postings, mid1_positions) + .has_value()); + ASSERT_TRUE(final_reducer.reduce(NoDeleteFilter()).has_value()); + + mid0_db.close(); + mid1_db.close(); + + // --- Phase 3: Verify merged results --- + // Verify output is BitPacked + std::string raw_data; + ASSERT_TRUE( + dst_db_.db_->Get(dst_db_.read_opts_, dst_postings_, "hello", &raw_data) + .ok()); + EXPECT_TRUE(fts::BitPackedPostingList::is_bitpacked_format(raw_data.data(), + raw_data.size())); + + // "hello" appears in doc 0, 1 (from mid0) and doc 3 (from mid1) + fts::BitPackedPostingIterator bp_iter; + ASSERT_EQ(bp_iter.open(raw_data.data(), raw_data.size()), 0); + EXPECT_EQ(bp_iter.cost(), 3u); + EXPECT_EQ(bp_iter.next_doc(), 0u); + EXPECT_EQ(bp_iter.next_doc(), 1u); + EXPECT_EQ(bp_iter.next_doc(), 3u); + EXPECT_EQ(bp_iter.next_doc(), fts::BitPackedPostingIterator::NO_MORE_DOCS); + + // "bar" appears in doc 2 (from mid0) and doc 4 (from mid1) + raw_data.clear(); + ASSERT_TRUE( + dst_db_.db_->Get(dst_db_.read_opts_, dst_postings_, "bar", &raw_data) + .ok()); + EXPECT_TRUE(fts::BitPackedPostingList::is_bitpacked_format(raw_data.data(), + raw_data.size())); + fts::BitPackedPostingIterator bar_iter; + ASSERT_EQ(bar_iter.open(raw_data.data(), raw_data.size()), 0); + EXPECT_EQ(bar_iter.cost(), 2u); + EXPECT_EQ(bar_iter.next_doc(), 2u); + EXPECT_EQ(bar_iter.next_doc(), 4u); + EXPECT_EQ(bar_iter.next_doc(), fts::BitPackedPostingIterator::NO_MORE_DOCS); + + // Verify search via FtsColumnIndexer still works + auto reader = MakeDstReader(); + std::vector results; + ASSERT_TRUE(search_str_ok(*reader, "hello", 10, &results)); + EXPECT_EQ(results.size(), 3u); + for (const auto &result : results) { + EXPECT_GT(result.score, 0.0f); + } + + results.clear(); + ASSERT_TRUE(search_str_ok(*reader, "bar", 10, &results)); + EXPECT_EQ(results.size(), 2u); +} + +// ============================================================ +// (Removed) Mixed BitPacked + Roaring Bitmap merge. +// The post-2026 reducer no longer accepts Roaring-format source segments +// (FtsColumnIndexer::convert_postings_to_bitpacked() always runs at dump +// time), so this scenario is no longer reachable in production. + +// ============================================================ +// Reducer over BitPacked-converted source segments with EMPTY side CFs +// ============================================================ +// +// After the post-2026 indexer change, +// MutableSegment::dump_fts_column_indexers() invokes +// FtsColumnIndexer::convert_postings_to_bitpacked(), which inlines +// tf/doc_len/max_tf into the BitPacked posting list AND DeleteRange's the +// $TF / $MAX_TF / $DOC_LEN side CFs. By the time the reducer sees the +// segment: +// - postings_cf : every value is BitPacked (magic 'BPKD') +// - term_freq_cf / max_tf_cf / doc_len_cf : empty (DeleteRange tombstones) +// +// The new reducer never reads the side CFs at all, so this test verifies +// the end-to-end pipeline produces a queryable destination index whose +// posting set matches the expected union — and that the empty side CFs +// cause no errors or stat under-counts. + +TEST_F(FtsRocksdbReducerTest, ReducerHandlesBitpackedConvertedSrcSegments) { + // ----- src0: insert + flush + convert (the helper already calls convert) + // ----- + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), { + {0, "hello world"}, + {1, "hello foo"}, + {2, "bar baz"}, + }); + + // Sanity: src0 postings are BitPacked AND the side CFs are empty (the + // indexer DeleteRange'd them as part of convert_postings_to_bitpacked()). + { + std::string raw; + ASSERT_TRUE( + src0_db_.db_->Get(src0_db_.read_opts_, src0_postings_, "hello", &raw) + .ok()); + EXPECT_TRUE( + BitPackedPostingList::is_bitpacked_format(raw.data(), raw.size())); + auto it = std::unique_ptr( + src0_db_.db_->NewIterator(src0_db_.read_opts_, src0_term_freq_)); + it->SeekToFirst(); + EXPECT_FALSE(it->Valid()); + auto it2 = std::unique_ptr( + src0_db_.db_->NewIterator(src0_db_.read_opts_, src0_doc_len_)); + it2->SeekToFirst(); + EXPECT_FALSE(it2->Valid()); + auto it3 = std::unique_ptr( + src0_db_.db_->NewIterator(src0_db_.read_opts_, src0_max_tf_)); + it3->SeekToFirst(); + EXPECT_FALSE(it3->Valid()); + } + + // ----- src1: insert + flush + convert ----- + auto indexer1 = MakeSrc1Indexer(); + InsertDocs(indexer1.get(), { + {0, "hello qux"}, + {1, "qux quux"}, + }); + + // ----- Reduce ----- + // src0 covers GLOBAL [0, 2], src1 covers GLOBAL [3, 4] (consecutive). + FtsRocksdbReducer reducer = MakeReducer(); + ASSERT_TRUE(reducer + .feed(MakeSegmentStats(0, 2), &src0_db_, src0_postings_, + src0_positions_) + .has_value()); + ASSERT_TRUE(reducer + .feed(MakeSegmentStats(3, 4), &src1_db_, src1_postings_, + src1_positions_) + .has_value()); + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + + // ----- Verify dst can be queried ----- + // After reduce, dst postings get re-written to BitPacked again by the + // reducer's existing convert_postings_to_bitpacked step, so this exercises + // the full BitPacked-in / BitPacked-out path. + auto reader = MakeDstReader(); + + // "hello" appears in src0 doc 0 (global 0), src0 doc 1 (global 1), + // src1 doc 0 (global 3) -> 3 hits. + std::vector results; + ASSERT_TRUE(search_str_ok(*reader, "hello", 10, &results)); + EXPECT_EQ(results.size(), 3u); + std::vector hello_ids; + for (const auto &r : results) hello_ids.push_back(r.doc_id); + std::sort(hello_ids.begin(), hello_ids.end()); + EXPECT_EQ(hello_ids[0], 0ull); + EXPECT_EQ(hello_ids[1], 1ull); + EXPECT_EQ(hello_ids[2], 3ull); + + // "qux" appears in src1 docs 0 and 1 -> globals 3 and 4. + results.clear(); + ASSERT_TRUE(search_str_ok(*reader, "qux", 10, &results)); + EXPECT_EQ(results.size(), 2u); + std::vector qux_ids; + for (const auto &r : results) qux_ids.push_back(r.doc_id); + std::sort(qux_ids.begin(), qux_ids.end()); + EXPECT_EQ(qux_ids[0], 3ull); + EXPECT_EQ(qux_ids[1], 4ull); +} + +// ============================================================ +// Single-segment reduce when the source side CFs are completely empty: +// the reducer must rely only on the BitPacked inline payloads (tf, doc_len) +// for both the merged posting list and the destination stat_cf. Any +// regression that re-introduces a side-CF read would surface here as a +// missing tf / doc_len / score. +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, ReduceWithEmptySideCFsProducesBitPacked) { + // InsertDocs() already calls convert_postings_to_bitpacked(), so by the + // time we reach reduce() the src $TF / $MAX_TF / $DOC_LEN CFs are empty. + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), {{0, "alpha beta gamma"}, + {1, "alpha alpha gamma"}, + {2, "delta epsilon"}}); + + // Sanity: side CFs are empty after convert (DeleteRange'd by the indexer). + { + auto it = std::unique_ptr( + src0_db_.db_->NewIterator(src0_db_.read_opts_, src0_term_freq_)); + it->SeekToFirst(); + EXPECT_FALSE(it->Valid()); + auto it2 = std::unique_ptr( + src0_db_.db_->NewIterator(src0_db_.read_opts_, src0_doc_len_)); + it2->SeekToFirst(); + EXPECT_FALSE(it2->Valid()); + auto it3 = std::unique_ptr( + src0_db_.db_->NewIterator(src0_db_.read_opts_, src0_max_tf_)); + it3->SeekToFirst(); + EXPECT_FALSE(it3->Valid()); + } + + FtsRocksdbReducer reducer = MakeReducer(); + ASSERT_TRUE(reducer + .feed(MakeSegmentStats(0, 2), &src0_db_, src0_postings_, + src0_positions_) + .has_value()); + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + + // Destination postings_cf must be BitPacked and carry inline tf/doc_len + // recovered solely from the source BitPacked payloads. + std::string raw; + ASSERT_TRUE( + dst_db_.db_->Get(dst_db_.read_opts_, dst_postings_, "alpha", &raw).ok()); + ASSERT_TRUE( + fts::BitPackedPostingList::is_bitpacked_format(raw.data(), raw.size())); + fts::BitPackedPostingIterator bp; + ASSERT_EQ(bp.open(raw.data(), raw.size()), 0); + EXPECT_EQ(bp.cost(), 2u); + + EXPECT_EQ(bp.next_doc(), 0u); + EXPECT_EQ(bp.term_freq(), 1u); // doc 0: "alpha" once + EXPECT_EQ(bp.doc_len(), 3u); + EXPECT_EQ(bp.next_doc(), 1u); + EXPECT_EQ(bp.term_freq(), 2u); // doc 1: "alpha alpha" + EXPECT_EQ(bp.doc_len(), 3u); + EXPECT_EQ(bp.next_doc(), fts::BitPackedPostingIterator::NO_MORE_DOCS); + + // dst_stat_cf must reflect the inline doc_len totals: 3 docs, 8 tokens + // ("alpha beta gamma" = 3, "alpha alpha gamma" = 3, "delta epsilon" = 2). + std::string total_docs_raw, total_tokens_raw; + ASSERT_TRUE(dst_db_.db_ + ->Get(dst_db_.read_opts_, dst_stat_, + kFieldName + "_total_docs", &total_docs_raw) + .ok()); + ASSERT_TRUE(dst_db_.db_ + ->Get(dst_db_.read_opts_, dst_stat_, + kFieldName + "_total_tokens", &total_tokens_raw) + .ok()); + uint64_t total_docs = fts::decode_uint64_value(total_docs_raw.data()); + uint64_t total_tokens = fts::decode_uint64_value(total_tokens_raw.data()); + EXPECT_EQ(total_docs, 3u); + EXPECT_EQ(total_tokens, 8u); + + // dst no longer has side CFs ($TF/$MAX_TF/$DOC_LEN) — they are dropped + // at dump time. Verify search still works end-to-end. + auto reader = MakeDstReader(); + std::vector results; + ASSERT_TRUE(search_str_ok(*reader, "alpha", 10, &results)); + EXPECT_EQ(results.size(), 2u); + for (const auto &r : results) EXPECT_GT(r.score, 0.0f); +} + +// ============================================================ +// Cross-segment BM25 stats: the destination total_docs / total_tokens +// must equal the sum of the surviving documents from every fed segment, +// using the inline doc_len payloads (each surviving doc counted ONCE per +// its segment, regardless of how many terms it appears under). +// ============================================================ + +TEST_F(FtsRocksdbReducerTest, MultiSegmentBM25StatsAreAccumulatedCorrectly) { + // src0: 2 docs, doc_len 3 + 2 = 5 tokens + auto indexer0 = MakeSrc0Indexer(); + InsertDocs(indexer0.get(), {{0, "alpha beta gamma"}, {1, "alpha beta"}}); + + // src1: 2 docs, doc_len 4 + 1 = 5 tokens + auto indexer1 = MakeSrc1Indexer(); + InsertDocs(indexer1.get(), {{0, "alpha gamma delta epsilon"}, {1, "alpha"}}); + + FtsRocksdbReducer reducer = MakeReducer(); + ASSERT_TRUE(reducer + .feed(MakeSegmentStats(0, 1), &src0_db_, src0_postings_, + src0_positions_) + .has_value()); + ASSERT_TRUE(reducer + .feed(MakeSegmentStats(2, 3), &src1_db_, src1_postings_, + src1_positions_) + .has_value()); + ASSERT_TRUE(reducer.reduce(NoDeleteFilter()).has_value()); + + // 4 surviving docs across both segments; 5 + 5 = 10 tokens total. + std::string total_docs_raw, total_tokens_raw; + ASSERT_TRUE(dst_db_.db_ + ->Get(dst_db_.read_opts_, dst_stat_, + kFieldName + "_total_docs", &total_docs_raw) + .ok()); + ASSERT_TRUE(dst_db_.db_ + ->Get(dst_db_.read_opts_, dst_stat_, + kFieldName + "_total_tokens", &total_tokens_raw) + .ok()); + uint64_t total_docs = fts::decode_uint64_value(total_docs_raw.data()); + uint64_t total_tokens = fts::decode_uint64_value(total_tokens_raw.data()); + EXPECT_EQ(total_docs, 4u); + EXPECT_EQ(total_tokens, 10u); + + // With one doc filtered out (global doc_id 2 from src1, doc_len 4), + // totals must drop to 3 docs / 6 tokens. + // Reset destination CFs by re-opening the dst RocksDB? Simpler: build a + // second dst inside this test would require a second fixture; instead we + // assert via a dedicated Reducer + dst pair using the current dst (which + // has data already) is not safe. Skip the filter sub-case here — it's + // covered by SingleSegmentMergeWithDeletes for the single-segment path. + + // Verify "alpha" merged posting carries 4 entries with monotonic doc_ids. + std::string raw; + ASSERT_TRUE( + dst_db_.db_->Get(dst_db_.read_opts_, dst_postings_, "alpha", &raw).ok()); + ASSERT_TRUE( + fts::BitPackedPostingList::is_bitpacked_format(raw.data(), raw.size())); + fts::BitPackedPostingIterator bp; + ASSERT_EQ(bp.open(raw.data(), raw.size()), 0); + EXPECT_EQ(bp.cost(), 4u); + std::vector docs; + while (true) { + uint32_t d = bp.next_doc(); + if (d == fts::BitPackedPostingIterator::NO_MORE_DOCS) break; + docs.push_back(d); + } + ASSERT_EQ(docs.size(), 4u); + EXPECT_EQ(docs[0], 0u); + EXPECT_EQ(docs[1], 1u); + EXPECT_EQ(docs[2], 2u); + EXPECT_EQ(docs[3], 3u); +} diff --git a/tests/db/index/column/fts_column/tokenizer_pipeline_manager_test.cc b/tests/db/index/column/fts_column/tokenizer_pipeline_manager_test.cc new file mode 100644 index 000000000..5a9ba5b0d --- /dev/null +++ b/tests/db/index/column/fts_column/tokenizer_pipeline_manager_test.cc @@ -0,0 +1,271 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "db/index/column/fts_column/tokenizer/tokenizer_pipeline_manager.h" +#include +#include +#include +#include +#include +#include +#include "db/index/column/fts_column/fts_types.h" + +using namespace zvec::fts; + +// ============================================================ +// Helpers +// ============================================================ + +static FtsIndexParams make_params(const std::string &tokenizer) { + FtsIndexParams params; + params.tokenizer_name = tokenizer; + return params; +} + +// ============================================================ +// make_key tests +// ============================================================ + +TEST(TokenizerPipelineManagerKeyTest, BasicKey) { + FtsIndexParams params; + params.tokenizer_name = "whitespace"; + std::string key = TokenizerPipelineManager::make_key(params); + EXPECT_FALSE(key.empty()); + EXPECT_NE(key.find("whitespace"), std::string::npos); +} + +TEST(TokenizerPipelineManagerKeyTest, SameParamsProduceSameKey) { + FtsIndexParams params1; + params1.tokenizer_name = "whitespace"; + params1.extra_params = R"({"dict_path":"/path/to/dict"})"; + + FtsIndexParams params2; + params2.tokenizer_name = "whitespace"; + params2.extra_params = R"({"dict_path":"/path/to/dict"})"; + + std::string key1 = TokenizerPipelineManager::make_key(params1); + std::string key2 = TokenizerPipelineManager::make_key(params2); + EXPECT_EQ(key1, key2); +} + +TEST(TokenizerPipelineManagerKeyTest, DifferentTokenizersDifferentKeys) { + FtsIndexParams params1 = make_params("whitespace"); + FtsIndexParams params2 = make_params("jieba"); + std::string key1 = TokenizerPipelineManager::make_key(params1); + std::string key2 = TokenizerPipelineManager::make_key(params2); + EXPECT_NE(key1, key2); +} + +TEST(TokenizerPipelineManagerKeyTest, FilterNamesAffectKey) { + FtsIndexParams params1 = make_params("whitespace"); + params1.filters.clear(); + + FtsIndexParams params2 = make_params("whitespace"); + params2.filters = {"lowercase"}; + + std::string key1 = TokenizerPipelineManager::make_key(params1); + std::string key2 = TokenizerPipelineManager::make_key(params2); + EXPECT_NE(key1, key2); +} + +// ============================================================ +// acquire / release tests +// ============================================================ + +class TokenizerPipelineManagerTest : public ::testing::Test { + protected: + void SetUp() override { + // Use whitespace tokenizer (always available, no dict needed) + params_ = make_params("whitespace"); + } + + void TearDown() override { + // Best-effort cleanup: release the params if it still exists + // (tests that fail mid-way may leave entries) + // We do this by calling release repeatedly; release on unknown key is a + // no-op + } + + FtsIndexParams params_; +}; + +TEST_F(TokenizerPipelineManagerTest, FirstAcquireCreatesPipeline) { + auto &mgr = TokenizerPipelineManager::Instance(); + auto pipeline = mgr.acquire(params_); + ASSERT_NE(pipeline, nullptr); + + // Cleanup + mgr.release(params_); +} + +TEST_F(TokenizerPipelineManagerTest, RepeatedAcquireReturnsSameInstance) { + auto &mgr = TokenizerPipelineManager::Instance(); + auto pipeline1 = mgr.acquire(params_); + auto pipeline2 = mgr.acquire(params_); + + ASSERT_NE(pipeline1, nullptr); + ASSERT_NE(pipeline2, nullptr); + // Both should point to the exact same underlying object + EXPECT_EQ(pipeline1.get(), pipeline2.get()); + + // Cleanup: two acquires → two releases + mgr.release(params_); + mgr.release(params_); +} + +TEST_F(TokenizerPipelineManagerTest, ReleaseDecrementsRefCount) { + auto &mgr = TokenizerPipelineManager::Instance(); + auto pipeline1 = mgr.acquire(params_); + auto pipeline2 = mgr.acquire(params_); + ASSERT_NE(pipeline1, nullptr); + + // Release one reference; pipeline should still be alive (ref_count = 1) + mgr.release(params_); + + // Acquire again — should still return the same instance (not recreated) + auto pipeline3 = mgr.acquire(params_); + ASSERT_NE(pipeline3, nullptr); + EXPECT_EQ(pipeline1.get(), pipeline3.get()); + + // Cleanup: we now have ref_count = 2 (pipeline2 + pipeline3) + mgr.release(params_); + mgr.release(params_); +} + +TEST_F(TokenizerPipelineManagerTest, RefCountZeroDestroysEntry) { + auto &mgr = TokenizerPipelineManager::Instance(); + + auto pipeline1 = mgr.acquire(params_); + ASSERT_NE(pipeline1, nullptr); + void *raw_ptr = pipeline1.get(); + + // Release the only reference → entry should be removed + mgr.release(params_); + + // Acquire again → a new pipeline should be created (possibly different + // address) + auto pipeline2 = mgr.acquire(params_); + ASSERT_NE(pipeline2, nullptr); + // The old shared_ptr (pipeline1) still holds the object alive, so raw_ptr + // is still valid, but the manager has created a fresh entry. + // We can't guarantee same/different address, but we can verify it works. + (void)raw_ptr; + + // Cleanup + mgr.release(params_); +} + +TEST_F(TokenizerPipelineManagerTest, ReleaseUnknownKeyIsNoOp) { + auto &mgr = TokenizerPipelineManager::Instance(); + // Should not crash or assert + FtsIndexParams unknown_params; + unknown_params.tokenizer_name = "nonexistent_tokenizer_name"; + EXPECT_NO_THROW(mgr.release(unknown_params)); +} + +TEST_F(TokenizerPipelineManagerTest, DifferentConfigsDifferentPipelines) { + auto &mgr = TokenizerPipelineManager::Instance(); + + FtsIndexParams params_ws = make_params("whitespace"); + + // scws tokenizer will fail to create (no dict), but whitespace should succeed + auto pipeline_ws = mgr.acquire(params_ws); + ASSERT_NE(pipeline_ws, nullptr); + + // Cleanup + mgr.release(params_ws); +} + +// ============================================================ +// Concurrent safety tests +// ============================================================ + +TEST_F(TokenizerPipelineManagerTest, ConcurrentAcquireSameKey) { + auto &mgr = TokenizerPipelineManager::Instance(); + constexpr int kThreads = 8; + constexpr int kAcquiresPerThread = 10; + + std::vector results(kThreads * kAcquiresPerThread); + std::vector threads; + std::atomic success_count{0}; + + for (int t = 0; t < kThreads; ++t) { + threads.emplace_back([&, t]() { + for (int i = 0; i < kAcquiresPerThread; ++i) { + auto pipeline = mgr.acquire(params_); + if (pipeline) { + results[t * kAcquiresPerThread + i] = pipeline; + success_count.fetch_add(1); + } + } + }); + } + + for (auto &th : threads) { + th.join(); + } + + // All acquires should succeed + EXPECT_EQ(success_count.load(), kThreads * kAcquiresPerThread); + + // All non-null results should point to the same underlying pipeline + void *expected_ptr = nullptr; + for (const auto &p : results) { + if (p) { + if (expected_ptr == nullptr) { + expected_ptr = p.get(); + } else { + EXPECT_EQ(p.get(), expected_ptr); + } + } + } + + // Cleanup: release all acquired references + for (int i = 0; i < kThreads * kAcquiresPerThread; ++i) { + mgr.release(params_); + } +} + +TEST_F(TokenizerPipelineManagerTest, ConcurrentAcquireAndRelease) { + auto &mgr = TokenizerPipelineManager::Instance(); + constexpr int kThreads = 4; + constexpr int kIterations = 20; + std::atomic errors{0}; + + std::vector threads; + for (int t = 0; t < kThreads; ++t) { + threads.emplace_back([&]() { + for (int i = 0; i < kIterations; ++i) { + auto pipeline = mgr.acquire(params_); + if (!pipeline) { + errors.fetch_add(1); + continue; + } + // Hold briefly then release + mgr.release(params_); + } + }); + } + + for (auto &th : threads) { + th.join(); + } + + EXPECT_EQ(errors.load(), 0); + // After all threads finish, ref_count should be 0 (all released) + // Verify by acquiring once more — should succeed + auto pipeline = mgr.acquire(params_); + EXPECT_NE(pipeline, nullptr); + mgr.release(params_); +} diff --git a/tests/db/index/common/doc_test.cc b/tests/db/index/common/doc_test.cc index 543141169..168024556 100644 --- a/tests/db/index/common/doc_test.cc +++ b/tests/db/index/common/doc_test.cc @@ -18,6 +18,7 @@ #include #include #include "utils/utils.h" +#include "zvec/db/index_params.h" #include "zvec/db/status.h" #include "zvec/db/type.h" @@ -823,8 +824,7 @@ TEST_F(DocDetailedTest, ValidateAndSanitization) { auto schema = test::TestHelper::CreateNormalSchema(false); std::vector invalid_names = { // Too long (>64) - std::string(65, 'a'), - std::string(64, 'a') + "_", + std::string(65, 'a'), std::string(64, 'a') + "_", // Illegal characters "a b", // space @@ -1409,6 +1409,55 @@ TEST(VectorQuery, ValidateAndSanitize) { s = query.validate_and_sanitize(&schema); EXPECT_TRUE(s.ok()); } + + // fts_ and vector fields are mutually exclusive + { + auto fts_params = std::make_shared(); + FieldSchema fts_schema("content", DataType::STRING, false, fts_params); + + VectorQuery query; + query.field_name_ = "embedding"; + query.topk_ = 10; + std::vector query_vector(128, 1.0f); + query.query_vector_ = + std::string(reinterpret_cast(query_vector.data()), + query_vector.size() * sizeof(float)); + Fts fts_hello; + fts_hello.query_string_ = "hello"; + query.fts_ = fts_hello; + + // Should fail: both vector and fts_ set + auto s = query.validate_and_sanitize(&fts_schema); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); + + // Clear vector, should pass with FTS schema + query.query_vector_.clear(); + s = query.validate_and_sanitize(&fts_schema); + EXPECT_TRUE(s.ok()); + + // FTS query with proper FTS field schema -> OK + VectorQuery fts_only; + fts_only.field_name_ = "content"; + fts_only.topk_ = 10; + Fts fts_test; + fts_test.query_string_ = "test"; + fts_only.fts_ = fts_test; + s = fts_only.validate_and_sanitize(&fts_schema); + EXPECT_TRUE(s.ok()); + + // FTS query with nullptr schema -> fail (field not found) + s = fts_only.validate_and_sanitize(nullptr); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); + + // FTS query with vector field schema -> fail (type mismatch) + FieldSchema vec_schema("embedding", DataType::VECTOR_FP32, 128, false, + std::make_shared(MetricType::L2)); + s = fts_only.validate_and_sanitize(&vec_schema); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); + } } // Test null value diff --git a/tests/db/sqlengine/CMakeLists.txt b/tests/db/sqlengine/CMakeLists.txt index 7922bbf6b..8b046eeb0 100644 --- a/tests/db/sqlengine/CMakeLists.txt +++ b/tests/db/sqlengine/CMakeLists.txt @@ -25,6 +25,7 @@ foreach(CC_SRCS ${ALL_TEST_SRCS}) LIBS zvec_common zvec_proto zvec_sqlengine + zvec_db zvec_ailego core_metric core_utility diff --git a/tests/db/sqlengine/fts_multi_segment_test.cc b/tests/db/sqlengine/fts_multi_segment_test.cc new file mode 100644 index 000000000..5a130eb61 --- /dev/null +++ b/tests/db/sqlengine/fts_multi_segment_test.cc @@ -0,0 +1,233 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +#include +#include +#include +#include +#include "db/common/file_helper.h" +#include "db/index/common/version_manager.h" +#include "db/index/segment/segment.h" +#include "db/sqlengine/sqlengine.h" +#include "zvec/db/doc.h" +#include "zvec/db/index_params.h" +#include "zvec/db/schema.h" +#include "zvec/db/type.h" + +namespace zvec::sqlengine { + +// Multi-segment FTS recall regression: +// +// The planner's SegmentNode drains per-segment readers in LIFO order, so the +// per-segment BM25 ordering is *not* preserved across the merged stream. The +// planner must therefore add a global order_by on the score column for FTS, +// mirroring what it already does for vector queries. +// +// To make the regression observable we engineer the two segments so that +// * segments_[0] (read LAST) holds the globally highest-scoring doc, and +// * segments_[1] (read FIRST) holds many low-scoring docs. +// +// Per-segment BM25 stats (rare term -> high IDF in segments_[0], common term +// -> low IDF in segments_[1]) guarantee s0_0 outranks every doc in +// segments_[1]. Without the global sort the first doc in the merged stream is +// the much lower-scoring s1_*, which breaks both the descending invariant and +// topk truncation. + +class FtsMultiSegmentTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + FileHelper::RemoveDirectory(root_path_); + FileHelper::CreateDirectory(root_path_); + + build_schema(); + + // segments_[0]: only one doc contains "apple" but with very high TF and + // very low df (rare term) -> high BM25. + auto seg0 = create_segment(root_path_ + "/seg0", "fts_ms_seg0"); + ASSERT_NE(seg0, nullptr); + insert_docs(seg0, /*pk_prefix=*/"s0_", + { + {"apple apple apple apple apple"}, // doc 0: TF=5, df=1 + {"banana"}, + {"cherry"}, + {"date"}, + {"elderberry"}, + }); + + // segments_[1]: all docs contain "apple" (df=N) -> very low IDF -> low + // BM25 across the board. + auto seg1 = create_segment(root_path_ + "/seg1", "fts_ms_seg1"); + ASSERT_NE(seg1, nullptr); + insert_docs(seg1, /*pk_prefix=*/"s1_", + { + {"apple banana"}, + {"apple cherry"}, + {"apple date"}, + {"apple elderberry"}, + }); + + segments_.push_back(seg0); + segments_.push_back(seg1); + + engine_ = SQLEngine::create(std::make_shared()); + } + + static void TearDownTestSuite() { + segments_.clear(); + engine_.reset(); + schema_.reset(); + FileHelper::RemoveDirectory(root_path_); + } + + Result fts_search(const std::string &query_string, + int topk = 10) { + VectorQuery vq; + vq.topk_ = topk; + vq.field_name_ = "content"; + Fts fts; + fts.query_string_ = query_string; + vq.fts_ = fts; + return engine_->execute(schema_, vq, segments_); + } + + private: + static void build_schema() { + auto fts_params = std::make_shared( + "whitespace", std::vector{"lowercase"}, ""); + schema_ = std::make_shared( + "fts_multi_segment_test", + std::vector{ + std::make_shared("content", DataType::STRING, false, + fts_params), + // Dummy vector field keeps the schema parity with the single- + // segment FTS fixture so the analyzer/planner paths behave the + // same. + std::make_shared( + "vec", DataType::VECTOR_FP32, 4, false, + std::make_shared(MetricType::L2)), + }); + } + + static Segment::Ptr create_segment(const std::string &seg_path, + const std::string &name) { + FileHelper::CreateDirectory(seg_path); + + auto segment_meta = std::make_shared(); + segment_meta->set_id(0); + + auto id_map = IDMap::CreateAndOpen(name, seg_path + "/id_map", true, false); + auto delete_store = std::make_shared(name); + + Version v1; + v1.set_schema(*schema_); + std::string v_path = seg_path + "/manifest"; + FileHelper::CreateDirectory(v_path); + auto vm = VersionManager::Create(v_path, v1); + if (!vm.has_value()) { + return nullptr; + } + + BlockMeta mem_block; + mem_block.id_ = 0; + mem_block.type_ = BlockType::SCALAR; + mem_block.min_doc_id_ = 0; + mem_block.max_doc_id_ = 0; + mem_block.doc_count_ = 0; + segment_meta->set_writing_forward_block(mem_block); + + SegmentOptions options; + options.read_only_ = false; + options.enable_mmap_ = true; + options.max_buffer_size_ = 256 * 1024; + + auto result = Segment::CreateAndOpen(seg_path, *schema_, 0, 0, id_map, + delete_store, vm.value(), options); + if (!result) { + return nullptr; + } + return result.value(); + } + + struct Entry { + std::string content; + }; + + static void insert_docs(const Segment::Ptr &segment, + const std::string &pk_prefix, + const std::vector &entries) { + for (size_t i = 0; i < entries.size(); ++i) { + Doc doc; + doc.set_pk(pk_prefix + std::to_string(i)); + doc.set_doc_id(i); + doc.set("content", entries[i].content); + auto status = segment->Insert(doc); + ASSERT_TRUE(status.ok()) + << pk_prefix << i << " insert failed: " << status.c_str(); + } + } + + protected: + static inline std::string root_path_ = "./fts_multi_segment_test_collection"; + static inline CollectionSchema::Ptr schema_; + static inline std::vector segments_; + static inline SQLEngine::Ptr engine_; +}; + +// The merged stream from all segments must be strictly non-increasing in +// score. Without the global order_by, segments_[1]'s low-scoring docs would +// appear before segments_[0]'s much higher-scoring s0_0, violating BM25 rank. +TEST_F(FtsMultiSegmentTest, ScoreDescendingAcrossSegments) { + auto result = fts_search("apple"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + + // s0_0 + s1_0..s1_3 = 5 matches. + ASSERT_EQ(result->size(), 5u); + + // s0_0 (TF=5, rare term in seg0) dominates the 4 low-IDF s1_* docs. + EXPECT_EQ((*result)[0]->pk(), "s0_0"); + EXPECT_GT((*result)[0]->score(), (*result)[1]->score()); + + for (size_t i = 0; i + 1 < result->size(); ++i) { + EXPECT_GE((*result)[i]->score(), (*result)[i + 1]->score()) + << "score not descending at rank " << i << ": " << (*result)[i]->pk() + << "=" << (*result)[i]->score() << " vs " << (*result)[i + 1]->pk() + << "=" << (*result)[i + 1]->score(); + } +} + +// topk must cut against the globally-sorted stream. Without the fix the +// first batch surfaced from SegmentNode comes from segments_[1] (LIFO read), +// so topk=1 would silently drop the highest-scoring s0_0. +TEST_F(FtsMultiSegmentTest, TopkPicksGloballyHighestScore) { + auto result = fts_search("apple", /*topk=*/1); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + ASSERT_EQ(result->size(), 1u); + EXPECT_EQ((*result)[0]->pk(), "s0_0"); +} + +// Sanity: a cross-segment OR query still returns the union of matches and +// stays descending across the segment boundary. +TEST_F(FtsMultiSegmentTest, CrossSegmentUnionDescending) { + // apple: 5 docs (s0_0, s1_0..s1_3). banana: s0_1 (seg0), s1_0 (seg1). + // OR-union: {s0_0, s0_1, s1_0, s1_1, s1_2, s1_3} = 6 docs. + auto result = fts_search("apple banana"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + ASSERT_EQ(result->size(), 6u); + for (size_t i = 0; i + 1 < result->size(); ++i) { + EXPECT_GE((*result)[i]->score(), (*result)[i + 1]->score()) + << "score not descending at rank " << i; + } +} + +} // namespace zvec::sqlengine diff --git a/tests/db/sqlengine/fts_parser_test.cc b/tests/db/sqlengine/fts_parser_test.cc new file mode 100644 index 000000000..2d77e7b1d --- /dev/null +++ b/tests/db/sqlengine/fts_parser_test.cc @@ -0,0 +1,790 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "db/index/column/fts_column/fts_query_ast.h" +#include "db/index/column/fts_column/fts_types.h" +#include "db/index/column/fts_column/parser/fts_query_parser.h" +#include "db/index/column/fts_column/tokenizer/tokenizer_factory.h" + +namespace zvec::fts { + +// ============================================================ +// Test fixture +// ============================================================ + +class FtsParserTest : public ::testing::Test { + protected: + void SetUp() override { + // Standard tokenizer + lowercase filter: ASCII tests behave the same as + // the previous whitespace split (alnum runs become tokens, delimiters + // get dropped) while CJK tests can exercise the per-character tokens + // standard produces from non-alnum bytes. + FtsIndexParams params; + params.tokenizer_name = "standard"; + params.filters = {"lowercase"}; + pipeline_ = TokenizerFactory::create(params); + ASSERT_NE(pipeline_, nullptr); + } + + FtsAstNodePtr parse(const std::string &query) { + return parser_.parse(query, pipeline_); + } + + // Overload for tests that need to specify the default operator explicitly. + FtsAstNodePtr parse(const std::string &query, FtsDefaultOperator default_op) { + return parser_.parse(query, pipeline_, default_op); + } + + const std::string &err_msg() { + return parser_.err_msg(); + } + + // Helpers for type-safe downcasting + static const TermNode &as_term(const FtsAstNode &node) { + EXPECT_EQ(node.type(), FtsNodeType::TERM); + return static_cast(node); + } + + static const PhraseNode &as_phrase(const FtsAstNode &node) { + EXPECT_EQ(node.type(), FtsNodeType::PHRASE); + return static_cast(node); + } + + static const AndNode &as_and(const FtsAstNode &node) { + EXPECT_EQ(node.type(), FtsNodeType::AND); + return static_cast(node); + } + + static const OrNode &as_or(const FtsAstNode &node) { + EXPECT_EQ(node.type(), FtsNodeType::OR); + return static_cast(node); + } + + private: + FtsQueryParser parser_; + TokenizerPipelinePtr pipeline_; +}; + +// ============================================================ +// Single term +// ============================================================ + +TEST_F(FtsParserTest, SingleTerm) { + auto ast = parse("vector"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::TERM); + const auto &term = as_term(*ast); + EXPECT_EQ(term.term, "vector"); + EXPECT_FALSE(term.must); + EXPECT_FALSE(term.must_not); +} + +TEST_F(FtsParserTest, SingleTermNumeric) { + auto ast = parse("2024"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::TERM); + EXPECT_EQ(as_term(*ast).term, "2024"); +} + +TEST_F(FtsParserTest, SingleTermWithHyphen) { + // The lexer's REGULAR_ID rule keeps hyphenated text as one token, but the + // standard tokenizer on the parser side splits non-alphanumerics. With the + // default OR operator the term decomposes into Or[full, text] so query + // segmentation matches the index segmentation. + auto ast = parse("full-text"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + EXPECT_EQ(as_term(*or_node.children[0]).term, "full"); + EXPECT_EQ(as_term(*or_node.children[1]).term, "text"); +} + +// ============================================================ +// Must (+) and must_not (-/NOT) modifiers +// ============================================================ + +TEST_F(FtsParserTest, MustModifier) { + auto ast = parse("+vector"); + ASSERT_NE(ast, nullptr); + const auto &term = as_term(*ast); + EXPECT_EQ(term.term, "vector"); + EXPECT_TRUE(term.must); + EXPECT_FALSE(term.must_not); +} + +TEST_F(FtsParserTest, MustNotModifierMinus) { + // "-slow" is lexed as a single REGULAR_ID token (hyphen is part of the id). + // To express must_not, use a space: "- slow" -> MINUS_SIGN + REGULAR_ID. + auto ast = parse("- slow"); + ASSERT_NE(ast, nullptr); + const auto &term = as_term(*ast); + EXPECT_EQ(term.term, "slow"); + EXPECT_FALSE(term.must); + EXPECT_TRUE(term.must_not); +} + +TEST_F(FtsParserTest, MustNotModifierMinusNoSpace) { + // "-slow" without space: FtsLexer treats '-' as MINUS_SIGN modifier, + // so "-slow" is parsed as must_not:slow (same as "- slow"). + auto ast = parse("-slow"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::TERM); + EXPECT_EQ(as_term(*ast).term, "slow"); + EXPECT_TRUE(as_term(*ast).must_not); +} + +TEST_F(FtsParserTest, MustNotModifierNot) { + // NOT is now a strict binary operator (`a NOT b` <=> `a AND NOT b`). + // A leading `NOT a` is therefore a syntax error — there is no left-hand + // operand for NOT to subtract from. + auto ast = parse("NOT slow"); + EXPECT_EQ(ast, nullptr); + EXPECT_FALSE(err_msg().empty()); +} + +// ============================================================ +// Phrase query +// ============================================================ + +TEST_F(FtsParserTest, DoubleQuotedPhrase) { + auto ast = parse("\"exact phrase\""); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::PHRASE); + const auto &phrase = as_phrase(*ast); + ASSERT_EQ(phrase.terms.size(), 2u); + EXPECT_EQ(phrase.terms[0], "exact"); + EXPECT_EQ(phrase.terms[1], "phrase"); + EXPECT_FALSE(phrase.must); + EXPECT_FALSE(phrase.must_not); +} + +TEST_F(FtsParserTest, SingleQuotedPhrase) { + // Single-quoted strings are not supported as phrase queries (no SQUOTA_STRING + // token). The lexer's TERM rule absorbs "'hello", "world", and "'" as + // individual term tokens, so the query parses as an implicit OR of terms. + auto ast = parse("'hello world'"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); +} + +TEST_F(FtsParserTest, PhraseWithMustModifier) { + auto ast = parse("+\"exact phrase\""); + ASSERT_NE(ast, nullptr); + const auto &phrase = as_phrase(*ast); + EXPECT_TRUE(phrase.must); + EXPECT_FALSE(phrase.must_not); +} + +TEST_F(FtsParserTest, PhraseWithMustNotModifier) { + auto ast = parse("-\"bad phrase\""); + ASSERT_NE(ast, nullptr); + const auto &phrase = as_phrase(*ast); + EXPECT_FALSE(phrase.must); + EXPECT_TRUE(phrase.must_not); +} + +TEST_F(FtsParserTest, PhraseWithThreeWords) { + auto ast = parse("\"one two three\""); + ASSERT_NE(ast, nullptr); + const auto &phrase = as_phrase(*ast); + ASSERT_EQ(phrase.terms.size(), 3u); + EXPECT_EQ(phrase.terms[0], "one"); + EXPECT_EQ(phrase.terms[1], "two"); + EXPECT_EQ(phrase.terms[2], "three"); +} + +// ============================================================ +// Explicit OR +// ============================================================ + +TEST_F(FtsParserTest, ExplicitOr) { + auto ast = parse("cat OR dog"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + EXPECT_EQ(as_term(*or_node.children[0]).term, "cat"); + EXPECT_EQ(as_term(*or_node.children[1]).term, "dog"); +} + +TEST_F(FtsParserTest, MultipleOr) { + auto ast = parse("a OR b OR c"); + ASSERT_NE(ast, nullptr); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 3u); +} + +// ============================================================ +// Explicit AND +// ============================================================ + +TEST_F(FtsParserTest, ExplicitAnd) { + auto ast = parse("cat AND dog"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::AND); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + EXPECT_EQ(as_term(*and_node.children[0]).term, "cat"); + EXPECT_EQ(as_term(*and_node.children[1]).term, "dog"); +} + +TEST_F(FtsParserTest, MultipleAnd) { + auto ast = parse("a AND b AND c"); + ASSERT_NE(ast, nullptr); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 3u); +} + +// ============================================================ +// Operator precedence: AND binds tighter than OR +// ============================================================ + +TEST_F(FtsParserTest, AndBindsTighterThanOr) { + // "a OR b AND c" should parse as "a OR (b AND c)" + auto ast = parse("a OR b AND c"); + ASSERT_NE(ast, nullptr); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + + // Left child: term "a" + EXPECT_EQ(as_term(*or_node.children[0]).term, "a"); + + // Right child: AND(b, c) + const auto &and_node = as_and(*or_node.children[1]); + ASSERT_EQ(and_node.children.size(), 2u); + EXPECT_EQ(as_term(*and_node.children[0]).term, "b"); + EXPECT_EQ(as_term(*and_node.children[1]).term, "c"); +} + +// ============================================================ +// Implicit adjacency (seqExpr / default operator) +// ============================================================ + +TEST_F(FtsParserTest, ImplicitAdjacency) { + // Adjacent terms without explicit operator: "a b" -> seqExpr -> OR(a, b) + auto ast = parse("a b"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + EXPECT_EQ(as_term(*or_node.children[0]).term, "a"); + EXPECT_EQ(as_term(*or_node.children[1]).term, "b"); +} + +TEST_F(FtsParserTest, ImplicitAdjacencyThreeTerms) { + auto ast = parse("a b c"); + ASSERT_NE(ast, nullptr); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 3u); +} + +TEST_F(FtsParserTest, ImplicitAdjacencyWithModifiers) { + // "+a - b" -> seqExpr -> OR(must:a, must_not:b) + // Note: "-b" (no space) is lexed as a single REGULAR_ID; use "- b" for + // must_not. + auto ast = parse("+a - b"); + ASSERT_NE(ast, nullptr); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + EXPECT_TRUE(as_term(*or_node.children[0]).must); + EXPECT_TRUE(as_term(*or_node.children[1]).must_not); +} + +// ============================================================ +// Parentheses grouping +// ============================================================ + +TEST_F(FtsParserTest, Parentheses) { + // "(a OR b) AND c" + auto ast = parse("(a OR b) AND c"); + ASSERT_NE(ast, nullptr); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + + // Left: OR(a, b) + const auto &or_node = as_or(*and_node.children[0]); + ASSERT_EQ(or_node.children.size(), 2u); + + // Right: term c + EXPECT_EQ(as_term(*and_node.children[1]).term, "c"); +} + +TEST_F(FtsParserTest, NestedParentheses) { + auto ast = parse("((a OR b) AND c) OR d"); + ASSERT_NE(ast, nullptr); + const auto &outer_or = as_or(*ast); + ASSERT_EQ(outer_or.children.size(), 2u); + EXPECT_EQ(as_term(*outer_or.children[1]).term, "d"); +} + +// ============================================================ +// Mixed complex queries +// ============================================================ + +TEST_F(FtsParserTest, MixedTermAndPhrase) { + // "+vector - slow \"exact phrase\"" + // Note: use "- slow" (with space) so MINUS_SIGN is a separate token. + auto ast = parse("+vector - slow \"exact phrase\""); + ASSERT_NE(ast, nullptr); + // Four adjacent items -> seqExpr -> OR(must:vector, must_not:slow, phrase) + // Actually: +vector and - slow and phrase are three unary nodes in seqExpr + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 3u); + + EXPECT_TRUE(as_term(*or_node.children[0]).must); + EXPECT_EQ(as_term(*or_node.children[0]).term, "vector"); + + EXPECT_TRUE(as_term(*or_node.children[1]).must_not); + EXPECT_EQ(as_term(*or_node.children[1]).term, "slow"); + + EXPECT_EQ(or_node.children[2]->type(), FtsNodeType::PHRASE); +} + +TEST_F(FtsParserTest, AndWithPhrase) { + auto ast = parse("\"machine learning\" AND model"); + ASSERT_NE(ast, nullptr); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + EXPECT_EQ(and_node.children[0]->type(), FtsNodeType::PHRASE); + EXPECT_EQ(as_term(*and_node.children[1]).term, "model"); +} + +TEST_F(FtsParserTest, ComplexBooleanQuery) { + // "a AND b OR c AND d" -> (a AND b) OR (c AND d) + auto ast = parse("a AND b OR c AND d"); + ASSERT_NE(ast, nullptr); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + + const auto &left_and = as_and(*or_node.children[0]); + ASSERT_EQ(left_and.children.size(), 2u); + + const auto &right_and = as_and(*or_node.children[1]); + ASSERT_EQ(right_and.children.size(), 2u); +} + +// ============================================================ +// Single-child simplification (no unnecessary wrapping) +// ============================================================ + +TEST_F(FtsParserTest, SingleChildNotWrapped) { + // A single term should not be wrapped in an AndNode/OrNode + auto ast = parse("hello"); + ASSERT_NE(ast, nullptr); + EXPECT_EQ(ast->type(), FtsNodeType::TERM); +} + +TEST_F(FtsParserTest, SinglePhraseNotWrapped) { + auto ast = parse("\"hello world\""); + ASSERT_NE(ast, nullptr); + EXPECT_EQ(ast->type(), FtsNodeType::PHRASE); +} + +// ============================================================ +// Error cases +// ============================================================ + +TEST_F(FtsParserTest, EmptyQueryReturnsNull) { + auto ast = parse(""); + EXPECT_EQ(ast, nullptr); +} + +TEST_F(FtsParserTest, OnlyParenthesesReturnsNull) { + auto ast = parse("()"); + EXPECT_EQ(ast, nullptr); +} + +TEST_F(FtsParserTest, UnclosedPhraseReturnsNull) { + // An unclosed double-quote causes the DQUOTA_STRING rule to fail. The + // remaining characters are absorbed by the TERM catch-all rule, so the + // query parses as a single term rather than returning nullptr. + auto ast = parse("\"unclosed phrase"); + ASSERT_NE(ast, nullptr); +} + +TEST_F(FtsParserTest, UnclosedParenReturnsNull) { + auto ast = parse("(a OR b"); + EXPECT_EQ(ast, nullptr); +} + +// ============================================================ +// Empty-AST cases: grammar valid, analyzer drops every term → EmptyNode. +// ============================================================ + +TEST_F(FtsParserTest, PunctuationOnlyReturnsEmpty) { + auto ast = parse("!!!"); + ASSERT_NE(ast, nullptr); + EXPECT_EQ(ast->type(), FtsNodeType::EMPTY); + EXPECT_TRUE(err_msg().empty()); +} + +TEST_F(FtsParserTest, MultiplePunctuationTermsReturnsEmpty) { + auto ast = parse("!!! ??? ..."); + ASSERT_NE(ast, nullptr); + EXPECT_EQ(ast->type(), FtsNodeType::EMPTY); + EXPECT_TRUE(err_msg().empty()); +} + +// ============================================================ +// NOT as a binary AND-NOT operator +// ============================================================ + +TEST_F(FtsParserTest, NotAsBinaryAndNot) { + // `foo NOT bar` <=> `foo AND NOT bar` -> And[foo, bar(must_not)] + auto ast = parse("foo NOT bar"); + ASSERT_NE(ast, nullptr); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + + EXPECT_EQ(as_term(*and_node.children[0]).term, "foo"); + EXPECT_FALSE(and_node.children[0]->must_not); + + EXPECT_EQ(as_term(*and_node.children[1]).term, "bar"); + EXPECT_TRUE(and_node.children[1]->must_not); +} + +TEST_F(FtsParserTest, AndAndNot) { + // `a AND NOT b` -> And[a, b(must_not)] + auto ast = parse("a AND NOT b"); + ASSERT_NE(ast, nullptr); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + EXPECT_EQ(as_term(*and_node.children[0]).term, "a"); + EXPECT_FALSE(and_node.children[0]->must_not); + EXPECT_EQ(as_term(*and_node.children[1]).term, "b"); + EXPECT_TRUE(and_node.children[1]->must_not); +} + +TEST_F(FtsParserTest, OrThenNot) { + // Precedence check: NOT shares AND's precedence (higher than OR). + // `a OR b NOT c` -> Or[a, And[b, c(must_not)]] + auto ast = parse("a OR b NOT c"); + ASSERT_NE(ast, nullptr); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + + EXPECT_EQ(as_term(*or_node.children[0]).term, "a"); + + const auto &right_and = as_and(*or_node.children[1]); + ASSERT_EQ(right_and.children.size(), 2u); + EXPECT_EQ(as_term(*right_and.children[0]).term, "b"); + EXPECT_FALSE(right_and.children[0]->must_not); + EXPECT_EQ(as_term(*right_and.children[1]).term, "c"); + EXPECT_TRUE(right_and.children[1]->must_not); +} + +TEST_F(FtsParserTest, NotWithGroup) { + // `a NOT (b OR c)` -> And[a, Or[b, c](must_not)] + auto ast = parse("a NOT (b OR c)"); + ASSERT_NE(ast, nullptr); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + + EXPECT_EQ(as_term(*and_node.children[0]).term, "a"); + EXPECT_FALSE(and_node.children[0]->must_not); + + ASSERT_EQ(and_node.children[1]->type(), FtsNodeType::OR); + EXPECT_TRUE(and_node.children[1]->must_not); + const auto &grouped_or = as_or(*and_node.children[1]); + ASSERT_EQ(grouped_or.children.size(), 2u); + EXPECT_EQ(as_term(*grouped_or.children[0]).term, "b"); + EXPECT_EQ(as_term(*grouped_or.children[1]).term, "c"); +} + +TEST_F(FtsParserTest, LeadingNotIsError) { + // Leading NOT has no left-hand operand and must fail to parse. + auto ast = parse("NOT a"); + EXPECT_EQ(ast, nullptr); + EXPECT_FALSE(err_msg().empty()); +} + +TEST_F(FtsParserTest, MultipleNotsAndAnds) { + // `a AND b NOT c AND d NOT e` -> And[a, b, c(must_not), d, e(must_not)] + auto ast = parse("a AND b NOT c AND d NOT e"); + ASSERT_NE(ast, nullptr); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 5u); + + EXPECT_EQ(as_term(*and_node.children[0]).term, "a"); + EXPECT_FALSE(and_node.children[0]->must_not); + + EXPECT_EQ(as_term(*and_node.children[1]).term, "b"); + EXPECT_FALSE(and_node.children[1]->must_not); + + EXPECT_EQ(as_term(*and_node.children[2]).term, "c"); + EXPECT_TRUE(and_node.children[2]->must_not); + + EXPECT_EQ(as_term(*and_node.children[3]).term, "d"); + EXPECT_FALSE(and_node.children[3]->must_not); + + EXPECT_EQ(as_term(*and_node.children[4]).term, "e"); + EXPECT_TRUE(and_node.children[4]->must_not); +} + +// ============================================================ +// +/- modifiers on parenthesised sub-expressions +// ============================================================ + +TEST_F(FtsParserTest, MustOnGroup) { + // `+(a OR b)` -> Or[a, b]{must=true} + auto ast = parse("+(a OR b)"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); + EXPECT_TRUE(ast->must); + EXPECT_FALSE(ast->must_not); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + EXPECT_EQ(as_term(*or_node.children[0]).term, "a"); + EXPECT_EQ(as_term(*or_node.children[1]).term, "b"); +} + +TEST_F(FtsParserTest, MustNotOnGroup) { + // `-(a AND b)` -> And[a, b]{must_not=true} + auto ast = parse("-(a AND b)"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::AND); + EXPECT_FALSE(ast->must); + EXPECT_TRUE(ast->must_not); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + EXPECT_EQ(as_term(*and_node.children[0]).term, "a"); + EXPECT_EQ(as_term(*and_node.children[1]).term, "b"); +} + +TEST_F(FtsParserTest, MustGroupAndOther) { + // `+(a OR b) c` -> implicit-OR collapses three siblings into a single + // OrNode: Or[Or[a, b]{must=true}, c] + // (the inner OR keeps its must flag; implicit adjacency is still OR.) + auto ast = parse("+(a OR b) c"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); + const auto &outer_or = as_or(*ast); + ASSERT_EQ(outer_or.children.size(), 2u); + + ASSERT_EQ(outer_or.children[0]->type(), FtsNodeType::OR); + EXPECT_TRUE(outer_or.children[0]->must); + const auto &inner_or = as_or(*outer_or.children[0]); + ASSERT_EQ(inner_or.children.size(), 2u); + EXPECT_EQ(as_term(*inner_or.children[0]).term, "a"); + EXPECT_EQ(as_term(*inner_or.children[1]).term, "b"); + + EXPECT_EQ(as_term(*outer_or.children[1]).term, "c"); +} + +TEST_F(FtsParserTest, NestedGroupModifier) { + // `+((a AND b) OR c)` -> the must flag attaches to the outermost OrNode. + auto ast = parse("+((a AND b) OR c)"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); + EXPECT_TRUE(ast->must); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + + ASSERT_EQ(or_node.children[0]->type(), FtsNodeType::AND); + EXPECT_FALSE(or_node.children[0]->must); // inner AND not affected + const auto &inner_and = as_and(*or_node.children[0]); + ASSERT_EQ(inner_and.children.size(), 2u); + EXPECT_EQ(as_term(*inner_and.children[0]).term, "a"); + EXPECT_EQ(as_term(*inner_and.children[1]).term, "b"); + + EXPECT_EQ(as_term(*or_node.children[1]).term, "c"); +} + +// ============================================================ +// Default operator (FtsDefaultOperator::OR / AND) +// Only adjacent bare terms (no explicit operator) are affected; explicit +// AND / OR / + / - usages keep their original semantics. +// ============================================================ + +TEST_F(FtsParserTest, DefaultOperatorOr_AdjacentBareTerms) { + // Backward-compat: omitting default_op or passing OR yields the original + // implicit-OR behaviour for adjacent bare terms. + auto ast = parse("vector database", FtsDefaultOperator::OR); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + EXPECT_EQ(as_term(*or_node.children[0]).term, "vector"); + EXPECT_EQ(as_term(*or_node.children[1]).term, "database"); +} + +TEST_F(FtsParserTest, DefaultOperatorAnd_AdjacentBareTerms) { + // With AND default, two adjacent bare terms collapse into an AndNode. + auto ast = parse("vector database", FtsDefaultOperator::AND); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::AND); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + EXPECT_EQ(as_term(*and_node.children[0]).term, "vector"); + EXPECT_EQ(as_term(*and_node.children[1]).term, "database"); +} + +TEST_F(FtsParserTest, DefaultOperatorAnd_SingleTermUnchanged) { + // A single term should not be wrapped in an AndNode. + auto ast = parse("vector", FtsDefaultOperator::AND); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::TERM); + EXPECT_EQ(as_term(*ast).term, "vector"); +} + +TEST_F(FtsParserTest, DefaultOperatorAnd_PropagatesIntoParens) { + // Parenthesised sub-expressions inherit the same default operator. + // `(a b) c` with AND default -> And[And[a, b], c]. + auto ast = parse("(a b) c", FtsDefaultOperator::AND); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::AND); + const auto &outer_and = as_and(*ast); + ASSERT_EQ(outer_and.children.size(), 2u); + + ASSERT_EQ(outer_and.children[0]->type(), FtsNodeType::AND); + const auto &inner_and = as_and(*outer_and.children[0]); + ASSERT_EQ(inner_and.children.size(), 2u); + EXPECT_EQ(as_term(*inner_and.children[0]).term, "a"); + EXPECT_EQ(as_term(*inner_and.children[1]).term, "b"); + + EXPECT_EQ(as_term(*outer_and.children[1]).term, "c"); +} + +TEST_F(FtsParserTest, DefaultOperatorAnd_DoesNotOverrideExplicitOr) { + // Explicit OR has higher-level structure; default_op only changes the + // implicit adjacency inside each seqExpr. + // `a OR b c` with AND default -> Or[a, And[b, c]]. + auto ast = parse("a OR b c", FtsDefaultOperator::AND); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + + EXPECT_EQ(as_term(*or_node.children[0]).term, "a"); + + ASSERT_EQ(or_node.children[1]->type(), FtsNodeType::AND); + const auto &inner_and = as_and(*or_node.children[1]); + ASSERT_EQ(inner_and.children.size(), 2u); + EXPECT_EQ(as_term(*inner_and.children[0]).term, "b"); + EXPECT_EQ(as_term(*inner_and.children[1]).term, "c"); +} + +TEST_F(FtsParserTest, DefaultOperatorOr_DoesNotOverrideExplicitAnd) { + // Grammar: andExpr = seqExpr ((AND|NOT) seqExpr)* + // `a AND b c` parses as seqExpr("a") AND seqExpr("b c"). + // With OR default, seqExpr("b c") -> Or[b, c]. + // Result: And[a, Or[b, c]]. + auto ast = parse("a AND b c", FtsDefaultOperator::OR); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::AND); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + + EXPECT_EQ(as_term(*and_node.children[0]).term, "a"); + + ASSERT_EQ(and_node.children[1]->type(), FtsNodeType::OR); + const auto &inner_or = as_or(*and_node.children[1]); + ASSERT_EQ(inner_or.children.size(), 2u); + EXPECT_EQ(as_term(*inner_or.children[0]).term, "b"); + EXPECT_EQ(as_term(*inner_or.children[1]).term, "c"); +} + +TEST_F(FtsParserTest, DefaultOperatorAnd_PreservesPlusMinusModifiers) { + // `+a b -c` with AND default -> And[a{must}, b, c{must_not}]. + // Modifiers on individual terms are independent of default_op. + auto ast = parse("+a b -c", FtsDefaultOperator::AND); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::AND); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 3u); + + const auto &t0 = as_term(*and_node.children[0]); + EXPECT_EQ(t0.term, "a"); + EXPECT_TRUE(t0.must); + EXPECT_FALSE(t0.must_not); + + const auto &t1 = as_term(*and_node.children[1]); + EXPECT_EQ(t1.term, "b"); + EXPECT_FALSE(t1.must); + EXPECT_FALSE(t1.must_not); + + const auto &t2 = as_term(*and_node.children[2]); + EXPECT_EQ(t2.term, "c"); + EXPECT_FALSE(t2.must); + EXPECT_TRUE(t2.must_not); +} + +// ============================================================ +// Pipeline-aware tokenization (phrase / bare term split through pipeline) +// ============================================================ + +TEST_F(FtsParserTest, MultiTokenBareTermAndDefaultGroupsAsAnd) { + // `full-text` lexes as one REGULAR_ID, but standard splits it into + // ["full", "text"]. With AND default operator the two tokens combine into + // an AndNode rather than the OR returned by the OR-default test above. + auto ast = parse("full-text", FtsDefaultOperator::AND); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::AND); + const auto &and_node = as_and(*ast); + ASSERT_EQ(and_node.children.size(), 2u); + EXPECT_EQ(as_term(*and_node.children[0]).term, "full"); + EXPECT_EQ(as_term(*and_node.children[1]).term, "text"); +} + +TEST_F(FtsParserTest, MultiTokenBareTermPreservesMustModifier) { + // `+full-text` -> Or[full, text] with must=true on the composite root. + auto ast = parse("+full-text"); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::OR); + EXPECT_TRUE(ast->must); + EXPECT_FALSE(ast->must_not); + const auto &or_node = as_or(*ast); + ASSERT_EQ(or_node.children.size(), 2u); + EXPECT_EQ(as_term(*or_node.children[0]).term, "full"); + EXPECT_EQ(as_term(*or_node.children[1]).term, "text"); +} + +TEST_F(FtsParserTest, PhraseTokensRunThroughPipeline) { + // The phrase body is tokenized exactly like document text. With the + // standard tokenizer, mixed delimiters between alnum runs collapse so + // "machine, learning!" becomes ["machine", "learning"]. + auto ast = parse("\"machine, learning!\""); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::PHRASE); + const auto &phrase = as_phrase(*ast); + ASSERT_EQ(phrase.terms.size(), 2u); + EXPECT_EQ(phrase.terms[0], "machine"); + EXPECT_EQ(phrase.terms[1], "learning"); +} + +TEST_F(FtsParserTest, PhraseLowercaseFilterApplies) { + // The lowercase filter is part of the pipeline so phrase tokens come back + // lowercased even when the input mixed case. + auto ast = parse("\"Machine LEARNING\""); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::PHRASE); + const auto &phrase = as_phrase(*ast); + ASSERT_EQ(phrase.terms.size(), 2u); + EXPECT_EQ(phrase.terms[0], "machine"); + EXPECT_EQ(phrase.terms[1], "learning"); +} + +TEST_F(FtsParserTest, AllPunctuationPhraseYieldsEmptyTerms) { + // Pure non-alnum content is filtered out entirely. The phrase node still + // exists but carries zero terms; the search engine treats this as + // "match nothing" without crashing. + auto ast = parse("\"!!! ???\""); + ASSERT_NE(ast, nullptr); + ASSERT_EQ(ast->type(), FtsNodeType::PHRASE); + EXPECT_TRUE(as_phrase(*ast).terms.empty()); +} + +} // namespace zvec::fts diff --git a/tests/db/sqlengine/fts_recall_test.cc b/tests/db/sqlengine/fts_recall_test.cc new file mode 100644 index 000000000..c84ccdd16 --- /dev/null +++ b/tests/db/sqlengine/fts_recall_test.cc @@ -0,0 +1,598 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +#include +#include +#include +#include +#include +#include +#include "db/common/file_helper.h" +#include "db/index/common/version_manager.h" +#include "db/index/segment/segment.h" +#include "db/sqlengine/sqlengine.h" +#include "zvec/db/doc.h" +#include "zvec/db/index_params.h" +#include "zvec/db/query_params.h" +#include "zvec/db/schema.h" +#include "zvec/db/type.h" + +namespace zvec::sqlengine { + +// ============================================================ +// FTS Recall Test fixture (real Segment + SQLEngine::execute via VectorQuery) +// ============================================================ + +class FtsRecallTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + FileHelper::RemoveDirectory(seg_path_); + FileHelper::CreateDirectory(seg_path_); + + build_schema(); + auto segment = create_segment(); + ASSERT_NE(segment, nullptr); + insert_docs(segment); + segments_.push_back(segment); + + engine_ = SQLEngine::create(std::make_shared()); + } + + static void TearDownTestSuite() { + segments_.clear(); + engine_.reset(); + schema_.reset(); + FileHelper::RemoveDirectory(seg_path_); + } + + // Helper: execute FTS query_string search via VectorQuery + Result fts_search(const std::string &query_string, + int topk = 10) { + VectorQuery vq; + vq.topk_ = topk; + vq.field_name_ = "content"; + Fts fts; + fts.query_string_ = query_string; + vq.fts_ = fts; + return engine_->execute(schema_, vq, segments_); + } + + // Helper: execute FTS match_string search via VectorQuery + Result fts_match(const std::string &match_string, + const std::string &default_op = "", + int topk = 10) { + VectorQuery vq; + vq.topk_ = topk; + vq.field_name_ = "content"; + Fts fts; + fts.match_string_ = match_string; + vq.fts_ = fts; + if (!default_op.empty()) { + auto fts_qp = std::make_shared(); + fts_qp->set_default_operator(default_op); + vq.query_params_ = fts_qp; + } + return engine_->execute(schema_, vq, segments_); + } + + // Helper: execute FTS query_string with default_operator via VectorQuery + Result fts_query_with_op(const std::string &query_string, + const std::string &default_op, + int topk = 10) { + VectorQuery vq; + vq.topk_ = topk; + vq.field_name_ = "content"; + Fts fts; + fts.query_string_ = query_string; + vq.fts_ = fts; + auto fts_qp = std::make_shared(); + fts_qp->set_default_operator(default_op); + vq.query_params_ = fts_qp; + return engine_->execute(schema_, vq, segments_); + } + + // Helper: execute FTS query_string with WHERE filter via VectorQuery + Result fts_search_with_filter(const std::string &query_string, + const std::string &filter, + int topk = 10) { + VectorQuery vq; + vq.topk_ = topk; + vq.field_name_ = "content"; + vq.filter_ = filter; + Fts fts; + fts.query_string_ = query_string; + vq.fts_ = fts; + return engine_->execute(schema_, vq, segments_); + } + + private: + static void build_schema() { + auto fts_params = std::make_shared( + "whitespace", std::vector{"lowercase"}, ""); + auto invert_params = std::make_shared(true); + schema_ = std::make_shared( + "fts_recall_test", + std::vector{ + std::make_shared("content", DataType::STRING, false, + fts_params), + std::make_shared("tag", DataType::INT32, false, + invert_params), + // Dummy vector field required for filter parsing path in + // execute + std::make_shared( + "vec", DataType::VECTOR_FP32, 4, false, + std::make_shared(MetricType::L2)), + }); + } + + static Segment::Ptr create_segment() { + auto segment_meta = std::make_shared(); + segment_meta->set_id(0); + + auto id_map = IDMap::CreateAndOpen("fts_recall_test", seg_path_ + "/id_map", + true, false); + auto delete_store = std::make_shared("fts_recall_test"); + + Version v1; + v1.set_schema(*schema_); + std::string v_path = seg_path_ + "/manifest"; + FileHelper::CreateDirectory(v_path); + auto vm = VersionManager::Create(v_path, v1); + if (!vm.has_value()) { + return nullptr; + } + + BlockMeta mem_block; + mem_block.id_ = 0; + mem_block.type_ = BlockType::SCALAR; + mem_block.min_doc_id_ = 0; + mem_block.max_doc_id_ = 0; + mem_block.doc_count_ = 0; + segment_meta->set_writing_forward_block(mem_block); + + SegmentOptions options; + options.read_only_ = false; + options.enable_mmap_ = true; + options.max_buffer_size_ = 256 * 1024; + + auto result = Segment::CreateAndOpen(seg_path_, *schema_, 0, 0, id_map, + delete_store, vm.value(), options); + if (!result) { + return nullptr; + } + return result.value(); + } + + static void insert_docs(const Segment::Ptr &segment) { + // doc_id 0: "apple banana cherry" tag=1 + // doc_id 1: "banana date elderberry" tag=2 + // doc_id 2: "cherry fig grape" tag=1 + // doc_id 3: "apple fig honeydew" tag=2 + // doc_id 4: "date grape kiwi" tag=1 + // doc_id 5: "apple apple apple" tag=2 + // doc_id 6: "mango papaya starfruit" tag=1 + // doc_id 7: "banana banana grape" tag=2 + struct Entry { + std::string content; + int32_t tag; + }; + std::vector entries = { + {"apple banana cherry", 1}, {"banana date elderberry", 2}, + {"cherry fig grape", 1}, {"apple fig honeydew", 2}, + {"date grape kiwi", 1}, {"apple apple apple", 2}, + {"mango papaya starfruit", 1}, {"banana banana grape", 2}, + }; + + for (size_t i = 0; i < entries.size(); ++i) { + Doc doc; + doc.set_pk("pk_" + std::to_string(i)); + doc.set_doc_id(i); + doc.set("content", entries[i].content); + doc.set("tag", entries[i].tag); + auto status = segment->Insert(doc); + ASSERT_TRUE(status.ok()) + << "Insert doc " << i << " failed: " << status.c_str(); + } + } + + protected: + static inline std::string seg_path_ = "./fts_recall_test_collection"; + static inline CollectionSchema::Ptr schema_; + static inline std::vector segments_; + static inline SQLEngine::Ptr engine_; +}; + +// ============================================================ +// Basic FTS search tests +// ============================================================ + +// "apple" matches docs 0, 3, 5 +TEST_F(FtsRecallTest, BasicSingleTerm) { + auto result = fts_search("apple"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 3u); +} + +// BM25 ordering: doc 5 ("apple apple apple") should have highest score +TEST_F(FtsRecallTest, BM25ScoreOrdering) { + auto result = fts_search("apple"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + ASSERT_GE(result->size(), 2u); + + // Results should be sorted by score descending + for (size_t i = 0; i + 1 < result->size(); ++i) { + EXPECT_GE((*result)[i]->score(), (*result)[i + 1]->score()) + << "Results not sorted descending at index " << i; + } + // Doc 5 has highest TF for "apple" + EXPECT_EQ((*result)[0]->pk(), "pk_5"); +} + +// "kiwi" only in doc 4 +TEST_F(FtsRecallTest, SingleMatch) { + auto result = fts_search("kiwi"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + ASSERT_EQ(result->size(), 1u); + EXPECT_EQ((*result)[0]->pk(), "pk_4"); +} + +// Nonexistent term +TEST_F(FtsRecallTest, NoMatch) { + auto result = fts_search("zzznomatch"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 0u); +} + +// Topk limit: "banana" in docs 0, 1, 7 (3 matches), topk=2 +TEST_F(FtsRecallTest, TopkLimit) { + auto result = fts_search("banana", /*topk=*/2); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_LE(result->size(), 2u); +} + +// Multi-term implicit OR: "apple banana" matches union of {0,3,5} and {0,1,7} +TEST_F(FtsRecallTest, MultiTermImplicitOr) { + auto result = fts_search("apple banana"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + // Union: {0,1,3,5,7} = 5 docs + EXPECT_EQ(result->size(), 5u); +} + +// "starfruit" only in doc 6 +TEST_F(FtsRecallTest, RareTerm) { + auto result = fts_search("starfruit"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + ASSERT_EQ(result->size(), 1u); + EXPECT_EQ((*result)[0]->pk(), "pk_6"); +} + +// "grape" in docs 2, 4, 7 +TEST_F(FtsRecallTest, CommonTerm) { + auto result = fts_search("grape"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 3u); +} + +// ============================================================ +// Explicit AND +// ============================================================ + +// "apple AND banana" -> intersection of {0,3,5} and {0,1,7} = {0} +TEST_F(FtsRecallTest, ExplicitAnd) { + auto result = fts_search("apple AND banana"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 1u); + EXPECT_EQ((*result)[0]->pk(), "pk_0"); +} + +// "cherry AND fig" -> {0,2} AND {2,3} = {2} +TEST_F(FtsRecallTest, ExplicitAnd2) { + auto result = fts_search("cherry AND fig"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 1u); + EXPECT_EQ((*result)[0]->pk(), "pk_2"); +} + +// ============================================================ +// Binary NOT (AND-NOT) +// ============================================================ + +// "apple NOT banana" -> {0,3,5} minus {0,1,7} = {3,5} +TEST_F(FtsRecallTest, BinaryNot) { + auto result = fts_search("apple NOT banana"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 2u); + std::set pks; + for (auto &doc : *result) { + pks.insert(doc->pk()); + } + EXPECT_TRUE(pks.count("pk_3")); + EXPECT_TRUE(pks.count("pk_5")); +} + +// "banana NOT grape" -> {0,1,7} minus {2,4,7} = {0,1} +TEST_F(FtsRecallTest, BinaryNot2) { + auto result = fts_search("banana NOT grape"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 2u); + std::set pks; + for (auto &doc : *result) { + pks.insert(doc->pk()); + } + EXPECT_TRUE(pks.count("pk_0")); + EXPECT_TRUE(pks.count("pk_1")); +} + +// ============================================================ +// Error cases +// ============================================================ + +// Leading NOT should fail parse +TEST_F(FtsRecallTest, LeadingNotIsRejected) { + auto result = fts_search("NOT apple"); + EXPECT_FALSE(result.has_value()); +} + +// Both query_string_ and match_string_ empty +TEST_F(FtsRecallTest, BothEmptyReturnsError) { + VectorQuery vq; + vq.topk_ = 10; + vq.field_name_ = "content"; + vq.fts_ = Fts{}; // both fields empty + auto result = engine_->execute(schema_, vq, segments_); + EXPECT_FALSE(result.has_value()); +} + +// Both query_string_ and match_string_ set +TEST_F(FtsRecallTest, BothSetReturnsError) { + VectorQuery vq; + vq.topk_ = 10; + vq.field_name_ = "content"; + Fts fts; + fts.query_string_ = "apple"; + fts.match_string_ = "banana"; + vq.fts_ = fts; + auto result = engine_->execute(schema_, vq, segments_); + EXPECT_FALSE(result.has_value()); +} + +// ============================================================ +// match_string tests +// ============================================================ + +// match_string "starfruit" -> doc 6 +TEST_F(FtsRecallTest, MatchStringRareTerm) { + auto result = fts_match("starfruit"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + ASSERT_EQ(result->size(), 1u); + EXPECT_EQ((*result)[0]->pk(), "pk_6"); +} + +// match_string "grape" -> docs 2, 4, 7 +TEST_F(FtsRecallTest, MatchStringCommonTerm) { + auto result = fts_match("grape"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 3u); +} + +// match_string "apple banana" -> OR -> union {0,1,3,5,7} +TEST_F(FtsRecallTest, MatchStringMultipleTokens) { + auto result = fts_match("apple banana"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 5u); +} + +// match_string analysing to zero tokens → empty result, not an error. +TEST_F(FtsRecallTest, MatchStringEmptyTokensReturnsNoResults) { + auto result = fts_match(" \t "); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_TRUE(result->empty()); +} + +// ============================================================ +// default_operator tests +// ============================================================ + +// AND default for match_string: "apple banana" -> intersection = {0} +TEST_F(FtsRecallTest, DefaultOperatorAnd_MatchString) { + auto result = fts_match("apple banana", "AND"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 1u); + EXPECT_EQ((*result)[0]->pk(), "pk_0"); +} + +// OR default for match_string (backward compat) +TEST_F(FtsRecallTest, DefaultOperatorOr_MatchString) { + auto result = fts_match("apple banana", "OR"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 5u); +} + +// AND default for query_string: "apple banana" -> AND +TEST_F(FtsRecallTest, DefaultOperatorAnd_QueryString) { + auto result = fts_query_with_op("apple banana", "AND"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 1u); + EXPECT_EQ((*result)[0]->pk(), "pk_0"); +} + +// Explicit OR in query not overridden by default_operator=AND +// "apple OR grape" with AND default -> OR still applies +TEST_F(FtsRecallTest, DefaultOperatorAnd_DoesNotOverrideExplicitOr) { + auto result = fts_query_with_op("apple OR grape", "AND"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + // apple: {0,3,5}, grape: {2,4,7} -> union = 6 + EXPECT_EQ(result->size(), 6u); +} + +// Empty default_operator keeps historical OR for match_string +TEST_F(FtsRecallTest, DefaultOperatorEmpty_BackwardCompatibleOr) { + auto result = fts_match("apple banana"); // no default_op arg + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + // OR semantics: union of apple{0,3,5} and banana{0,1,7} = 5 + EXPECT_EQ(result->size(), 5u); +} + +// Lowercase "and" must be accepted +TEST_F(FtsRecallTest, DefaultOperatorAndLowercase_Accepted) { + auto result = fts_match("apple banana", "and"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 1u); +} + +// Mixed-case "And" / "oR" are accepted via case-insensitive normalisation. +TEST_F(FtsRecallTest, DefaultOperatorMixedCase_Accepted) { + { + // "And" -> AND semantics: intersection of apple{0,3,5} and banana{0,1,7} + auto result = fts_match("apple banana", "And"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 1u); + } + { + // "oR" -> OR semantics: union = 5 docs + auto result = fts_match("apple banana", "oR"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_EQ(result->size(), 5u); + } +} + +// Invalid default_operator value should be rejected (was previously silently +// downgraded to OR). +TEST_F(FtsRecallTest, DefaultOperatorInvalid_Rejected) { + auto result = fts_match("apple banana", "xor"); + EXPECT_FALSE(result.has_value()); +} + +// ============================================================ +// Error cases (additional) +// ============================================================ + +// Empty field_name should fail +TEST_F(FtsRecallTest, EmptyFieldNameReturnsError) { + VectorQuery vq; + vq.topk_ = 10; + vq.field_name_ = ""; + Fts fts; + fts.query_string_ = "apple"; + vq.fts_ = fts; + auto result = engine_->execute(schema_, vq, segments_); + EXPECT_FALSE(result.has_value()); +} + +// Empty query_string (with field_name set) should fail +TEST_F(FtsRecallTest, EmptyQueryStringReturnsError) { + VectorQuery vq; + vq.topk_ = 10; + vq.field_name_ = "content"; + // Both query_string_ and match_string_ empty -> error + vq.fts_ = Fts{}; + auto result = engine_->execute(schema_, vq, segments_); + EXPECT_FALSE(result.has_value()); +} + +// ============================================================ +// FTS search with WHERE filter +// ============================================================ + +// "apple" (docs 0,3,5) + tag = 1 (docs 0,2,4,6) -> intersection = {0} +TEST_F(FtsRecallTest, FtsSearchWithFilter_ScoreTag) { + auto result = fts_search_with_filter("apple", "tag = 1"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + // Filter should reduce results to doc 0 only + EXPECT_LE(result->size(), 3u); + // Verify that at least doc 0 (which satisfies both FTS and filter) is present + bool found_pk0 = false; + for (auto &doc : *result) { + if (doc->pk() == "pk_0") { + found_pk0 = true; + } + } + EXPECT_TRUE(found_pk0); +} + +// "banana" (docs 0,1,7) + tag = 2 (docs 1,3,5,7) + topk=1 +TEST_F(FtsRecallTest, FtsSearchWithFilter_TopkRespected) { + auto result = fts_search_with_filter("banana", "tag = 2", /*topk=*/1); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_LE(result->size(), 1u); +} + +// ============================================================ +// Repeated-term linearity: the AST rewriter collapses a repeated term into a +// single TermNode whose boost equals the occurrence count. With linear boost +// the per-document score must be exactly N× the single-term score, matching +// the pre-rewrite "N independent scorers summed" semantics. +// ============================================================ + +TEST_F(FtsRecallTest, MatchStringRepeatedTermLinearBoost) { + auto baseline = fts_match("apple"); + auto repeated = fts_match("apple apple"); + ASSERT_TRUE(baseline.has_value()) << baseline.error().c_str(); + ASSERT_TRUE(repeated.has_value()) << repeated.error().c_str(); + ASSERT_EQ(baseline->size(), repeated->size()); + + // Same doc set, same ordering — only the absolute scores differ. + for (size_t i = 0; i < baseline->size(); ++i) { + EXPECT_EQ((*baseline)[i]->pk(), (*repeated)[i]->pk()) << "rank " << i; + EXPECT_FLOAT_EQ((*baseline)[i]->score() * 2.0f, (*repeated)[i]->score()) + << "rank " << i << " pk=" << (*repeated)[i]->pk(); + } +} + +// Unary `-` prefix inside an OR was previously executed via build_or_iterator +// wrapping the disjunction in a must_not Conjunction. After the rewriter +// canonicalizes OR-with-must_not into AND(positive..., -negative...), the +// must_not iterator path lives only in build_and_iterator. End-to-end the +// match set must be unchanged: apple{0,3,5} − banana{0,1,7} = {3, 5}. +TEST_F(FtsRecallTest, QueryStringUnaryMinusExcludesMatchingDocs) { + auto result = fts_search("apple -banana"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + std::set pks; + for (const auto &d : *result) { + pks.insert(d->pk()); + } + EXPECT_EQ(pks, std::set({"pk_3", "pk_5"})); +} + +// `apple -apple` is a self-contradiction; the rewriter detects the must vs +// must_not conflict after canonicalization and rewrites the whole subtree +// to EmptyNode, so the query short-circuits to zero docs. +TEST_F(FtsRecallTest, QueryStringSelfContradictionReturnsNoResults) { + auto result = fts_search("apple -apple"); + ASSERT_TRUE(result.has_value()) << result.error().c_str(); + EXPECT_TRUE(result->empty()); +} + +TEST_F(FtsRecallTest, MatchStringRepeatedTermPreservesUnion) { + // "apple apple banana" — apple repeated, banana once. Doc set must equal + // "apple banana" (union), and apple-only docs should score 2× their + // single-term score plus zero for banana. + auto plain_union = fts_match("apple banana"); + auto repeated_union = fts_match("apple apple banana"); + ASSERT_TRUE(plain_union.has_value()) << plain_union.error().c_str(); + ASSERT_TRUE(repeated_union.has_value()) << repeated_union.error().c_str(); + EXPECT_EQ(plain_union->size(), repeated_union->size()); + + std::set plain_pks; + std::set repeated_pks; + for (const auto &d : *plain_union) { + plain_pks.insert(d->pk()); + } + for (const auto &d : *repeated_union) { + repeated_pks.insert(d->pk()); + } + EXPECT_EQ(plain_pks, repeated_pks); +} + +} // namespace zvec::sqlengine diff --git a/tests/db/sqlengine/mock_segment.h b/tests/db/sqlengine/mock_segment.h index ccb65f800..4892b2c66 100644 --- a/tests/db/sqlengine/mock_segment.h +++ b/tests/db/sqlengine/mock_segment.h @@ -499,6 +499,17 @@ class MockSegment : public Segment { return {}; } + fts::FtsColumnIndexerPtr get_fts_indexer( + const std::string &field_name) const override { + return nullptr; + } + + Result> fts_search( + const std::string &field_name, const fts::FtsAstNode &ast, + const fts::FtsQueryParams ¶ms) override { + return std::vector{}; + } + Status flush() override { return Status::OK(); } diff --git a/thirdparty/CMakeLists.txt b/thirdparty/CMakeLists.txt index 22f06ceae..01561e5c7 100644 --- a/thirdparty/CMakeLists.txt +++ b/thirdparty/CMakeLists.txt @@ -26,4 +26,7 @@ add_subdirectory(CRoaring CRoaring EXCLUDE_FROM_ALL) add_subdirectory(arrow arrow EXCLUDE_FROM_ALL) add_subdirectory(magic_enum magic_enum EXCLUDE_FROM_ALL) add_subdirectory(RaBitQ-Library RaBitQ-Library EXCLUDE_FROM_ALL) +add_subdirectory(FastPFOR FastPFOR EXCLUDE_FROM_ALL) +add_subdirectory(limonp limonp EXCLUDE_FROM_ALL) +add_subdirectory(cppjieba cppjieba EXCLUDE_FROM_ALL) diff --git a/thirdparty/FastPFOR/CMakeLists.txt b/thirdparty/FastPFOR/CMakeLists.txt new file mode 100644 index 000000000..77a8dfba9 --- /dev/null +++ b/thirdparty/FastPFOR/CMakeLists.txt @@ -0,0 +1,46 @@ +include(${CMAKE_SOURCE_DIR}/cmake/bazel.cmake) + +# On ARM platforms, FastPFOR uses SIMDe to emulate SSE intrinsics. +# Detection covers native ARM builds AND cross-compilation (e.g. iOS/Android). +set(_FASTPFOR_IS_ARM FALSE) +if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|arm64") + set(_FASTPFOR_IS_ARM TRUE) +elseif(CMAKE_OSX_ARCHITECTURES MATCHES "arm64") + set(_FASTPFOR_IS_ARM TRUE) +elseif(CMAKE_SYSTEM_NAME STREQUAL "iOS") + set(_FASTPFOR_IS_ARM TRUE) +endif() + +if(_FASTPFOR_IS_ARM) + include(FetchContent) + FetchContent_Declare( + simde + GIT_REPOSITORY https://github.com/simd-everywhere/simde.git + GIT_TAG v0.8.2 + ) + FetchContent_MakeAvailable(simde) + set(FASTPFOR_EXTRA_INCS ${simde_SOURCE_DIR}) + set(FASTPFOR_EXTRA_CXXFLAGS "") + set(FASTPFOR_EXTRA_DEFS SIMDE_ENABLE_NATIVE_ALIASES) +elseif(MSVC) + set(FASTPFOR_EXTRA_INCS "") + set(FASTPFOR_EXTRA_CXXFLAGS "") + set(FASTPFOR_EXTRA_DEFS "") +else() + set(FASTPFOR_EXTRA_INCS "") + set(FASTPFOR_EXTRA_CXXFLAGS -msse4.1) + set(FASTPFOR_EXTRA_DEFS "") +endif() + +cc_library( + NAME FastPFOR STATIC + SRCS FastPFOR-0.4.0/src/simdbitpacking.cpp + FastPFOR-0.4.0/src/bitpacking.cpp + FastPFOR-0.4.0/src/bitpackingaligned.cpp + FastPFOR-0.4.0/src/bitpackingunaligned.cpp + FastPFOR-0.4.0/src/simdunalignedbitpacking.cpp + INCS FastPFOR-0.4.0/headers ${FASTPFOR_EXTRA_INCS} + PUBINCS FastPFOR-0.4.0/headers ${FASTPFOR_EXTRA_INCS} + DEFS ${FASTPFOR_EXTRA_DEFS} + CXXFLAGS ${FASTPFOR_EXTRA_CXXFLAGS} +) diff --git a/thirdparty/FastPFOR/FastPFOR-0.4.0 b/thirdparty/FastPFOR/FastPFOR-0.4.0 new file mode 160000 index 000000000..2be1f9769 --- /dev/null +++ b/thirdparty/FastPFOR/FastPFOR-0.4.0 @@ -0,0 +1 @@ +Subproject commit 2be1f976935b8ff9296b029f574d7f964be9d35d diff --git a/thirdparty/cppjieba/CMakeLists.txt b/thirdparty/cppjieba/CMakeLists.txt new file mode 100644 index 000000000..4c80932cc --- /dev/null +++ b/thirdparty/cppjieba/CMakeLists.txt @@ -0,0 +1,17 @@ +set(cppjieba_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/cppjieba-5.6.7") + +if(NOT TARGET cppjieba) + add_library(cppjieba INTERFACE) + target_include_directories(cppjieba SYSTEM INTERFACE + ${cppjieba_SOURCE_DIR}/include + ) + target_link_libraries(cppjieba INTERFACE limonp) +endif() + +set(cppjieba_FOUND TRUE PARENT_SCOPE) +set(cppjieba_INCLUDE_DIR ${cppjieba_SOURCE_DIR}/include PARENT_SCOPE) +set(cppjieba_INCLUDE_DIRS + ${cppjieba_SOURCE_DIR}/include + ${limonp_INCLUDE_DIR} + PARENT_SCOPE) +set(cppjieba_DICT_DIR ${cppjieba_SOURCE_DIR}/dict PARENT_SCOPE) diff --git a/thirdparty/cppjieba/cppjieba-5.6.7 b/thirdparty/cppjieba/cppjieba-5.6.7 new file mode 160000 index 000000000..b3602bef7 --- /dev/null +++ b/thirdparty/cppjieba/cppjieba-5.6.7 @@ -0,0 +1 @@ +Subproject commit b3602bef7d1f67521a61788a74fb5801a0e62cd3 diff --git a/thirdparty/limonp/CMakeLists.txt b/thirdparty/limonp/CMakeLists.txt new file mode 100644 index 000000000..610327676 --- /dev/null +++ b/thirdparty/limonp/CMakeLists.txt @@ -0,0 +1,12 @@ +set(limonp_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/limonp-v1.0.2") + +if(NOT TARGET limonp) + add_library(limonp INTERFACE) + target_include_directories(limonp SYSTEM INTERFACE + ${limonp_SOURCE_DIR}/include + ) +endif() + +set(limonp_FOUND TRUE PARENT_SCOPE) +set(limonp_INCLUDE_DIR ${limonp_SOURCE_DIR}/include PARENT_SCOPE) +set(limonp_INCLUDE_DIRS ${limonp_SOURCE_DIR}/include PARENT_SCOPE) diff --git a/thirdparty/limonp/limonp-v1.0.2 b/thirdparty/limonp/limonp-v1.0.2 new file mode 160000 index 000000000..9d74077df --- /dev/null +++ b/thirdparty/limonp/limonp-v1.0.2 @@ -0,0 +1 @@ +Subproject commit 9d74077dfcdf8073536c97a00bb79d7a3c3fdaba diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 4e17f1ec3..d01b22e1c 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -5,4 +5,5 @@ include(${PROJECT_ROOT_DIR}/cmake/option.cmake) git_version(ZVEC_VERSION ${CMAKE_CURRENT_SOURCE_DIR}) # Add repository -cc_directory(core) \ No newline at end of file +cc_directory(core) +cc_directory(db) \ No newline at end of file diff --git a/tools/db/CMakeLists.txt b/tools/db/CMakeLists.txt new file mode 100644 index 000000000..fc224e3f8 --- /dev/null +++ b/tools/db/CMakeLists.txt @@ -0,0 +1,13 @@ +include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) + +cc_binary( + NAME fts_bench PACKED + SRCS fts_bench_main.cc + LIBS + zvec_shared + gflags + roaring + rocksdb + INCS . ${PROJECT_SOURCE_DIR}/src + LDFLAGS ${APPLE_FRAMEWORK_LIBS} +) diff --git a/tools/db/fts_bench_main.cc b/tools/db/fts_bench_main.cc new file mode 100644 index 000000000..b0d729931 --- /dev/null +++ b/tools/db/fts_bench_main.cc @@ -0,0 +1,1868 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "db/common/constants.h" +#include "db/common/file_helper.h" +#include "db/common/rocksdb_context.h" +#include "db/index/column/fts_column/fts_column_indexer.h" +#include "db/index/column/fts_column/fts_pipeline.h" +#include "db/index/column/fts_column/fts_query_ast.h" +#include "db/index/column/fts_column/fts_rocksdb_merge.h" +#include "db/index/column/fts_column/fts_rocksdb_reducer.h" +#include "db/index/column/fts_column/fts_types.h" +#include "db/index/column/fts_column/fts_utils.h" +#include "db/index/column/fts_column/posting/bitpacked_posting_list.h" + +namespace { + +// Helper: build a public FtsIndexParams from FLAGS_extra_params JSON string. +// The JSON may contain a "tokenizer" key that specifies the tokenizer name; +// the remaining JSON is passed through as extra_params verbatim. +static std::shared_ptr build_fts_index_params( + const std::string &extra_params_json) { + std::string tokenizer_name = "standard"; + zvec::ailego::JsonValue jv; + if (jv.parse(extra_params_json) && jv.is_object()) { + const auto &obj = jv.as_object(); + zvec::ailego::JsonValue tok_val = obj["tokenizer"]; + if (tok_val.is_string()) { + tokenizer_name = tok_val.as_string().as_stl_string(); + } + } + return std::make_shared( + std::move(tokenizer_name), std::vector{"lowercase"}, + extra_params_json); +} + +// Helper: build a transient FieldSchema for FTS field with index params. +static zvec::FieldSchema::Ptr make_fts_field_schema( + const std::string &field_name, + std::shared_ptr fts_params = nullptr) { + if (!fts_params) { + fts_params = std::make_shared(); + } + return std::make_shared(field_name, zvec::DataType::STRING, + false, fts_params); +} + +} // namespace + +// --------------------------------------------------------------------------- +// gflags +// --------------------------------------------------------------------------- +DEFINE_string(cmd, "", + "Command to execute: build, search, stats. " + "If empty, auto-detect from -corpus / -query flags."); +DEFINE_string(index, "", "Path to FTS index directory"); +DEFINE_string(corpus, "", "Path to BEIR corpus.jsonl (build mode)"); +DEFINE_string(query, "", "Path to BEIR queries.jsonl (search mode)"); +DEFINE_string(qrels, "", "Path to BEIR qrels directory (search mode)"); +DEFINE_int32(topk, 10, "Top-K results to retrieve per query"); +DEFINE_string(extra_params, R"({"tokenizer":"standard"})", + "Extra params JSON for tokenizer pipeline"); +DEFINE_string(field, "text", "FTS field name"); +DEFINE_int32(threads, 16, "Number of threads for multi-threaded search"); +DEFINE_int32(max_queries, 0, + "Maximum number of queries to run in search mode. " + "0 means all queries (default)."); +DEFINE_bool(reduce, false, + "After build, run FtsRocksdbReducer to convert postings to " + "BitPacked format. Reduced index is written to -reduce."); +DEFINE_string(default_operator, "or", + "Default operator used to combine query tokens when searching " + "match_string-style queries. Valid values: 'or' (union, default) " + "or 'and' (intersection)."); +DEFINE_string(mode, "raw", + "Execution mode: 'raw' (default) operates directly on RocksDB " + "via FtsColumnIndexer; 'db' operates through " + "the zvec Collection API (CreateAndOpen / Insert / Query)."); + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- +static const std::string kForwardCfName = "forward"; + +using namespace zvec; +using namespace zvec::fts; + +// --------------------------------------------------------------------------- +// Query AST builder: combine tokens with the configured default operator. +// Returns nullptr when tokens is empty. +// --------------------------------------------------------------------------- +template +static FtsAstNodePtr build_query_ast_from_tokens( + const TokenContainer &tokens, const std::string &default_operator) { + if (tokens.empty()) { + return nullptr; + } + if (default_operator == "and") { + auto and_node = std::make_unique(); + for (const auto &token : tokens) { + and_node->children.push_back(std::make_unique(token.text)); + } + return and_node; + } + // Default: OR + auto or_node = std::make_unique(); + for (const auto &token : tokens) { + or_node->children.push_back(std::make_unique(token.text)); + } + return or_node; +} + +// Validate -default_operator flag value. Returns true if valid. +static bool validate_default_operator(const std::string &op) { + return op == "or" || op == "and"; +} + +// --------------------------------------------------------------------------- +// Helper: open RocksdbStore with FTS column families. +// +// `with_side_cfs` controls whether the build-time only side CFs +// ($TF / $MAX_TF / $DOC_LEN) are listed in the open args. These three CFs +// are dropped at the end of build (after convert_postings_to_bitpacked() +// inlines their payloads into BitPacked postings), mirroring +// MutableSegment::dump_fts_column_indexers(). Search/stats paths therefore +// open the store without them so that the open call doesn't fail with +// "column family not found" against a built index. +// --------------------------------------------------------------------------- +static bool open_fts_store(RocksdbContext *store, const std::string &field_name, + bool existing, const std::string &index_path = "", + bool with_side_cfs = true, + bool with_forward_cf = true) { + const std::string &data_dir = index_path.empty() ? FLAGS_index : index_path; + const std::string max_tf_cf = field_name + zvec::kFtsMaxTfSuffix; + + std::vector cf_names = { + field_name, + field_name + zvec::kFtsPositionsSuffix, + zvec::kFtsStatCfName, + }; + if (with_forward_cf) { + cf_names.push_back(kForwardCfName); + } + if (with_side_cfs) { + cf_names.push_back(field_name + zvec::kFtsTfSuffix); + cf_names.push_back(max_tf_cf); + cf_names.push_back(field_name + zvec::kFtsDocLenSuffix); + } + + // Build per-CF merge operators map + std::unordered_map> + per_cf_merge_ops; + per_cf_merge_ops[field_name] = std::make_shared(); + if (with_side_cfs) { + per_cf_merge_ops[max_tf_cf] = std::make_shared(); + } + + Status status; + if (existing) { + status = store->open( + RocksdbContext::Args{data_dir, cf_names, nullptr, per_cf_merge_ops}, + false); + } else { + status = store->create(RocksdbContext::Args{data_dir, cf_names, nullptr, + per_cf_merge_ops, true}); + } + if (!status.ok()) { + fprintf(stderr, "ERROR: Failed to open RocksdbStore at [%s], status[%s]\n", + data_dir.c_str(), status.message().c_str()); + return false; + } + return true; +} + +// --------------------------------------------------------------------------- +// Helper: drop $TF / $MAX_TF / $DOC_LEN CFs after convert_postings_to_bitpacked +// has inlined their payloads into BitPacked postings. Mirrors +// MutableSegment::dump_fts_column_indexers(). The dumped immutable index is +// significantly smaller because these CFs no longer occupy SST space. +// Logs and ignores per-CF failures so that a partial drop (e.g. CF already +// missing on retry) does not abort the whole build. +// --------------------------------------------------------------------------- +static void drop_fts_side_cfs(RocksdbContext *store, + const std::string &field_name) { + const std::vector side_cf_names = { + field_name + zvec::kFtsTfSuffix, + field_name + zvec::kFtsMaxTfSuffix, + field_name + zvec::kFtsDocLenSuffix, + }; + for (const auto &cf_name : side_cf_names) { + Status drop_status = store->drop_cf(cf_name); + if (!drop_status.ok()) { + fprintf(stderr, + "WARN: Drop column family[%s] failed, status[%s] (ignored)\n", + cf_name.c_str(), drop_status.message().c_str()); + } + } +} + + +// --------------------------------------------------------------------------- +// Helper: parse a JSONL line and extract a string field +// --------------------------------------------------------------------------- +static bool parse_jsonl_line( + const std::string &line, + std::unordered_map *out) { + zvec::ailego::JsonValue jv; + if (!jv.parse(line) || !jv.is_object()) { + return false; + } + const auto &obj = jv.as_object(); + for (const auto &kv : obj) { + if (kv.value().is_string()) { + (*out)[kv.key().as_stl_string()] = kv.value().as_string().as_stl_string(); + } + } + return true; +} + +// --------------------------------------------------------------------------- +// Latency statistics helper +// --------------------------------------------------------------------------- +struct LatencyStats { + std::vector samples; // microseconds + + void add(uint64_t us) { + samples.push_back(us); + } + + void print(const std::string &label) const { + if (samples.empty()) { + std::cout << label << ": no samples" << std::endl; + return; + } + std::vector sorted = samples; + std::sort(sorted.begin(), sorted.end()); + + uint64_t sum = 0; + for (auto v : sorted) sum += v; + double avg = static_cast(sum) / sorted.size(); + + auto percentile = [&](double p) -> uint64_t { + size_t idx = static_cast(p * sorted.size()); + if (idx >= sorted.size()) idx = sorted.size() - 1; + return sorted[idx]; + }; + + std::cout << label << " latency (us):" << std::endl; + std::cout << " Count : " << sorted.size() << std::endl; + std::cout << " Average: " << static_cast(avg) << std::endl; + std::cout << " Min : " << sorted.front() << std::endl; + std::cout << " P50 : " << percentile(0.50) << std::endl; + std::cout << " P95 : " << percentile(0.95) << std::endl; + std::cout << " P99 : " << percentile(0.99) << std::endl; + std::cout << " Max : " << sorted.back() << std::endl; + } +}; + +// --------------------------------------------------------------------------- +// REDUCE MODE: convert Roaring Bitmap postings to BitPacked format +// --------------------------------------------------------------------------- +static int do_reduce(const std::string &src_index_path, uint32_t total_docs) { + const std::string dst_index_path = src_index_path + "-reduce"; + std::cout << std::endl; + std::cout << "=== REDUCE MODE ===" << std::endl; + std::cout << " Source : " << src_index_path << std::endl; + std::cout << " Dest : " << dst_index_path << std::endl; + + // Create destination directory + if (!zvec::FileHelper::DirectoryExists(dst_index_path)) { + if (!zvec::FileHelper::CreateDirectory(dst_index_path)) { + fprintf(stderr, "ERROR: Failed to create reduce output directory: %s\n", + dst_index_path.c_str()); + return -1; + } + } + + // Open source store (existing). $TF/$MAX_TF/$DOC_LEN were dropped at + // build time after convert_postings_to_bitpacked(), so we open without + // them. The reducer never consumed these CFs anyway (BitPacked postings + // already carry inline tf/doc_len/max_score payloads). + RocksdbContext src_store; + if (!open_fts_store(&src_store, FLAGS_field, /*existing=*/true, + src_index_path, /*with_side_cfs=*/false)) { + fprintf(stderr, "ERROR: Failed to open source store for reduce\n"); + return -1; + } + + // Open destination store (new) — same shape as a freshly-dumped immutable + // index: no side CFs. + RocksdbContext dst_store; + if (!open_fts_store(&dst_store, FLAGS_field, /*existing=*/false, + dst_index_path, /*with_side_cfs=*/false)) { + fprintf(stderr, "ERROR: Failed to open destination store for reduce\n"); + src_store.close(); + return -1; + } + + // Get source column families + rocksdb::ColumnFamilyHandle *src_postings = src_store.get_cf(FLAGS_field); + rocksdb::ColumnFamilyHandle *src_positions = + src_store.get_cf(FLAGS_field + zvec::kFtsPositionsSuffix); + rocksdb::ColumnFamilyHandle *src_stat = + src_store.get_cf(zvec::kFtsStatCfName); + rocksdb::ColumnFamilyHandle *src_forward = src_store.get_cf(kForwardCfName); + + // Get destination column families + rocksdb::ColumnFamilyHandle *dst_postings = dst_store.get_cf(FLAGS_field); + rocksdb::ColumnFamilyHandle *dst_positions = + dst_store.get_cf(FLAGS_field + zvec::kFtsPositionsSuffix); + rocksdb::ColumnFamilyHandle *dst_stat = + dst_store.get_cf(zvec::kFtsStatCfName); + rocksdb::ColumnFamilyHandle *dst_forward = dst_store.get_cf(kForwardCfName); + + if (!src_postings || !src_positions || !src_stat || !dst_postings || + !dst_positions || !dst_stat) { + fprintf(stderr, "ERROR: Failed to get column families for reduce\n"); + src_store.close(); + dst_store.close(); + return -1; + } + + zvec::ailego::ElapsedTime reduce_timer; + + // Initialize reducer. Side CFs ($TF/$MAX_TF/$DOC_LEN) are no longer + // consumed by the reducer; they remain in the schema for SST compatibility + // but the bench tool does not need to wire them in. + FtsRocksdbReducer reducer; + auto init_result = reducer.init(FLAGS_field, &dst_store, dst_postings, + dst_positions, dst_stat); + if (!init_result.has_value()) { + fprintf(stderr, "ERROR: FtsRocksdbReducer init failed, status[%s]\n", + init_result.error().message().c_str()); + src_store.close(); + dst_store.close(); + return -1; + } + + // Feed source as a single segment: doc_id range [0, total_docs-1] + FtsSegmentStats seg_stats; + seg_stats.min_doc_id = 0; + seg_stats.max_doc_id = total_docs > 0 ? total_docs - 1 : 0; + + auto feed_result = + reducer.feed(seg_stats, &src_store, src_postings, src_positions); + if (!feed_result.has_value()) { + fprintf(stderr, "ERROR: FtsRocksdbReducer feed failed, status[%s]\n", + feed_result.error().message().c_str()); + src_store.close(); + dst_store.close(); + return -1; + } + + // Run reduce with no-delete filter (empty bitmap = nothing deleted). + std::cout << " Running reduce..." << std::endl; + roaring::Roaring no_delete_bitmap; + auto reduce_result = reducer.reduce(no_delete_bitmap); + if (!reduce_result.has_value()) { + fprintf(stderr, "ERROR: FtsRocksdbReducer reduce failed, status[%s]\n", + reduce_result.error().message().c_str()); + src_store.close(); + dst_store.close(); + return -1; + } + + // Copy forward CF (doc_id -> corpus_id mapping) + if (src_forward && dst_forward) { + std::cout << " Copying forward CF..." << std::endl; + auto iter = std::unique_ptr( + src_store.db_->NewIterator(src_store.read_opts_, src_forward)); + while (iter->Valid()) { + dst_store.db_->Put(dst_store.write_opts_, dst_forward, + iter->key().ToString(), iter->value().ToString()); + iter->Next(); + } + // iter auto-closes via unique_ptr + } + + // Flush and compact destination. Side CFs are not present here. + dst_store.flush(); + // compact not available in RocksdbContext + + + uint64_t reduce_ms = reduce_timer.milli_seconds(); + + std::cout << "=== REDUCE COMPLETE ===" << std::endl; + std::cout << " Reduce time : " << reduce_ms << " ms" << std::endl; + std::cout << " Output path : " << dst_index_path << std::endl; + + (void)reducer.cleanup(); + src_store.close(); + dst_store.close(); + return 0; +} + + +struct CorpusEntry { + uint32_t doc_id; + std::string corpus_id; + std::string content; +}; + +static int do_build() { + const int num_threads = std::max(1, FLAGS_threads); + std::cout << "=== BUILD MODE ===" << std::endl; + std::cout << "Index : " << FLAGS_index << std::endl; + std::cout << "Corpus : " << FLAGS_corpus << std::endl; + std::cout << "Field : " << FLAGS_field << std::endl; + std::cout << "Threads: " << num_threads << std::endl; + std::cout << "ExtraParams: " << FLAGS_extra_params << std::endl; + + // Remove existing index directory so that RocksdbContext::create() starts + // fresh (it requires the path to NOT exist). + if (zvec::FileHelper::DirectoryExists(FLAGS_index)) { + std::cout << "Removing existing index directory: " << FLAGS_index + << std::endl; + zvec::FileHelper::RemoveDirectory(FLAGS_index); + } + + // Open RocksDB (new) + RocksdbContext store; + if (!open_fts_store(&store, FLAGS_field, /*existing=*/false)) { + return -1; + } + + // Get column families + const std::string max_tf_cf_name = FLAGS_field + zvec::kFtsMaxTfSuffix; + + rocksdb::ColumnFamilyHandle *postings_cf = store.get_cf(FLAGS_field); + rocksdb::ColumnFamilyHandle *positions_cf = + store.get_cf(FLAGS_field + zvec::kFtsPositionsSuffix); + rocksdb::ColumnFamilyHandle *term_freq_cf = + store.get_cf(FLAGS_field + zvec::kFtsTfSuffix); + rocksdb::ColumnFamilyHandle *max_tf_cf = store.get_cf(max_tf_cf_name); + rocksdb::ColumnFamilyHandle *doc_len_cf = + store.get_cf(FLAGS_field + zvec::kFtsDocLenSuffix); + rocksdb::ColumnFamilyHandle *stat_cf = store.get_cf(zvec::kFtsStatCfName); + rocksdb::ColumnFamilyHandle *forward_cf = store.get_cf(kForwardCfName); + + if (!postings_cf || !positions_cf || !term_freq_cf || !max_tf_cf || + !doc_len_cf || !stat_cf || !forward_cf) { + fprintf(stderr, "ERROR: Failed to get column families\n"); + return -1; + } + + // Pre-load all corpus entries into memory with pre-assigned doc_ids + std::vector corpus_entries; + uint64_t parse_failed_count = 0; + { + std::ifstream corpus_file(FLAGS_corpus); + if (!corpus_file.is_open()) { + fprintf(stderr, "ERROR: Failed to open corpus file: %s\n", + FLAGS_corpus.c_str()); + return -1; + } + + uint32_t doc_id = 0; + std::string line; + while (std::getline(corpus_file, line)) { + if (line.empty()) continue; + + std::unordered_map fields; + if (!parse_jsonl_line(line, &fields)) { + fprintf(stderr, "WARN: Failed to parse line: %s\n", + line.substr(0, 100).c_str()); + ++parse_failed_count; + continue; + } + + const std::string &corpus_id = fields["_id"]; + if (corpus_id.empty()) { + ++parse_failed_count; + continue; + } + + std::string content; + if (!fields["title"].empty()) { + content = fields["title"] + " " + fields["text"]; + } else { + content = fields["text"]; + } + + corpus_entries.push_back( + {doc_id, std::move(corpus_id), std::move(content)}); + ++doc_id; + } + } + std::cout << "Loaded " << corpus_entries.size() << " corpus entries." + << std::endl; + if (parse_failed_count > 0) { + std::cout << " Warning: " << parse_failed_count + << " entries failed to parse." << std::endl; + } + + auto fts_params = build_fts_index_params(FLAGS_extra_params); + auto field_meta = make_fts_field_schema(FLAGS_field, fts_params); + + FtsColumnIndexer indexer; + auto open_result = indexer.open(field_meta, &store, postings_cf, positions_cf, + term_freq_cf, max_tf_cf, doc_len_cf, stat_cf); + if (!open_result.has_value()) { + fprintf(stderr, "ERROR: Failed to open FtsColumnIndexer, status[%s]\n", + open_result.error().message().c_str()); + return -1; + } + + // Shared atomic index for work-stealing across threads + std::atomic next_entry_index{0}; + + // Per-thread result accumulators + struct ThreadResult { + uint64_t indexed_count{0}; + uint64_t failed_count{0}; + }; + std::vector thread_results(num_threads); + + std::cout << "Building index with " << num_threads << " thread(s)..." + << std::endl; + + zvec::ailego::ElapsedTime timer; + + auto worker = [&](int thread_id) { + ThreadResult &result = thread_results[thread_id]; + + while (true) { + size_t entry_idx = + next_entry_index.fetch_add(1, std::memory_order_relaxed); + if (entry_idx >= corpus_entries.size()) break; + + const CorpusEntry &entry = corpus_entries[entry_idx]; + + auto insert_result = indexer.insert(entry.doc_id, entry.content); + if (!insert_result.has_value()) { + fprintf(stderr, + "WARN: Thread[%d] failed to insert doc_id[%u] corpus_id[%s], " + "status[%s]\n", + thread_id, entry.doc_id, entry.corpus_id.c_str(), + insert_result.error().message().c_str()); + ++result.failed_count; + continue; + } + + // Write forward mapping: doc_id -> corpus_id + std::string doc_id_key; + fts::encode_uint32_big_endian(entry.doc_id, &doc_id_key); + store.db_->Put(store.write_opts_, forward_cf, doc_id_key, + entry.corpus_id); + + ++result.indexed_count; + + // Progress reporting (only from thread 0 to avoid interleaving) + if (thread_id == 0 && result.indexed_count % 1000 == 0) { + size_t total_done = 0; + for (const auto &tr : thread_results) { + total_done += tr.indexed_count + tr.failed_count; + } + std::cout << "\r Indexed ~" << total_done << " / " + << corpus_entries.size() << " docs..." << std::flush; + } + } + }; + + // Launch threads + std::vector threads; + threads.reserve(num_threads); + for (int thread_id = 0; thread_id < num_threads; ++thread_id) { + threads.emplace_back(worker, thread_id); + } + for (auto &thread : threads) { + thread.join(); + } + + uint64_t build_ms = timer.milli_seconds(); + + // Merge per-thread results + uint64_t total_indexed = 0; + uint64_t total_failed = 0; + for (const auto &result : thread_results) { + total_indexed += result.indexed_count; + total_failed += result.failed_count; + } + + std::cout << "\r Indexed " << total_indexed << " docs total." << std::endl; + if (total_failed > 0) { + std::cout << " Warning: " << total_failed << " docs failed to index." + << std::endl; + } + + // Flush statistics — single indexer tracks all docs/tokens atomically + std::cout << "Flushing statistics (total_docs=" << indexer.total_docs() + << ", total_tokens=" << indexer.total_tokens() << ")..." + << std::endl; + auto flush_result = indexer.flush(); + if (!flush_result.has_value()) { + fprintf(stderr, "WARN: FtsColumnIndexer flush failed, status[%s]\n", + flush_result.error().message().c_str()); + } + + // Convert Roaring postings to BitPacked before close/dump, mirroring + // MutableSegment::dump_fts_column_indexers(). Must run before close() + // for symmetry with the single-threaded path; convert itself does not + // depend on the tokenizer pipeline. + std::cout << "Converting postings to BitPacked..." << std::endl; + zvec::ailego::ElapsedTime bitpacked_timer2; + auto bitpacked_result = indexer.convert_postings_to_bitpacked(); + if (!bitpacked_result.has_value()) { + fprintf(stderr, + "WARN: FtsColumnIndexer convert_postings_to_bitpacked failed, " + "status[%s]\n", + bitpacked_result.error().message().c_str()); + } + std::cout << "convert_postings_to_bitpacked took " + << bitpacked_timer2.micro_seconds() / 1000.0 << " ms" << std::endl; + + // Drop $TF / $MAX_TF / $DOC_LEN CFs after their payloads have been inlined + // into BitPacked postings. Mirrors MutableSegment::dump_fts_column_ + // indexers(): reset_side_cfs() first so any concurrent reader-path access + // through the indexer falls back to default tf=1/doc_len=1 instead of + // touching a dropped handle, then drop the CFs from the underlying store. + indexer.reset_side_cfs(); + drop_fts_side_cfs(&store, FLAGS_field); + // Local pointers are now dangling; null them out so accidental use becomes + // an obvious crash instead of a use-after-free. + term_freq_cf = nullptr; + max_tf_cf = nullptr; + doc_len_cf = nullptr; + + (void)indexer.close(); + + // Flush RocksDB memtables + dump checkpoint + zvec::ailego::ElapsedTime dump_timer; + store.flush(); + + // Trigger compaction + checkpoint + std::cout << "Running compaction..." << std::endl; + store.compact(); + + uint64_t dump_ms = dump_timer.milli_seconds(); + uint64_t elapsed_ms = timer.milli_seconds(); + std::cout << "=== BUILD COMPLETE ===" << std::endl; + std::cout << " Total docs : " << total_indexed << std::endl; + std::cout << " Threads : " << num_threads << std::endl; + std::cout << " Build time : " << build_ms << " ms" << std::endl; + std::cout << " Dump time : " << dump_ms << " ms (flush + compaction)" + << std::endl; + std::cout << " Total time : " << elapsed_ms << " ms" << std::endl; + std::cout << " Throughput : " + << (total_indexed > 0 + ? total_indexed * 1000ULL / (build_ms > 0 ? build_ms : 1) + : 0) + << " docs/s (build only)" << std::endl; + + store.close(); + + // Optional: run reduce to convert postings to BitPacked format + if (FLAGS_reduce) { + int reduce_ret = do_reduce(FLAGS_index, total_indexed); + if (reduce_ret != 0) { + fprintf(stderr, "ERROR: Reduce step failed, ret[%d]\n", reduce_ret); + return reduce_ret; + } + } + + return 0; +} + +// --------------------------------------------------------------------------- +// BUILD MODE (db): use zvec Collection API +// --------------------------------------------------------------------------- +static int do_build_db() { + const int num_threads = std::max(1, FLAGS_threads); + std::cout << "=== BUILD MODE (db) ===" << std::endl; + std::cout << "Index : " << FLAGS_index << std::endl; + std::cout << "Corpus : " << FLAGS_corpus << std::endl; + std::cout << "Field : " << FLAGS_field << std::endl; + std::cout << "Threads: " << num_threads << std::endl; + + // Remove existing collection directory + if (zvec::FileHelper::DirectoryExists(FLAGS_index)) { + std::cout << "Removing existing collection directory: " << FLAGS_index + << std::endl; + zvec::FileHelper::RemoveDirectory(FLAGS_index); + } + + // Build schema: pk (implicit) + FTS field + dummy vector field (required + // by segment layer). + // Build FtsIndexParams from FLAGS_extra_params so that the tokenizer + // pipeline configuration (e.g. enable_simple_closet) matches raw mode. + auto db_fts_params = build_fts_index_params(FLAGS_extra_params); + + CollectionSchema schema("fts_bench"); + schema.add_field(std::make_shared(FLAGS_field, DataType::STRING, + false, db_fts_params)); + // Segment layer requires at least one vector field. Do NOT set + // index_params: fts_bench links with PACKED mode which strips core-layer + // metric static registrations, so creating a vector index would fail with + // "Failed to create metric". An unindexed vector field is sufficient. + schema.add_field(std::make_shared( + "__dummy_vec", DataType::VECTOR_FP32, 4, /*nullable=*/true)); + + CollectionOptions options; + options.read_only_ = false; + + auto create_result = Collection::CreateAndOpen(FLAGS_index, schema, options); + if (!create_result.has_value()) { + fprintf(stderr, "ERROR: Failed to create collection at [%s]: %s\n", + FLAGS_index.c_str(), create_result.error().message().c_str()); + return -1; + } + auto collection = create_result.value(); + + // Pre-load corpus entries + std::vector corpus_entries; + uint64_t parse_failed_count = 0; + { + std::ifstream corpus_file(FLAGS_corpus); + if (!corpus_file.is_open()) { + fprintf(stderr, "ERROR: Failed to open corpus file: %s\n", + FLAGS_corpus.c_str()); + return -1; + } + uint32_t doc_id = 0; + std::string line; + while (std::getline(corpus_file, line)) { + if (line.empty()) continue; + std::unordered_map fields; + if (!parse_jsonl_line(line, &fields)) { + ++parse_failed_count; + continue; + } + const std::string &corpus_id = fields["_id"]; + if (corpus_id.empty()) { + ++parse_failed_count; + continue; + } + std::string content; + if (!fields["title"].empty()) { + content = fields["title"] + " " + fields["text"]; + } else { + content = fields["text"]; + } + corpus_entries.push_back( + {doc_id, std::move(corpus_id), std::move(content)}); + ++doc_id; + } + } + std::cout << "Loaded " << corpus_entries.size() << " corpus entries." + << std::endl; + if (parse_failed_count > 0) { + std::cout << " Warning: " << parse_failed_count + << " entries failed to parse." << std::endl; + } + + // Insert in batches via Collection::Insert + const size_t batch_size = 1000; + uint64_t total_indexed = 0; + uint64_t total_failed = 0; + + std::cout << "Inserting documents via Collection API..." << std::endl; + zvec::ailego::ElapsedTime timer; + + for (size_t offset = 0; offset < corpus_entries.size(); + offset += batch_size) { + size_t end = std::min(offset + batch_size, corpus_entries.size()); + std::vector docs; + docs.reserve(end - offset); + for (size_t i = offset; i < end; ++i) { + const CorpusEntry &entry = corpus_entries[i]; + Doc doc; + doc.set_pk(entry.corpus_id); + doc.set(FLAGS_field, entry.content); + // dummy vector (nullable field still needs a value for WAL/forward) + doc.set>("__dummy_vec", {0.0f, 0.0f, 0.0f, 0.0f}); + docs.push_back(std::move(doc)); + } + auto insert_result = collection->Insert(docs); + if (!insert_result.has_value()) { + fprintf(stderr, "WARN: Batch insert failed at offset[%zu]: %s\n", offset, + insert_result.error().message().c_str()); + total_failed += (end - offset); + } else { + total_indexed += (end - offset); + } + if (total_indexed % 10000 < batch_size) { + std::cout << "\r Inserted " << total_indexed << " / " + << corpus_entries.size() << " docs..." << std::flush; + } + } + + uint64_t build_ms = timer.milli_seconds(); + + // Flush collection + auto flush_status = collection->Flush(); + if (!flush_status.ok()) { + fprintf(stderr, "WARN: Collection flush failed: %s\n", + flush_status.message().c_str()); + } + + // Optimize triggers segment dump which converts Roaring postings to + // BitPacked format (with inline tf/doc_len payloads). Without this step + // the immutable reader path falls back to tf=1/doc_len=1 because the + // side CFs (_tf/_doc_len/_max_tf) are not opened for read-only segments. + auto optimize_status = collection->Optimize(); + if (!optimize_status.ok()) { + fprintf(stderr, "WARN: Collection optimize failed: %s\n", + optimize_status.message().c_str()); + } + + std::cout << "\r Inserted " << total_indexed << " docs total." << std::endl; + if (total_failed > 0) { + std::cout << " Warning: " << total_failed << " docs failed to insert." + << std::endl; + } + std::cout << "=== BUILD COMPLETE (db) ===" << std::endl; + std::cout << " Total docs : " << total_indexed << std::endl; + std::cout << " Build time : " << build_ms << " ms" << std::endl; + std::cout << " Throughput : " + << (total_indexed > 0 + ? total_indexed * 1000ULL / (build_ms > 0 ? build_ms : 1) + : 0) + << " docs/s" << std::endl; + + return 0; +} + +// --------------------------------------------------------------------------- +// SEARCH MODE +// --------------------------------------------------------------------------- + +// Parse qrels TSV file: returns map of query_id -> set +static std::unordered_map> +load_qrels(const std::string &qrels_dir) { + std::unordered_map> qrels; + + // Try test.tsv first, then train.tsv + std::vector candidates = {qrels_dir + "/test.tsv", + qrels_dir + "/train.tsv"}; + std::string qrels_file; + for (const auto &f : candidates) { + if (FileHelper::FileExists(f)) { + qrels_file = f; + break; + } + } + + if (qrels_file.empty()) { + fprintf(stderr, "ERROR: No qrels file found in directory: %s\n", + qrels_dir.c_str()); + return qrels; + } + + std::cout << "Loading qrels from: " << qrels_file << std::endl; + + std::ifstream f(qrels_file); + if (!f.is_open()) { + fprintf(stderr, "ERROR: Failed to open qrels file: %s\n", + qrels_file.c_str()); + return qrels; + } + + std::string line; + bool first_line = true; + while (std::getline(f, line)) { + if (first_line) { + first_line = false; + continue; // skip header + } + if (line.empty()) continue; + + std::istringstream ss(line); + std::string query_id, corpus_id, score_str; + if (!std::getline(ss, query_id, '\t') || + !std::getline(ss, corpus_id, '\t') || + !std::getline(ss, score_str, '\t')) { + continue; + } + // Only include relevant docs (score > 0) + int score = std::stoi(score_str); + if (score > 0) { + qrels[query_id].insert(corpus_id); + } + } + + std::cout << "Loaded qrels for " << qrels.size() << " queries." << std::endl; + return qrels; +} + +// --------------------------------------------------------------------------- +// Unified single-/multi-threaded search: +// * Always pre-loads queries into memory and dispatches them to +// FLAGS_threads workers via an atomic index counter. +// * FtsColumnIndexer::search() and the shared TokenizerPipeline are both +// read-only / fork-safe, so a single shared reader and pipeline are +// reused across workers. +// * When FLAGS_threads == 1 the path collapses to a single worker, +// behaving equivalently to a sequential single-threaded search. +// --------------------------------------------------------------------------- + +struct QueryEntry { + std::string query_id; + std::string match_text; +}; + +struct RecallCounter { + double sum{0.0}; + uint64_t total{0}; + void add(double recall_value) { + sum += recall_value; + total++; + } + double ratio() const { + return total > 0 ? sum / static_cast(total) : 0.0; + } +}; + + +static int do_search() { + if (!validate_default_operator(FLAGS_default_operator)) { + fprintf(stderr, + "ERROR: Invalid -default_operator[%s]. Must be 'or' or 'and'.\n", + FLAGS_default_operator.c_str()); + return -1; + } + + const int num_threads = std::max(1, FLAGS_threads); + + const std::string fts_index_path = FLAGS_index; + + std::cout << "=== SEARCH MODE ===" << std::endl; + std::cout << "Index : " << fts_index_path << std::endl; + std::cout << "Query : " << FLAGS_query << std::endl; + std::cout << "Qrels : " << FLAGS_qrels << std::endl; + std::cout << "TopK : " << FLAGS_topk << std::endl; + std::cout << "Field : " << FLAGS_field << std::endl; + std::cout << "Threads : " << num_threads << std::endl; + std::cout << "Default operator : " << FLAGS_default_operator << std::endl; + + // Open FTS RocksDB (existing) — shared across threads (RocksDB reads are + // thread-safe at the CF level). Open without $TF/$MAX_TF/$DOC_LEN since + // those CFs were dropped at build time after convert_postings_to_bitpacked(). + RocksdbContext store; + if (!open_fts_store(&store, FLAGS_field, /*existing=*/true, + /*index_path=*/fts_index_path, + /*with_side_cfs=*/false, + /*with_forward_cf=*/true)) { + return -1; + } + + rocksdb::ColumnFamilyHandle *postings_cf = store.get_cf(FLAGS_field); + rocksdb::ColumnFamilyHandle *positions_cf = + store.get_cf(FLAGS_field + zvec::kFtsPositionsSuffix); + rocksdb::ColumnFamilyHandle *stat_cf = store.get_cf(zvec::kFtsStatCfName); + rocksdb::ColumnFamilyHandle *forward_cf = store.get_cf(kForwardCfName); + + if (!postings_cf || !positions_cf || !stat_cf || !forward_cf) { + fprintf(stderr, "ERROR: Failed to get column families\n"); + return -1; + } + + // Load qrels + auto qrels = load_qrels(FLAGS_qrels); + + // Pre-load all queries into memory so threads can access them without I/O + // contention + std::vector queries; + { + std::ifstream query_file(FLAGS_query); + if (!query_file.is_open()) { + fprintf(stderr, "ERROR: Failed to open query file: %s\n", + FLAGS_query.c_str()); + return -1; + } + std::string line; + while (std::getline(query_file, line)) { + if (line.empty()) continue; + std::unordered_map fields; + if (!parse_jsonl_line(line, &fields)) continue; + const std::string &query_id = fields["_id"]; + const std::string &query_text = fields["text"]; + if (query_id.empty() || query_text.empty()) continue; + queries.push_back({query_id, query_text}); + } + } + std::cout << "Loaded " << queries.size() << " queries." << std::endl; + + // Limit the number of queries if configured: first keep only queries that + // have qrels entries (relevant results), then truncate to max_queries. + if (FLAGS_max_queries > 0) { + std::vector filtered; + for (auto &q : queries) { + if (qrels.count(q.query_id) > 0) { + filtered.push_back(std::move(q)); + } + } + queries = std::move(filtered); + if (static_cast(FLAGS_max_queries) < queries.size()) { + queries.resize(FLAGS_max_queries); + } + std::cout << "Limited to " << queries.size() + << " queries with qrels (--max_queries)." << std::endl; + } + + // Shared atomic index for work-stealing across threads + std::atomic next_query_index{0}; + + // Per-thread result accumulators, merged after all threads finish + struct ThreadResult { + LatencyStats latency_stats; + RecallCounter recall1; + RecallCounter recall5; + RecallCounter recall10; + RecallCounter recallK; + uint64_t no_result_count{0}; + uint64_t query_count{0}; + }; + std::vector thread_results(num_threads); + + auto query_fts_params = build_fts_index_params(FLAGS_extra_params); + auto pipeline_result = zvec::detail::AcquireFtsPipeline(*query_fts_params); + if (!pipeline_result.has_value()) { + fprintf(stderr, + "ERROR: Failed to create tokenizer pipeline for " + "extra_params[%s]: %s\n", + FLAGS_extra_params.c_str(), + pipeline_result.error().message().c_str()); + return -1; + } + auto &query_pipeline = pipeline_result.value(); + + std::cout << "Running queries with " << num_threads << " thread(s)..." + << std::endl; + + // Create a single shared FtsColumnIndexer in read-only mode. search() is a + // const method that only performs read-only RocksDB lookups, so it is safe + // to share across threads. + FtsColumnIndexer reader; + { + // $TF/$MAX_TF/$DOC_LEN are dropped at build time; pass nullptr — the + // BitPacked path doesn't need them and the Roaring fallback degrades + // to default tf=1/doc_len=1 when these pointers are null. + auto open_result = + reader.open_reader(FLAGS_field, &store, postings_cf, positions_cf, + /*term_freq_cf=*/nullptr, + /*max_tf_cf=*/nullptr, + /*doc_len_cf=*/nullptr, stat_cf); + if (!open_result.has_value()) { + fprintf(stderr, "ERROR: Failed to open FtsColumnIndexer, status[%s]\n", + open_result.error().message().c_str()); + return -1; + } + } + + auto worker = [&](int thread_id) { + ThreadResult &result = thread_results[thread_id]; + + while (true) { + size_t query_idx = + next_query_index.fetch_add(1, std::memory_order_relaxed); + if (query_idx >= queries.size()) break; + + const QueryEntry &entry = queries[query_idx]; + + std::vector results; + bool search_ok = true; + uint64_t elapsed_us = 0; + { + zvec::ailego::ElapsedTime timer; + // Tokenize query text (match_string style: tokenize then build AST + // combining tokens with the configured default operator). + auto tokens = query_pipeline->process(entry.match_text); + auto ast_root = + build_query_ast_from_tokens(tokens, FLAGS_default_operator); + if (ast_root) { + fts::FtsQueryParams query_params; + query_params.topk = static_cast(FLAGS_topk); + auto search_result = reader.search(*ast_root, query_params); + if (!search_result.has_value()) { + fprintf(stderr, + "WARN: Thread[%d] search failed for query_id[%s], " + "status[%s]\n", + thread_id, entry.query_id.c_str(), + search_result.error().message().c_str()); + search_ok = false; + } else { + results = std::move(search_result.value()); + } + } + elapsed_us = timer.micro_seconds(); + } + + if (!search_ok) { + continue; + } + + result.latency_stats.add(elapsed_us); + ++result.query_count; + + if (results.empty()) { + ++result.no_result_count; + } + + // Resolve doc_id -> corpus_id (a.k.a. pk) via the forward CF. + std::vector retrieved_corpus_ids; + retrieved_corpus_ids.reserve(results.size()); + for (const auto &r : results) { + std::string corpus_id; + std::string doc_id_key; + fts::encode_uint32_big_endian(r.doc_id, &doc_id_key); + if (!store.db_ + ->Get(store.read_opts_, forward_cf, doc_id_key, &corpus_id) + .ok()) { + corpus_id = ""; + } + retrieved_corpus_ids.push_back(corpus_id); + } + + // Compute recall at various cutoffs + const auto qrels_it = qrels.find(entry.query_id); + if (qrels_it == qrels.end()) continue; + + const auto &relevant = qrels_it->second; + + // Standard IR Recall@K = |retrieved_topK ∩ relevant| / |relevant| + auto compute_recall = [&](int cutoff) -> double { + if (relevant.empty()) return 0.0; + int limit = + std::min(cutoff, static_cast(retrieved_corpus_ids.size())); + int hits = 0; + for (int i = 0; i < limit; ++i) { + if (relevant.count(retrieved_corpus_ids[i]) > 0) { + hits++; + } + } + return static_cast(hits) / static_cast(relevant.size()); + }; + + result.recall1.add(compute_recall(1)); + result.recall5.add(compute_recall(5)); + result.recall10.add(compute_recall(10)); + result.recallK.add(compute_recall(FLAGS_topk)); + } + }; + + // Launch threads and measure total wall-clock time + auto wall_start = std::chrono::steady_clock::now(); + + std::vector threads; + threads.reserve(num_threads); + for (int thread_id = 0; thread_id < num_threads; ++thread_id) { + threads.emplace_back(worker, thread_id); + } + for (auto &thread : threads) { + thread.join(); + } + + auto wall_end = std::chrono::steady_clock::now(); + uint64_t wall_ms = static_cast( + std::chrono::duration_cast(wall_end - + wall_start) + .count()); + + // Merge per-thread results + LatencyStats merged_latency; + RecallCounter merged_recall1, merged_recall5, merged_recall10, merged_recallK; + uint64_t total_query_count = 0; + uint64_t total_no_result_count = 0; + + for (const auto &result : thread_results) { + for (uint64_t sample : result.latency_stats.samples) { + merged_latency.add(sample); + } + merged_recall1.sum += result.recall1.sum; + merged_recall1.total += result.recall1.total; + merged_recall5.sum += result.recall5.sum; + merged_recall5.total += result.recall5.total; + merged_recall10.sum += result.recall10.sum; + merged_recall10.total += result.recall10.total; + merged_recallK.sum += result.recallK.sum; + merged_recallK.total += result.recallK.total; + total_query_count += result.query_count; + total_no_result_count += result.no_result_count; + } + + // Output results + std::cout << std::endl; + std::cout << "=== SEARCH RESULTS ===" << std::endl; + std::cout << "Threads : " << num_threads << std::endl; + std::cout << "Total queries : " << total_query_count << std::endl; + std::cout << "No-result queries: " << total_no_result_count << std::endl; + std::cout << "Wall-clock time : " << wall_ms << " ms" << std::endl; + if (wall_ms > 0) { + std::cout << "Throughput : " << total_query_count * 1000ULL / wall_ms + << " queries/s" << std::endl; + } + std::cout << std::endl; + + merged_latency.print("Search (per-query)"); + std::cout << std::endl; + + if (merged_recall1.total > 0) { + std::cout << "=== RECALL ===" << std::endl; + std::cout << " Recall@1 : " << merged_recall1.ratio() << " (evaluated on " + << merged_recall1.total << " queries)" << std::endl; + std::cout << " Recall@5 : " << merged_recall5.ratio() << " (evaluated on " + << merged_recall5.total << " queries)" << std::endl; + std::cout << " Recall@10 : " << merged_recall10.ratio() + << " (evaluated on " << merged_recall10.total << " queries)" + << std::endl; + if (FLAGS_topk > 10) { + std::cout << " Recall@" << FLAGS_topk << " : " << merged_recallK.ratio() + << " (evaluated on " << merged_recallK.total << " queries)" + << std::endl; + } + } else { + std::cout << "No qrels matched for evaluated queries." << std::endl; + } + + store.close(); + return 0; +} + +// --------------------------------------------------------------------------- +// SEARCH MODE (db): use zvec Collection::Query(Fts) +// --------------------------------------------------------------------------- +static int do_search_db() { + const int num_threads = std::max(1, FLAGS_threads); + + std::cout << "=== SEARCH MODE (db) ===" << std::endl; + std::cout << "Index : " << FLAGS_index << std::endl; + std::cout << "Query : " << FLAGS_query << std::endl; + std::cout << "Qrels : " << FLAGS_qrels << std::endl; + std::cout << "TopK : " << FLAGS_topk << std::endl; + std::cout << "Field : " << FLAGS_field << std::endl; + std::cout << "Threads : " << num_threads << std::endl; + + // Open existing collection in read-only mode + CollectionOptions options; + options.read_only_ = true; + + auto open_result = Collection::Open(FLAGS_index, options); + if (!open_result.has_value()) { + fprintf(stderr, "ERROR: Failed to open collection at [%s]: %s\n", + FLAGS_index.c_str(), open_result.error().message().c_str()); + return -1; + } + auto collection = open_result.value(); + + // Load qrels + auto qrels = load_qrels(FLAGS_qrels); + + // Pre-load queries + std::vector queries; + { + std::ifstream query_file(FLAGS_query); + if (!query_file.is_open()) { + fprintf(stderr, "ERROR: Failed to open query file: %s\n", + FLAGS_query.c_str()); + return -1; + } + std::string line; + while (std::getline(query_file, line)) { + if (line.empty()) continue; + std::unordered_map fields; + if (!parse_jsonl_line(line, &fields)) continue; + const std::string &query_id = fields["_id"]; + const std::string &query_text = fields["text"]; + if (query_id.empty() || query_text.empty()) continue; + queries.push_back({query_id, query_text}); + } + } + std::cout << "Loaded " << queries.size() << " queries." << std::endl; + + // Limit the number of queries if configured: first keep only queries that + // have qrels entries (relevant results), then truncate to max_queries. + if (FLAGS_max_queries > 0) { + std::vector filtered; + for (auto &q : queries) { + if (qrels.count(q.query_id) > 0) { + filtered.push_back(std::move(q)); + } + } + queries = std::move(filtered); + if (static_cast(FLAGS_max_queries) < queries.size()) { + queries.resize(FLAGS_max_queries); + } + std::cout << "Limited to " << queries.size() + << " queries with qrels (--max_queries)." << std::endl; + } + + // Per-thread result accumulators + std::atomic next_query_index{0}; + std::atomic fatal_error{false}; + + struct ThreadResult { + LatencyStats latency_stats; + RecallCounter recall1; + RecallCounter recall5; + RecallCounter recall10; + RecallCounter recallK; + uint64_t no_result_count{0}; + uint64_t query_count{0}; + }; + std::vector thread_results(num_threads); + + std::cout << "Running queries via Collection API with " << num_threads + << " thread(s)..." << std::endl; + + auto worker = [&](int thread_id) { + ThreadResult &result = thread_results[thread_id]; + + while (true) { + if (fatal_error.load(std::memory_order_relaxed)) break; + size_t query_idx = + next_query_index.fetch_add(1, std::memory_order_relaxed); + if (query_idx >= queries.size()) break; + + const QueryEntry &entry = queries[query_idx]; + + VectorQuery vq; + vq.field_name_ = FLAGS_field; + vq.topk_ = FLAGS_topk; + Fts fts; + fts.match_string_ = entry.match_text; + vq.fts_ = fts; + + uint64_t elapsed_us = 0; + std::vector retrieved_corpus_ids; + { + zvec::ailego::ElapsedTime query_timer; + auto query_result = collection->Query(vq); + elapsed_us = query_timer.micro_seconds(); + + if (query_result.has_value()) { + const auto &doc_list = query_result.value(); + retrieved_corpus_ids.reserve(doc_list.size()); + for (const auto &doc_ptr : doc_list) { + retrieved_corpus_ids.push_back(doc_ptr->pk()); + } + } else { + fprintf(stderr, "ERROR: Thread[%d] Fts failed for query_id[%s]: %s\n", + thread_id, entry.query_id.c_str(), + query_result.error().message().c_str()); + fatal_error.store(true, std::memory_order_relaxed); + break; + } + } + + result.latency_stats.add(elapsed_us); + ++result.query_count; + + if (retrieved_corpus_ids.empty()) { + ++result.no_result_count; + } + + // Compute recall + const auto qrels_it = qrels.find(entry.query_id); + if (qrels_it == qrels.end()) continue; + const auto &relevant = qrels_it->second; + + auto compute_recall = [&](int cutoff) -> double { + if (relevant.empty()) return 0.0; + int limit = + std::min(cutoff, static_cast(retrieved_corpus_ids.size())); + int hits = 0; + for (int i = 0; i < limit; ++i) { + if (relevant.count(retrieved_corpus_ids[i]) > 0) { + hits++; + } + } + return static_cast(hits) / static_cast(relevant.size()); + }; + + result.recall1.add(compute_recall(1)); + result.recall5.add(compute_recall(5)); + result.recall10.add(compute_recall(10)); + result.recallK.add(compute_recall(FLAGS_topk)); + } + }; + + auto wall_start = std::chrono::steady_clock::now(); + std::vector threads; + threads.reserve(num_threads); + for (int thread_id = 0; thread_id < num_threads; ++thread_id) { + threads.emplace_back(worker, thread_id); + } + for (auto &thread : threads) { + thread.join(); + } + + if (fatal_error.load()) { + fprintf(stderr, "ERROR: Aborting: Fts failed during search\n"); + return -1; + } + + auto wall_end = std::chrono::steady_clock::now(); + uint64_t wall_ms = static_cast( + std::chrono::duration_cast(wall_end - + wall_start) + .count()); + + // Merge per-thread results + LatencyStats merged_latency; + RecallCounter merged_recall1, merged_recall5, merged_recall10, merged_recallK; + uint64_t total_query_count = 0; + uint64_t total_no_result_count = 0; + + for (const auto &result : thread_results) { + for (uint64_t sample : result.latency_stats.samples) { + merged_latency.add(sample); + } + merged_recall1.sum += result.recall1.sum; + merged_recall1.total += result.recall1.total; + merged_recall5.sum += result.recall5.sum; + merged_recall5.total += result.recall5.total; + merged_recall10.sum += result.recall10.sum; + merged_recall10.total += result.recall10.total; + merged_recallK.sum += result.recallK.sum; + merged_recallK.total += result.recallK.total; + total_query_count += result.query_count; + total_no_result_count += result.no_result_count; + } + + std::cout << std::endl; + std::cout << "=== SEARCH RESULTS (db) ===" << std::endl; + std::cout << "Threads : " << num_threads << std::endl; + std::cout << "Total queries : " << total_query_count << std::endl; + std::cout << "No-result queries: " << total_no_result_count << std::endl; + std::cout << "Wall-clock time : " << wall_ms << " ms" << std::endl; + if (wall_ms > 0) { + std::cout << "Throughput : " << total_query_count * 1000ULL / wall_ms + << " queries/s" << std::endl; + } + std::cout << std::endl; + + merged_latency.print("Search (per-query)"); + std::cout << std::endl; + + if (merged_recall1.total > 0) { + std::cout << "=== RECALL ===" << std::endl; + std::cout << " Recall@1 : " << merged_recall1.ratio() << " (evaluated on " + << merged_recall1.total << " queries)" << std::endl; + std::cout << " Recall@5 : " << merged_recall5.ratio() << " (evaluated on " + << merged_recall5.total << " queries)" << std::endl; + std::cout << " Recall@10 : " << merged_recall10.ratio() + << " (evaluated on " << merged_recall10.total << " queries)" + << std::endl; + if (FLAGS_topk > 10) { + std::cout << " Recall@" << FLAGS_topk << " : " << merged_recallK.ratio() + << " (evaluated on " << merged_recallK.total << " queries)" + << std::endl; + } + } else { + std::cout << "No qrels matched for evaluated queries." << std::endl; + } + + return 0; +} + +// --------------------------------------------------------------------------- +// STATS MODE +// --------------------------------------------------------------------------- +static int do_stats() { + std::cout << "=== STATS MODE ===" << std::endl; + std::cout << "Index : " << FLAGS_index << std::endl; + std::cout << "Field : " << FLAGS_field << std::endl; + + // Open RocksDB (existing). $TF/$MAX_TF/$DOC_LEN are dropped at build + // time, so open without them. Sections that scan these CFs are now + // gated on the corresponding pointers being non-null (always null here + // post-drop) and simply skipped with an explanatory message. + RocksdbContext store; + if (!open_fts_store(&store, FLAGS_field, /*existing=*/true, + /*index_path=*/"", /*with_side_cfs=*/false)) { + return -1; + } + + rocksdb::ColumnFamilyHandle *postings_cf = store.get_cf(FLAGS_field); + rocksdb::ColumnFamilyHandle *stat_cf = store.get_cf(zvec::kFtsStatCfName); + // $MAX_TF/$DOC_LEN are not opened above; keep nullptrs so the + // doc-length / max-tf scan sections degrade gracefully. + rocksdb::ColumnFamilyHandle *max_tf_cf = nullptr; + rocksdb::ColumnFamilyHandle *doc_len_cf = nullptr; + + if (!postings_cf || !stat_cf) { + fprintf(stderr, "ERROR: Failed to get required column families\n"); + return -1; + } + + // --------------------------------------------------------------- + // 1. Segment-level statistics (total_docs, total_tokens) + // --------------------------------------------------------------- + uint64_t total_docs = 0; + uint64_t total_tokens = 0; + { + const std::string total_docs_key = FLAGS_field + "_total_docs"; + const std::string total_tokens_key = FLAGS_field + "_total_tokens"; + std::string value; + if (store.db_->Get(store.read_opts_, stat_cf, total_docs_key, &value) + .ok() && + value.size() >= sizeof(uint64_t)) { + std::memcpy(&total_docs, value.data(), sizeof(uint64_t)); + } + value.clear(); + if (store.db_->Get(store.read_opts_, stat_cf, total_tokens_key, &value) + .ok() && + value.size() >= sizeof(uint64_t)) { + std::memcpy(&total_tokens, value.data(), sizeof(uint64_t)); + } + } + + double avg_doc_len = total_docs > 0 ? static_cast(total_tokens) / + static_cast(total_docs) + : 0.0; + + std::cout << std::endl; + std::cout << "--- Segment Statistics ---" << std::endl; + std::cout << " Total documents : " << total_docs << std::endl; + std::cout << " Total tokens : " << total_tokens << std::endl; + std::cout << " Avg doc length : " << avg_doc_len << std::endl; + + // --------------------------------------------------------------- + // 2. Vocabulary & posting list statistics + // --------------------------------------------------------------- + std::cout << std::endl; + std::cout << "--- Vocabulary & Posting List Statistics ---" << std::endl; + std::cout << " Scanning postings CF..." << std::flush; + + uint64_t vocab_size = 0; + uint64_t total_postings_entries = 0; // sum of all posting list lengths + uint64_t total_postings_bytes = 0; // sum of serialized bitmap sizes + uint64_t max_posting_len = 0; + std::string max_posting_term; + + // Posting list length distribution buckets + // [1], [2-10], [11-100], [101-1K], [1K-10K], [10K-100K], [100K+] + uint64_t bucket_1 = 0; + uint64_t bucket_2_10 = 0; + uint64_t bucket_11_100 = 0; + uint64_t bucket_101_1k = 0; + uint64_t bucket_1k_10k = 0; + uint64_t bucket_10k_100k = 0; + uint64_t bucket_100k_plus = 0; + + // Format counters + uint64_t roaring_count = 0; + uint64_t bitpacked_count = 0; + + { + auto iter = std::unique_ptr( + store.db_->NewIterator(store.read_opts_, postings_cf)); + while (iter->Valid()) { + const std::string term = iter->key().ToString(); + const std::string posting_data = iter->value().ToString(); + + ++vocab_size; + total_postings_bytes += posting_data.size(); + + uint64_t cardinality = 0; + + if (BitPackedPostingList::is_bitpacked_format(posting_data.data(), + posting_data.size())) { + // BitPacked format: read num_docs from FileHeader + ++bitpacked_count; + fts::BitPackedPostingIterator bp_iter; + if (bp_iter.open(posting_data.data(), posting_data.size()) == 0) { + cardinality = bp_iter.cost(); + } + } else { + // Roaring Bitmap format + ++roaring_count; + roaring_bitmap_t *bitmap = roaring_bitmap_portable_deserialize_safe( + posting_data.data(), posting_data.size()); + if (bitmap) { + cardinality = roaring_bitmap_get_cardinality(bitmap); + roaring_bitmap_free(bitmap); + } + } + + total_postings_entries += cardinality; + + if (cardinality > max_posting_len) { + max_posting_len = cardinality; + max_posting_term = term; + } + + // Bucket distribution + if (cardinality <= 1) { + ++bucket_1; + } else if (cardinality <= 10) { + ++bucket_2_10; + } else if (cardinality <= 100) { + ++bucket_11_100; + } else if (cardinality <= 1000) { + ++bucket_101_1k; + } else if (cardinality <= 10000) { + ++bucket_1k_10k; + } else if (cardinality <= 100000) { + ++bucket_10k_100k; + } else { + ++bucket_100k_plus; + } + + if (vocab_size % 10000 == 0) { + std::cout << "\r Scanning postings CF... " << vocab_size << " terms" + << std::flush; + } + + iter->Next(); + } + // iter auto-closes via unique_ptr + } + + std::cout << "\r Scanning postings CF... done. " << std::endl; + std::cout << " Posting format : " << roaring_count << " Roaring, " + << bitpacked_count << " BitPacked" << std::endl; + std::cout << " Vocabulary size : " << vocab_size << std::endl; + std::cout << " Total postings entries : " << total_postings_entries + << std::endl; + std::cout << " Total postings bytes : " << total_postings_bytes / 1024 + << " KB" << std::endl; + if (vocab_size > 0) { + std::cout << " Avg posting list len : " + << static_cast(total_postings_entries) / vocab_size + << std::endl; + std::cout << " Avg posting bytes : " + << static_cast(total_postings_bytes) / vocab_size << " B" + << std::endl; + } + std::cout << " Max posting list len : " << max_posting_len; + if (!max_posting_term.empty()) { + std::cout << " (term: \"" << max_posting_term << "\")"; + } + std::cout << std::endl; + + std::cout << std::endl; + std::cout << " Posting list length distribution:" << std::endl; + std::cout << " [1] : " << bucket_1 << std::endl; + std::cout << " [2-10] : " << bucket_2_10 << std::endl; + std::cout << " [11-100] : " << bucket_11_100 << std::endl; + std::cout << " [101-1K] : " << bucket_101_1k << std::endl; + std::cout << " [1K-10K] : " << bucket_1k_10k << std::endl; + std::cout << " [10K-100K] : " << bucket_10k_100k << std::endl; + std::cout << " [100K+] : " << bucket_100k_plus << std::endl; + + // --------------------------------------------------------------- + // 3. Document length distribution + // --------------------------------------------------------------- + std::cout << std::endl; + std::cout << "--- Document Length Distribution ---" << std::endl; + + uint64_t doc_count = 0; + uint64_t sum_doc_len = 0; + uint32_t min_doc_len = UINT32_MAX; + uint32_t max_doc_len = 0; + std::vector all_doc_lens; + + if (doc_len_cf) { + auto iter = std::unique_ptr( + store.db_->NewIterator(store.read_opts_, doc_len_cf)); + while (iter->Valid()) { + const std::string value = iter->value().ToString(); + if (value.size() >= sizeof(uint32_t)) { + uint32_t doc_len = 0; + std::memcpy(&doc_len, value.data(), sizeof(uint32_t)); + ++doc_count; + sum_doc_len += doc_len; + if (doc_len < min_doc_len) min_doc_len = doc_len; + if (doc_len > max_doc_len) max_doc_len = doc_len; + all_doc_lens.push_back(doc_len); + } + iter->Next(); + } + // iter auto-closes via unique_ptr + } else { + std::cout << " $DOC_LEN CF was dropped at build time after" + << " convert_postings_to_bitpacked()." << std::endl + << " Per-doc length info is now inlined in BitPacked" + << " postings; skipping distribution scan." << std::endl; + } + + if (doc_count > 0) { + std::sort(all_doc_lens.begin(), all_doc_lens.end()); + + auto percentile = [&](double p) -> uint32_t { + size_t idx = static_cast(p * all_doc_lens.size()); + if (idx >= all_doc_lens.size()) idx = all_doc_lens.size() - 1; + return all_doc_lens[idx]; + }; + + std::cout << " Doc count : " << doc_count << std::endl; + std::cout << " Avg doc length: " + << static_cast(sum_doc_len) / doc_count << std::endl; + std::cout << " Min doc length: " << min_doc_len << std::endl; + std::cout << " P25 doc length: " << percentile(0.25) << std::endl; + std::cout << " P50 doc length: " << percentile(0.50) << std::endl; + std::cout << " P75 doc length: " << percentile(0.75) << std::endl; + std::cout << " P95 doc length: " << percentile(0.95) << std::endl; + std::cout << " P99 doc length: " << percentile(0.99) << std::endl; + std::cout << " Max doc length: " << max_doc_len << std::endl; + } else { + std::cout << " No documents found in $DOC_LEN CF." << std::endl; + } + + // --------------------------------------------------------------- + // 4. Max-TF statistics (top terms by max term frequency) + // --------------------------------------------------------------- + if (max_tf_cf) { + std::cout << std::endl; + std::cout << "--- Top Terms by Max Term Frequency ---" << std::endl; + + struct TermMaxTf { + std::string term; + uint32_t max_tf; + }; + + // Collect all and sort by max_tf descending, show top 20 + std::vector term_max_tfs; + { + auto iter = std::unique_ptr( + store.db_->NewIterator(store.read_opts_, max_tf_cf)); + while (iter->Valid()) { + const std::string term = iter->key().ToString(); + const std::string value = iter->value().ToString(); + uint32_t max_tf = 0; + if (value.size() >= sizeof(uint32_t)) { + std::memcpy(&max_tf, value.data(), sizeof(uint32_t)); + } + term_max_tfs.push_back({term, max_tf}); + iter->Next(); + } + // iter auto-closes via unique_ptr + } + + std::sort(term_max_tfs.begin(), term_max_tfs.end(), + [](const TermMaxTf &a, const TermMaxTf &b) { + return a.max_tf > b.max_tf; + }); + + size_t show_count = std::min(20, term_max_tfs.size()); + for (size_t i = 0; i < show_count; ++i) { + std::cout << " " << (i + 1) << ". \"" << term_max_tfs[i].term + << "\" max_tf=" << term_max_tfs[i].max_tf << std::endl; + } + } + + // --------------------------------------------------------------- + // 5. Storage size summary + // --------------------------------------------------------------- + std::cout << std::endl; + std::cout << "--- Storage Size Summary ---" << std::endl; + std::cout << " Postings CF ($POSTINGS) : " << total_postings_bytes / 1024 + << " KB (serialized bitmap data)" << std::endl; + std::cout << " (Other CF sizes require RocksDB property queries or dump)" + << std::endl; + + std::cout << std::endl; + std::cout << "=== STATS COMPLETE ===" << std::endl; + + store.close(); + return 0; +} + + +int main(int argc, char *argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + + + if (FLAGS_index.empty()) { + std::cerr << "Error: -index is required." << std::endl; + std::cerr << "Usage:" << std::endl; + std::cerr << " Build : bin/fts_bench -cmd build -index -corpus " + "" + << std::endl; + std::cerr << " Search : bin/fts_bench -cmd search " + "-index -query -qrels " + << std::endl; + std::cerr << " Stats : bin/fts_bench -cmd stats -index " + << std::endl; + return 1; + } + + // Determine command: explicit -cmd flag takes priority, otherwise auto-detect + std::string cmd = FLAGS_cmd; + if (cmd.empty()) { + if (!FLAGS_corpus.empty()) { + cmd = "build"; + } else if (!FLAGS_query.empty()) { + cmd = "search"; + } else { + std::cerr << "Error: specify -cmd (build/search/stats) or -corpus/-query." + << std::endl; + return 1; + } + } + + + // Validate -mode flag + const bool db_mode = (FLAGS_mode == "db"); + if (FLAGS_mode != "raw" && FLAGS_mode != "db") { + std::cerr << "Error: unknown -mode '" << FLAGS_mode + << "'. Use 'raw' or 'db'." << std::endl; + return 1; + } + + + if (cmd == "build") { + if (FLAGS_corpus.empty()) { + std::cerr << "Error: -corpus is required in build mode." << std::endl; + return 1; + } + return db_mode ? do_build_db() : do_build(); + } else if (cmd == "search") { + if (FLAGS_query.empty()) { + std::cerr << "Error: -query is required in search mode." << std::endl; + return 1; + } + if (FLAGS_qrels.empty()) { + std::cerr << "Error: -qrels is required in search mode." << std::endl; + return 1; + } + return db_mode ? do_search_db() : do_search(); + } else if (cmd == "stats") { + if (db_mode) { + std::cerr << "Error: stats command is not supported in db mode." + << std::endl; + return 1; + } + return do_stats(); + } else { + std::cerr << "Error: unknown command '" << cmd + << "'. Use build, search, or stats." << std::endl; + return 1; + } +}