Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5392f57
nan handling for continuous and categorical coloring
Mar 3, 2025
d7a6d17
also plot nonfinite points so that we see nans
Mar 5, 2025
ac4677f
fix behavior for labels (continuous and categorical)
Mar 11, 2025
edbcb88
fix ds coloring by values in table and by categorical with nan
Mar 18, 2025
1bd9237
fix introduced mpl shapes cbar bug, ds continuous case: render nan po…
Mar 20, 2025
da69a42
fix label color vector dtype for int annotation
Mar 24, 2025
afe7f6a
add tests
Mar 25, 2025
be08f22
Merge branch 'main' into bugfix/issue355-nan-in-categorical-coloring
Sonja-Stockhaus Mar 25, 2025
673f0ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2025
a9b9873
add new test images
Mar 25, 2025
2c02cee
fix error from merging main
Mar 25, 2025
49b3ce2
Merge branch 'main' into bugfix/issue355-nan-in-categorical-coloring
Sonja-Stockhaus Apr 23, 2025
138ad75
debug ds points hex color handling
Apr 23, 2025
82d6d49
remove comment
Apr 23, 2025
85e6c08
Merge branch 'main' into bugfix/issue355-nan-in-categorical-coloring
timtreis Dec 11, 2025
bfe8066
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
b460b58
merge
timtreis Dec 11, 2025
ad5eb86
fixed
timtreis Dec 11, 2025
8aadf07
fixed test
timtreis Dec 11, 2025
3dc4f10
split test so that it actually render points with mpl
timtreis Dec 11, 2025
e1f2f17
fixed points edge case
timtreis Dec 11, 2025
15f3e93
test fixes
timtreis Dec 11, 2025
f9c892c
fixed dask categoricals
timtreis Dec 11, 2025
f32736e
fixed dask categoricals
timtreis Dec 11, 2025
05eb2ce
changed typecheck import
timtreis Dec 11, 2025
1d76909
fixed datashader point filtering
timtreis Dec 11, 2025
7b1adfd
images from runner
timtreis Dec 11, 2025
629f86c
legend ordering now deterministic again
timtreis Dec 11, 2025
272b36c
groups arg now behaves the same for points (shows non-groups in grey)
timtreis Dec 11, 2025
e9c083d
updated table filtering
timtreis Dec 11, 2025
9244f4f
images from runner
timtreis Dec 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 226 additions & 96 deletions src/spatialdata_plot/pl/render.py

Large diffs are not rendered by default.

