Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
715d80f
Refactor add_node_attr_key/add_edge_attr_key with dtype parameter
JoOkuma Jan 20, 2026
e7aa9c7
Refactor: require explicit dtype parameter for add_node_attr_key and …
JoOkuma Jan 20, 2026
2c7a9e1
Update all tests to use explicit dtype parameter
JoOkuma Jan 20, 2026
5392cbf
fixing correct usage of dtypes and using default when possible
JoOkuma Jan 20, 2026
b2368e1
fixing more attr usage
JoOkuma Jan 20, 2026
e3ac6ea
removing unused function
JoOkuma Jan 20, 2026
b7b00f4
improving attr key default usage
JoOkuma Jan 20, 2026
2bcafcb
improving testing
JoOkuma Jan 20, 2026
8288a69
fixing spatial filtering typing
JoOkuma Jan 20, 2026
58256e2
simplifying sql type detection case
JoOkuma Jan 20, 2026
a05788c
refactoring generic usage
JoOkuma Jan 20, 2026
b83b587
moving tests and fixing get attr usage
JoOkuma Jan 20, 2026
3ca9164
cleanup on sql type conversion
JoOkuma Jan 20, 2026
7940fee
fixing docs
JoOkuma Jan 21, 2026
29e8be2
delayed attr key default value -- working version
JoOkuma Jan 22, 2026
3ef7d5c
simplifying code
JoOkuma Jan 22, 2026
8e584ee
not removing null when default value is None
JoOkuma Jan 23, 2026
1cb217f
Merge branch 'main' into jookuma/attr-key-refactoring
JoOkuma Jan 23, 2026
cab9e27
Merge branch 'jookuma/attr-key-refactoring' into jookuma/attr-key-del…
JoOkuma Jan 23, 2026
fe09cc9
reproducing bug
JoOkuma Jan 28, 2026
f19e6d6
adding subgraph schema matching test and fixing it
JoOkuma Jan 28, 2026
dfd1913
adding from_other schema matching
JoOkuma Jan 28, 2026
9b09154
Merge branch 'main' into jookuma/attr-key-delayed-default
JoOkuma Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
buffer_cache_size: int | None = None,
dtype: np.dtype | None = None,
):
if attr_key not in graph.node_attr_keys():
if attr_key not in graph.node_attr_keys(return_ids=True):
raise ValueError(f"Attribute key '{attr_key}' not found in graph. Expected '{graph.node_attr_keys()}'")

self.graph = graph
Expand Down
16 changes: 8 additions & 8 deletions src/tracksdata/edges/_test/test_distance_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def test_distance_edges_neighbors_per_frame_false() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", pl.Float64, 0.0)
graph.add_node_attr_key("y", pl.Float64, 0.0)

# Add nodes at t=0, t=1, t=2
# At t=0: two nodes close to origin
Expand Down Expand Up @@ -363,8 +363,8 @@ def test_distance_edges_neighbors_per_frame_true() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", pl.Float64, 0.0)
graph.add_node_attr_key("y", pl.Float64, 0.0)

# Add nodes at t=0, t=1, t=2
# At t=0: two nodes close to origin
Expand Down Expand Up @@ -408,8 +408,8 @@ def test_distance_edges_neighbors_per_frame_with_distance_threshold() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", pl.Float64, 0.0)
graph.add_node_attr_key("y", pl.Float64, 0.0)

# Add nodes at t=0 (far away), t=1 (close), t=2
# At t=0: nodes very far from where t=2 node will be
Expand Down Expand Up @@ -444,8 +444,8 @@ def test_distance_edges_neighbors_per_frame_single_delta_t() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", pl.Float64, 0.0)
graph.add_node_attr_key("y", pl.Float64, 0.0)

# Add nodes at t=0 and t=1
for i in range(3):
Expand Down
7 changes: 5 additions & 2 deletions src/tracksdata/functional/_test/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def test_apply_tiled_default_attrs(sample_graph: RustWorkXGraph) -> None:
def test_apply_tiled_2d_tiling() -> None:
"""Test apply_tiled with 2D spatial coordinates."""
graph = RustWorkXGraph()
graph.add_node_attr_key("y", dtype=pl.Int64)
graph.add_node_attr_key("x", dtype=pl.Int64)
graph.add_node_attr_key("y", dtype=pl.Float64)
graph.add_node_attr_key("x", dtype=pl.Float64)

