diff --git a/tide/utils.py b/tide/utils.py index a56b339..d7fa40d 100644 --- a/tide/utils.py +++ b/tide/utils.py @@ -6,6 +6,7 @@ from bigtree import dict_to_tree, levelordergroup_iter from bigtree.node import node from typing import TypeVar +from functools import lru_cache T = TypeVar("T", bound=node.Node) @@ -33,6 +34,34 @@ 2: {"name": 1}, } +@lru_cache(maxsize=32) +def _cached_enriched_columns(columns_tuple: tuple[str, ...]): + max_level = get_tags_max_level(columns_tuple) + + enriched_map = { + col_name_tag_enrichment(col, max_level): col for col in columns_tuple + } + + split_tags = { + enriched: enriched.split("__") for enriched in enriched_map + } + + return enriched_map, split_tags + +@lru_cache(maxsize=32) +def _build_tag_index(columns_tuple: tuple[str, ...]): + max_level = get_tags_max_level(columns_tuple) + + tag_index = {} + order = {col: i for i, col in enumerate(columns_tuple)} + + for col in columns_tuple: + enriched = col_name_tag_enrichment(col, max_level) + + for tag in enriched.split("__"): + tag_index.setdefault(tag, set()).add(col) + + return tag_index, order def get_tree_depth_from_level(tree_max_depth: int, level: int | str): level = LEVEL_NAME_MAP[level] if isinstance(level, int) else level @@ -213,12 +242,8 @@ def tide_request( f"request must be str, list[str], pd.Index or None, got {type(request)}" ) - max_level = get_tags_max_level(data_columns) - - # Enrich columns once - enriched_map = { - col_name_tag_enrichment(col, max_level): col for col in data_columns - } + columns_tuple = tuple(data_columns) + tag_index, order = _build_tag_index(columns_tuple) selected = [] @@ -232,12 +257,17 @@ def tide_request( "Use up to 4 tags separated by '__'." ) - for enriched_name, original in enriched_map.items(): - tags = enriched_name.split("__") + candidate_sets = [] + + for tag in group_tags: + if tag not in tag_index: + candidate_sets = [] + break + candidate_sets.append(tag_index[tag]) - # Exact per-tag match - if all(tag in tags for tag in group_tags): - selected.append(original) + if candidate_sets: + matches = set.intersection(*candidate_sets) + selected.extend(sorted(matches, key=lambda c: order[c])) return list(dict.fromkeys(selected))