Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions tide/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from bigtree import dict_to_tree, levelordergroup_iter
from bigtree.node import node
from typing import TypeVar
from functools import lru_cache

T = TypeVar("T", bound=node.Node)

Expand Down Expand Up @@ -33,6 +34,34 @@
2: {"name": 1},
}

@lru_cache(maxsize=32)
def _cached_enriched_columns(columns_tuple: tuple[str, ...]):
max_level = get_tags_max_level(columns_tuple)

enriched_map = {
col_name_tag_enrichment(col, max_level): col for col in columns_tuple
}

split_tags = {
enriched: enriched.split("__") for enriched in enriched_map
}

return enriched_map, split_tags

@lru_cache(maxsize=32)
def _build_tag_index(columns_tuple: tuple[str, ...]):
max_level = get_tags_max_level(columns_tuple)

tag_index = {}
order = {col: i for i, col in enumerate(columns_tuple)}

for col in columns_tuple:
enriched = col_name_tag_enrichment(col, max_level)

for tag in enriched.split("__"):
tag_index.setdefault(tag, set()).add(col)

return tag_index, order

def get_tree_depth_from_level(tree_max_depth: int, level: int | str):
level = LEVEL_NAME_MAP[level] if isinstance(level, int) else level
Expand Down Expand Up @@ -213,12 +242,8 @@ def tide_request(
f"request must be str, list[str], pd.Index or None, got {type(request)}"
)

max_level = get_tags_max_level(data_columns)

# Enrich columns once
enriched_map = {
col_name_tag_enrichment(col, max_level): col for col in data_columns
}
columns_tuple = tuple(data_columns)
tag_index, order = _build_tag_index(columns_tuple)

selected = []

Expand All @@ -232,12 +257,17 @@ def tide_request(
"Use up to 4 tags separated by '__'."
)

for enriched_name, original in enriched_map.items():
tags = enriched_name.split("__")
candidate_sets = []

for tag in group_tags:
if tag not in tag_index:
candidate_sets = []
break
candidate_sets.append(tag_index[tag])

# Exact per-tag match
if all(tag in tags for tag in group_tags):
selected.append(original)
if candidate_sets:
matches = set.intersection(*candidate_sets)
selected.extend(sorted(matches, key=lambda c: order[c]))

return list(dict.fromkeys(selected))

Expand Down
Loading