diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 6a91f436f..355cb8cee 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -22,7 +22,7 @@ ) from memos.memories.textual.item import TextualMemoryItem from memos.multi_mem_cube.views import MemCubeView -from memos.search import search_text_memories +from memos.search import resolve_filter_for_cube, search_text_memories from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( FINE_STRATEGY, @@ -91,6 +91,13 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: Unified memory search handling (text + preference memories). Preference memories are now searched through the same _search_text flow. """ + cube_filter = resolve_filter_for_cube(search_req.filter, self.cube_id) + if cube_filter is not search_req.filter: + import copy + + search_req = copy.copy(search_req) + search_req.filter = cube_filter + # Create UserContext object user_context = UserContext( user_id=search_req.user_id, diff --git a/src/memos/search/__init__.py b/src/memos/search/__init__.py index 71388c62b..d2c197403 100644 --- a/src/memos/search/__init__.py +++ b/src/memos/search/__init__.py @@ -1,4 +1,14 @@ -from .search_service import SearchContext, build_search_context, search_text_memories +from .search_service import ( + SearchContext, + build_search_context, + resolve_filter_for_cube, + search_text_memories, +) -__all__ = ["SearchContext", "build_search_context", "search_text_memories"] +__all__ = [ + "SearchContext", + "build_search_context", + "resolve_filter_for_cube", + "search_text_memories", +] diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py index fa713a7d1..f4092d168 100644 --- a/src/memos/search/search_service.py +++ b/src/memos/search/search_service.py @@ -36,6 +36,35 @@ def build_search_context( ) +def resolve_filter_for_cube( + raw_filter: dict[str, Any] | None, cube_id: str +) -> dict[str, Any] | None: + """Resolve a multi-cube filter dict into the sub-filter for a single cube. + + Supported forms: + - None → None (no filter) + - {"and": [...]} / {"or": [...]} → returned as-is (unified, all cubes share) + - {"cube_A": {...}, "cube_B": {...}} → return raw_filter[cube_id] or None + Mixed top-level (and/or + cube keys) is rejected. + """ + if raw_filter is None: + return None + + has_logic_key = "and" in raw_filter or "or" in raw_filter + other_keys = {k for k in raw_filter if k not in ("and", "or")} + + if has_logic_key and other_keys: + raise ValueError( + "Invalid filter: top-level 'and'/'or' cannot coexist with per-cube keys " + f"{other_keys}. Use either a unified filter or per-cube filter, not both." + ) + + if has_logic_key: + return raw_filter + + return raw_filter.get(cube_id) + + def search_text_memories( text_mem: Any, search_req: APISearchRequest, diff --git a/tests/search/test_resolve_filter_for_cube.py b/tests/search/test_resolve_filter_for_cube.py new file mode 100644 index 000000000..ab6be71a3 --- /dev/null +++ b/tests/search/test_resolve_filter_for_cube.py @@ -0,0 +1,78 @@ +import pytest + +from memos.search.search_service import resolve_filter_for_cube + + +class TestResolveFilterForCube: + """Tests for resolve_filter_for_cube — multi-cube filter routing.""" + + # ── None passthrough ── + + def test_none_returns_none(self): + assert resolve_filter_for_cube(None, "cube_001") is None + + # ── Unified filter (filter2): top-level and/or ── + + def test_unified_and_returns_same_for_any_cube(self): + f = {"and": [{"tags": {"contains": "阅读"}}, {"created_at": {"gte": "2025-01-01"}}]} + assert resolve_filter_for_cube(f, "cube_001") is f + assert resolve_filter_for_cube(f, "cube_999") is f + + def test_unified_or_returns_same_for_any_cube(self): + f = {"or": [{"tags": {"contains": "A"}}, {"tags": {"contains": "B"}}]} + assert resolve_filter_for_cube(f, "cube_001") is f + + # ── Per-cube filter (filter1 / filter4) ── + + def test_per_cube_returns_matching_sub_filter(self): + sub_a = {"and": [{"tags": {"contains": "阅读"}}]} + sub_b = {"and": [{"tags": {"contains": "工作"}}]} + f = {"cube_A": sub_a, "cube_B": sub_b} + + assert resolve_filter_for_cube(f, "cube_A") is sub_a + assert resolve_filter_for_cube(f, "cube_B") is sub_b + + def test_per_cube_missing_key_returns_none(self): + f = { + "cube_A": {"and": [{"tags": {"contains": "阅读"}}]}, + "cube_B": {"and": [{"tags": {"contains": "工作"}}]}, + } + assert resolve_filter_for_cube(f, "cube_C") is None + + def test_per_cube_single_key(self): + sub = {"and": [{"created_at": {"gte": "2025-01-01"}}]} + f = {"cube_only": sub} + assert resolve_filter_for_cube(f, "cube_only") is sub + assert resolve_filter_for_cube(f, "other") is None + + # ── Mixed (filter3): illegal ── + + def test_mixed_and_with_cube_key_raises(self): + f = { + "and": [{"tags": {"contains": "阅读"}}], + "cube_A": {"and": [{"tags": {"contains": "工作"}}]}, + } + with pytest.raises(ValueError, match="cannot coexist"): + resolve_filter_for_cube(f, "cube_A") + + def test_mixed_or_with_cube_key_raises(self): + f = { + "or": [{"tags": {"contains": "阅读"}}], + "cube_B": {"and": [{"tags": {"contains": "工作"}}]}, + } + with pytest.raises(ValueError, match="cannot coexist"): + resolve_filter_for_cube(f, "cube_B") + + # ── Edge cases ── + + def test_empty_dict_returns_none(self): + assert resolve_filter_for_cube({}, "cube_001") is None + + def test_per_cube_with_empty_sub_filter(self): + f = {"cube_A": {}} + result = resolve_filter_for_cube(f, "cube_A") + assert result == {} + + def test_unified_and_empty_list(self): + f = {"and": []} + assert resolve_filter_for_cube(f, "any") is f