for y in [5, 11, 14]:
for x in [10, 30]:
Expand Down Expand Up @@ -192,6 +192,9 @@ def test_apply_tile_scale_invariance() -> None:

for scale in scales:
graph = RustWorkXGraph()
# hack: updating schema
graph._node_attr_schemas()["t"].dtype = pl.Float64

for p in pos:
graph.add_node({"t": p * scale})

Expand Down
66 changes: 48 additions & 18 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def supports_custom_indices(self) -> bool:
def _validate_attributes(
attrs: dict[str, Any],
reference_keys: list[str],
mode: str,
mode: Literal["node", "edge"],
) -> None:
"""
Validate the attributes of a node.
Expand All @@ -85,11 +85,18 @@ def _validate_attributes(
f"`graph.add_{mode}_attr_key(key, default_value)`"
)

for ref_key in reference_keys:
if ref_key not in attrs.keys() and ref_key != DEFAULT_ATTR_KEYS.NODE_ID:
raise ValueError(
f"Attribute '{ref_key}' not found in attrs: '{attrs.keys()}'\nRequested keys: '{reference_keys}'"
)
missing_keys = set(reference_keys) - set(attrs.keys())
missing_keys = missing_keys - {
DEFAULT_ATTR_KEYS.NODE_ID,
DEFAULT_ATTR_KEYS.EDGE_ID,
DEFAULT_ATTR_KEYS.EDGE_SOURCE,
DEFAULT_ATTR_KEYS.EDGE_TARGET,
}

if missing_keys:
raise ValueError(
f"{mode} attribute keys not found in attrs: '{missing_keys}'\nRequested keys: '{reference_keys}'"
)

@abc.abstractmethod
def add_node(
Expand Down Expand Up @@ -626,15 +633,27 @@ def edge_attrs(
"""

@abc.abstractmethod
def node_attr_keys(self) -> list[str]:
def node_attr_keys(self, return_ids: bool = False) -> list[str]:
"""
Get the keys of the attributes of the nodes.

Parameters
----------
return_ids : bool, default False
Whether to include NODE_ID in the returned keys. Defaults to False.
If True, NODE_ID will be included in the list.
"""

@abc.abstractmethod
def edge_attr_keys(self) -> list[str]:
def edge_attr_keys(self, return_ids: bool = False) -> list[str]:
"""
Get the keys of the attributes of the edges.

Parameters
----------
return_ids : bool, optional
Whether to include EDGE_ID, EDGE_SOURCE, and EDGE_TARGET in the returned keys.
Defaults to False. If True, these ID fields will be included in the list.
"""

@overload
Expand Down Expand Up @@ -1169,11 +1188,10 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
graph = cls(**kwargs)
graph.update_metadata(**other.metadata())

for col in node_attrs.columns:
if col != DEFAULT_ATTR_KEYS.T:
# Use the dtype from the source DataFrame
dtype = node_attrs[col].dtype
graph.add_node_attr_key(col, dtype)
current_node_attr_schemas = graph._node_attr_schemas()
for k, v in other._node_attr_schemas().items():
if k not in current_node_attr_schemas:
graph.add_node_attr_key(k, v.dtype, v.default_value)