107 changes: 92 additions & 15 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,17 +371,19 @@ def _get_collection_shape(
def _as_rgba_array(x: Any) -> np.ndarray:
return np.asarray(ColorConverter().to_rgba_array(x))

n_shapes = len(shapes)

# Case A: per-row numeric colors given as Nx3 or Nx4 float array
if (
c_arr.ndim == 2
and c_arr.shape[0] == len(shapes)
and c_arr.shape[0] == n_shapes
and c_arr.shape[1] in (3, 4)
and np.issubdtype(c_arr.dtype, np.number)
):
fill_c = _as_rgba_array(c_arr)

# Case B: continuous numeric vector len == n_shapes (possibly with NaNs)
elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and np.issubdtype(c_arr.dtype, np.number):
elif c_arr.ndim == 1 and len(c_arr) == n_shapes and np.issubdtype(c_arr.dtype, np.number):
finite_mask = np.isfinite(c_arr)

# Select or build a normalization that ignores NaNs for scaling
Expand All @@ -403,7 +405,8 @@ def _as_rgba_array(x: Any) -> np.ndarray:
if finite_mask.any():
fill_c[finite_mask] = cmap(used_norm(c_arr[finite_mask]))

elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and c_arr.dtype == object:
# Case B': 1D object/str column: may contain numeric-like and/or explicit color specs
elif c_arr.ndim == 1 and len(c_arr) == n_shapes and c_arr.dtype == object:
# Split into numeric vs color-like
c_series = pd.Series(c_arr, copy=False)
num = pd.to_numeric(c_series, errors="coerce").to_numpy()
Expand Down Expand Up @@ -434,15 +437,19 @@ def _as_rgba_array(x: Any) -> np.ndarray:
else:
fill_c = _as_rgba_array(c)

# Apply optional fill alpha without destroying existing transparency
# Apply global fill alpha from render_params
if getattr(render_params, "fill_alpha", None) is not None:
fill_c[..., -1] *= float(render_params.fill_alpha)

# Override with explicit fill_alpha if provided
if fill_alpha is not None:
nonzero_alpha = fill_c[..., -1] > 0
fill_c[nonzero_alpha, -1] = fill_alpha
fill_c[nonzero_alpha, -1] = float(fill_alpha)

# Outline handling
if outline_alpha and outline_alpha > 0.0:
if outline_alpha is not None and outline_alpha > 0.0:
outline_c_array = _as_rgba_array(outline_color)
outline_c_array[..., -1] = outline_alpha
outline_c_array[..., -1] = float(outline_alpha)
outline_c = outline_c_array.tolist()
else:
outline_c = [None] * fill_c.shape[0]
Expand Down Expand Up @@ -536,9 +543,7 @@ def _create_patches(
rows.append(pr)
return pd.DataFrame(rows)

patches = _create_patches(
shapes_df, fill_c.tolist(), outline_c.tolist() if hasattr(outline_c, "tolist") else outline_c, s
)
patches = _create_patches(shapes_df, fill_c.tolist(), outline_c, s)

return PatchCollection(
patches["geometry"].values.tolist(),
Expand Down Expand Up @@ -998,7 +1003,8 @@ def _set_color_source_vec(

if len(origins) > 1:
raise ValueError(
f"Color key '{value_to_plot}' for element '{element_name}' been found in multiple locations: {origins}."
f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. "
"Please keep it in exactly one place (preferably on the points parquet for speed) to avoid ambiguity."
)

if len(origins) == 1 and value_to_plot is not None:
Expand All @@ -1021,6 +1027,64 @@ def _set_color_source_vec(
color_source_vector if isinstance(color_source_vector, pd.Series) else pd.Series(color_source_vector)
)

if color_series.isna().all():
element_label = _format_element_name(element_name)
location = f"table '{table_name}'" if table_name is not None else "the element"
# Provide dtype hints to help diagnose index alignment issues
dtype_hints: list[str] = []
color_index_dtype = getattr(color_series.index, "dtype", None)
element_index_dtype = (
getattr(getattr(element, "index", None), "dtype", None) if element is not None else None
)

table_instance_dtype = None
table_index_dtype = None
instance_key = None
if table_name is not None and sdata is not None and table_name in sdata.tables:
table = sdata.tables[table_name]
table_index_dtype = getattr(getattr(table, "obs", None), "index", None)
if table_index_dtype is not None:
table_index_dtype = getattr(table_index_dtype, "dtype", None)
try:
_, _, instance_key = get_table_keys(table)
except (KeyError, ValueError, TypeError, AttributeError):
instance_key = None
if instance_key is not None and hasattr(table, "obs") and instance_key in table.obs:
table_instance_dtype = table.obs[instance_key].dtype

if (
element_index_dtype is not None
and table_instance_dtype is not None
and element_index_dtype != table_instance_dtype
):
dtype_hints.append(
f"element index dtype is {element_index_dtype}, '{instance_key}' dtype is {table_instance_dtype}"
)
if (
table_index_dtype is not None
and table_instance_dtype is not None
and table_index_dtype != table_instance_dtype
):
dtype_hints.append(
f"table index dtype is {table_index_dtype}, '{instance_key}' dtype is {table_instance_dtype}"
)
if (
color_index_dtype is not None
and element_index_dtype is not None
and color_index_dtype != element_index_dtype
):
dtype_hints.append(
f"color index dtype is {color_index_dtype}, element index dtype is {element_index_dtype}"
)

dtype_hint = f" (hint: {'; '.join(dtype_hints)})" if dtype_hints else ""
raise ValueError(
f"Column '{value_to_plot}' for element '{element_label}' contains only missing values after aligning "
f"with {location}. This usually means the instance ids/indices could not be aligned or converted, so "
"colors cannot be determined. Please ensure the table annotates the element with matching instance ids."
f"{dtype_hint}"
)

kind, processed = _infer_color_data_kind(
series=color_series,
value_to_plot=value_to_plot,
Expand All @@ -1045,6 +1109,9 @@ def _set_color_source_vec(
return None, numeric_vector, False

assert isinstance(processed, pd.Categorical)
if not processed.ordered:
# ensure deterministic category order when the source is unordered (e.g., from a Python set)
processed = processed.reorder_categories(sorted(processed.categories))
color_source_vector = processed # convert, e.g., `pd.Series`

# Use the provided table_name parameter, fall back to only one present
Expand Down Expand Up @@ -1121,6 +1188,12 @@ def _set_color_source_vec(

# do not rename categories, as colors need not be unique
color_vector = color_source_vector.map(color_mapping)
# nan handling: only add the NA category if needed, and store it as a hex string
na_color_hex = na_color.get_hex_with_alpha() if isinstance(na_color, Color) else str(na_color)
if pd.isna(color_vector).any():
if na_color_hex not in color_vector.categories:
color_vector = color_vector.add_categories(na_color_hex)
color_vector[pd.isna(color_vector)] = na_color_hex

return color_source_vector, color_vector, True

Expand Down Expand Up @@ -1148,15 +1221,18 @@ def _map_color_seg(

if pd.api.types.is_categorical_dtype(color_vector.dtype):
# Case A: users wants to plot a categorical column
if np.any(color_source_vector.isna()):
cell_id[color_source_vector.isna()] = 0
Comment on lines -1151 to -1152
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In seg, the value 0 depicts the background, so this would lead to the bg being mapped to the NaN color
The actual label(s) with na in the color_source_vector don't have their id in cell_id anymore, so they're mapped to nothing! => would look like background

val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1)
cols = colors.to_rgba_array(color_vector.categories)
elif pd.api.types.is_numeric_dtype(color_vector.dtype):
# Case B: user wants to plot a continous column
if isinstance(color_vector, pd.Series):
color_vector = color_vector.to_numpy()
cols = cmap_params.cmap(cmap_params.norm(color_vector))
# normalize only the not nan values, else the whole array would contain only nan values
normed_color_vector = color_vector.copy().astype(float)
normed_color_vector[~np.isnan(normed_color_vector)] = cmap_params.norm(
normed_color_vector[~np.isnan(normed_color_vector)]
)
cols = cmap_params.cmap(normed_color_vector)
val_im = map_array(seg.copy(), cell_id, cell_id)
else:
# Case C: User didn't specify any colors
Expand Down Expand Up @@ -2639,6 +2715,7 @@ def _validate_col_for_column_table(
elif table_name is not None:
tables = get_element_annotators(sdata, element_name)
if table_name not in tables:
logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.")
raise KeyError(f"Table '{table_name}' does not annotate element '{element_name}'.")
if col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names:
raise KeyError(
Expand Down Expand Up @@ -3032,7 +3109,7 @@ def _prepare_transformation(
def _datashader_map_aggregate_to_color(
agg: DataArray,
cmap: str | list[str] | ListedColormap,
color_key: None | list[str] = None,
color_key: list[str] | dict[str, str] | None = None,
min_alpha: float = 40,
span: None | list[float] = None,
clip: bool = True,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Points_datashader_can_color_by_category.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Points_datashader_colors_from_table_obs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_can_filter_with_groups.png
Binary file modified tests/_images/Shapes_datashader_can_color_by_category.png
Binary file modified tests/_images/Shapes_datashader_can_color_by_category_with_cmap.png
42 changes: 42 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,48 @@ def test_sdata_multiple_images_diverging_dims():
return sd.SpatialData(images=images)


@pytest.fixture
def sdata_blobs_points_with_nans_in_table() -> SpatialData:
"""Get blobs sdata where the table annotates the points and includes nan values"""
blob = blobs()
n_obs = len(blob["blobs_points"])
adata = AnnData(get_standard_RNG().normal(size=(n_obs, 2)))
adata.X[0:30, 0] = np.nan
adata.var = pd.DataFrame({}, index=["col1", "col2"])
adata.obs = pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["col_a", "col_b", "col_c"])
adata.obs.iloc[0:30, adata.obs.columns.get_loc("col_a")] = np.nan
adata.obs["instance_id"] = np.arange(adata.n_obs)
cat_pattern = ["a", "b", np.nan]
repeats = (n_obs + len(cat_pattern) - 1) // len(cat_pattern)
adata.obs["category"] = pd.Categorical((cat_pattern * repeats)[:n_obs])
adata.obs["instance_id"] = list(range(adata.n_obs))
adata.obs["region"] = "blobs_points"
table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_points")
blob["table"] = table
return blob


@pytest.fixture
def sdata_blobs_shapes_with_nans_in_table() -> SpatialData:
"""Get blobs sdata where the table annotates the shapes and includes nan values"""
blob = blobs()
n_obs = len(blob["blobs_polygons"])
adata = AnnData(get_standard_RNG().normal(size=(n_obs, 2)))
adata.X[0, 0] = np.nan
adata.var = pd.DataFrame({}, index=["col1", "col2"])
adata.obs = pd.DataFrame(get_standard_RNG().normal(size=(n_obs, 3)), columns=["col_a", "col_b", "col_c"])
adata.obs.iloc[0, adata.obs.columns.get_loc("col_a")] = np.nan
adata.obs["instance_id"] = np.arange(adata.n_obs)
cat_pattern = ["a", "b", np.nan, "c", "a"]
repeats = (n_obs + len(cat_pattern) - 1) // len(cat_pattern)
adata.obs["category"] = pd.Categorical((cat_pattern * repeats)[:n_obs])
adata.obs["instance_id"] = list(range(adata.n_obs))
adata.obs["region"] = "blobs_polygons"
table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_polygons")
blob["table"] = table
return blob


@pytest.fixture
def sdata_blobs_shapes_annotated() -> SpatialData:
"""Get blobs sdata with continuous annotation of polygons."""
Expand Down
13 changes: 13 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dask.array as da
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
import scanpy as sc
Expand Down Expand Up @@ -263,6 +264,18 @@ def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialDat
sdata_blobs["table"].layers["normalized"] = get_standard_RNG().random(sdata_blobs["table"].X.shape)
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show()

def test_plot_can_annotate_labels_with_nan_in_table_obs_categorical(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["cat_color"] = pd.Categorical(["a", "b", "b", "a", "b"] * 5 + [np.nan])
sdata_blobs.pl.render_labels("blobs_labels", color="cat_color").pl.show()

def test_plot_can_annotate_labels_with_nan_in_table_obs_continuous(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["cont_color"] = [np.nan, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] * 2
sdata_blobs.pl.render_labels("blobs_labels", color="cont_color").pl.show()

def test_plot_can_annotate_labels_with_nan_in_table_X_continuous(self, sdata_blobs: SpatialData):
sdata_blobs["table"].X[0:5, 0] = np.nan
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()

def _prepare_labels_with_small_objects(self, sdata_blobs: SpatialData) -> SpatialData:
# add a categorical column
adata = sdata_blobs["table"]
Expand Down
54 changes: 54 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,60 @@ def test_plot_can_annotate_points_with_table_layer(self, sdata_blobs: SpatialDat

sdata_blobs.pl.render_points("blobs_points", color="feature0", size=10, table_layer="normalized").pl.show()

def test_plot_can_annotate_points_with_nan_in_table_obs_categorical_matplotlib(
self, sdata_blobs_points_with_nans_in_table: SpatialData
):
sdata_blobs_points_with_nans_in_table.pl.render_points(
"blobs_points", color="category", size=40, method="matplotlib"
).pl.show()

def test_plot_can_annotate_points_with_nan_in_table_obs_categorical_datashader(
self, sdata_blobs_points_with_nans_in_table: SpatialData
):
sdata_blobs_points_with_nans_in_table.pl.render_points(
"blobs_points", color="category", size=40, method="datashader"
).pl.show()

def test_plot_can_annotate_points_with_nan_in_table_obs_continuous(
self, sdata_blobs_points_with_nans_in_table: SpatialData
):
sdata_blobs_points_with_nans_in_table.pl.render_points("blobs_points", color="col_a", size=30).pl.show()

def test_plot_can_annotate_points_with_nan_in_table_obs_continuous_datashader(
self, sdata_blobs_points_with_nans_in_table: SpatialData
):
sdata_blobs_points_with_nans_in_table.pl.render_points(
"blobs_points", color="col_a", size=40, method="datashader"
).pl.show()

def test_plot_can_annotate_points_with_nan_in_table_X_continuous(
self, sdata_blobs_points_with_nans_in_table: SpatialData
):
sdata_blobs_points_with_nans_in_table.pl.render_points("blobs_points", color="col1", size=30).pl.show()

def test_plot_can_annotate_points_with_nan_in_table_X_continuous_datashader(
self, sdata_blobs_points_with_nans_in_table: SpatialData
):
sdata_blobs_points_with_nans_in_table.pl.render_points(
"blobs_points", color="col1", size=40, method="datashader"
).pl.show()

def test_plot_can_annotate_points_with_nan_in_df_categorical(self, sdata_blobs: SpatialData):
sdata_blobs["blobs_points"]["cat_color"] = pd.Series([np.nan, "a", "b", "c"] * 50, dtype="category")
sdata_blobs.pl.render_points("blobs_points", color="cat_color", size=30).pl.show()

def test_plot_can_annotate_points_with_nan_in_df_categorical_datashader(self, sdata_blobs: SpatialData):
sdata_blobs["blobs_points"]["cat_color"] = pd.Series([np.nan, "a", "b", "c"] * 50, dtype="category")
sdata_blobs.pl.render_points("blobs_points", color="cat_color", size=40, method="datashader").pl.show()

def test_plot_can_annotate_points_with_nan_in_df_continuous(self, sdata_blobs: SpatialData):
sdata_blobs["blobs_points"]["cont_color"] = pd.Series([np.nan, 2, 9, 13] * 50)
sdata_blobs.pl.render_points("blobs_points", color="cont_color", size=30).pl.show()

def test_plot_can_annotate_points_with_nan_in_df_continuous_datashader(self, sdata_blobs: SpatialData):
sdata_blobs["blobs_points"]["cont_color"] = pd.Series([np.nan, 2, 9, 13] * 50)
sdata_blobs.pl.render_points("blobs_points", color="cont_color", size=40, method="datashader").pl.show()


def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData):
# Work on an independent copy since we mutate tables
Expand Down
Loading