if graph.supports_custom_indices():
new_node_ids = graph.bulk_add_nodes(
Expand All @@ -1195,11 +1213,11 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
edge_attrs = other.edge_attrs()
edge_attrs = edge_attrs.drop(DEFAULT_ATTR_KEYS.EDGE_ID)

for col in edge_attrs.columns:
if col not in [DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]:
# Use the dtype from the source DataFrame
dtype = edge_attrs[col].dtype
graph.add_edge_attr_key(col, dtype)
current_edge_attr_schemas = graph._edge_attr_schemas()
for k, v in other._edge_attr_schemas().items():
if k not in current_edge_attr_schemas:
print(f"Adding edge attribute key: {k} with dtype: {v.dtype} and default value: {v.default_value}")
graph.add_edge_attr_key(k, v.dtype, v.default_value)

edge_attrs = edge_attrs.with_columns(
edge_attrs[col].map_elements(node_map.get, return_dtype=pl.Int64).alias(col)
Expand Down Expand Up @@ -1930,6 +1948,18 @@ def __getitem__(self, node_id: int) -> "NodeInterface":
raise ValueError(f"graph index must be a integer, found '{node_id}' of type {type(node_id)}")
return NodeInterface(self, node_id)

@abc.abstractmethod
def _node_attr_schemas(self) -> dict[str, AttrSchema]:
"""
Get the attribute schemas for the nodes.
"""

@abc.abstractmethod
def _edge_attr_schemas(self) -> dict[str, AttrSchema]:
"""
Get the attribute schemas for the edges.
"""


class NodeInterface:
"""
Expand Down
69 changes: 63 additions & 6 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,29 @@ def __init__(
self._node_attr_keys = node_attr_keys
self._edge_attr_keys = edge_attr_keys

# add default keys to the node and edge attr keys
if self._node_attr_keys is not None:
self._node_attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T, *self._node_attr_keys]
self._node_attr_keys = list(dict.fromkeys(self._node_attr_keys))

if self._edge_attr_keys is not None:
self._edge_attr_keys = [
DEFAULT_ATTR_KEYS.EDGE_ID,
DEFAULT_ATTR_KEYS.EDGE_SOURCE,
DEFAULT_ATTR_KEYS.EDGE_TARGET,
*self._edge_attr_keys,
]
self._edge_attr_keys = list(dict.fromkeys(self._edge_attr_keys))

# use parent graph overlaps
self._overlaps = None

def _node_attr_schemas(self) -> dict[str, AttrSchema]:
return self._root._node_attr_schemas()

def _edge_attr_schemas(self) -> dict[str, AttrSchema]:
return self._root._edge_attr_schemas()

def supports_custom_indices(self) -> bool:
return self._root.supports_custom_indices()

Expand Down Expand Up @@ -230,11 +250,48 @@ def filter(
include_sources=include_sources,
)

def node_attr_keys(self) -> list[str]:
return self._root.node_attr_keys() if self._node_attr_keys is None else self._node_attr_keys
def node_attr_keys(self, return_ids: bool = False) -> list[str]:
"""
Get the keys of the attributes of the nodes.

def edge_attr_keys(self) -> list[str]:
return self._root.edge_attr_keys() if self._edge_attr_keys is None else self._edge_attr_keys
Parameters
----------
return_ids : bool, optional
Whether to include NODE_ID in the returned keys. Defaults to False.
If True, NODE_ID will be included in the list.
"""
if self._node_attr_keys is None:
return self._root.node_attr_keys(return_ids=return_ids)
else:
keys = self._node_attr_keys.copy()
if not return_ids:
try:
keys.remove(DEFAULT_ATTR_KEYS.NODE_ID)
except ValueError:
pass
return keys

def edge_attr_keys(self, return_ids: bool = False) -> list[str]:
"""
Get the keys of the attributes of the edges.

Parameters
----------
return_ids : bool, optional
Whether to include EDGE_ID, EDGE_SOURCE, and EDGE_TARGET in the returned keys.
Defaults to False. If True, these ID fields will be included in the list.
"""
if self._edge_attr_keys is None:
return self._root.edge_attr_keys(return_ids=return_ids)
else:
keys = self._edge_attr_keys.copy()
if not return_ids:
for k in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]:
try:
keys.remove(k)
except ValueError:
pass
return keys

def add_node_attr_key(
self,
Expand All @@ -258,7 +315,7 @@ def add_node_attr_key(
if not self._is_root_rx_graph:
if self.sync:
# Get the schema from root to get the actual default value used
schema = self._root._node_attr_schemas[key]
schema = self._root._node_attr_schemas()[key]
# Apply to local rx_graph
rx_graph = self.rx_graph
for node_id in rx_graph.node_indices():
Expand Down Expand Up @@ -300,7 +357,7 @@ def add_edge_attr_key(
if not self._is_root_rx_graph:
if self.sync:
# Get the schema from root to get the actual default value used
schema = self._root._edge_attr_schemas[key]
schema = self._root._edge_attr_schemas()[key]
# Apply to local rx_graph
for _, _, edge_attr in self.rx_graph.weighted_edge_list():
edge_attr[key] = schema.default_value
Expand Down
Loading
Loading