diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 2d607fe9..73786d3c 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1103,13 +1103,11 @@ def _render_images( layers[ch] = img.sel(c=ch).copy(deep=True).squeeze() if isinstance(render_params.cmap_params, list): ch_norm = render_params.cmap_params[ch_idx].norm - ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default else: ch_norm = render_params.cmap_params.norm - ch_cmap_is_default = render_params.cmap_params.cmap_is_default - if not ch_cmap_is_default and ch_norm is not None: - layers[ch_idx] = ch_norm(layers[ch_idx]) + if ch_norm is not None: + layers[ch] = ch_norm(layers[ch]) # 2A) Image has 3 channels, no palette info, and no/only one cmap was given if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list): diff --git a/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png b/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png old mode 100644 new mode 100755 index bb0246c6..aecc4ca0 Binary files a/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png and b/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png differ diff --git a/tests/_images/Images_correctly_normalizes_multichannel_images.png b/tests/_images/Images_correctly_normalizes_multichannel_images.png new file mode 100755 index 00000000..0cf51c43 Binary files /dev/null and b/tests/_images/Images_correctly_normalizes_multichannel_images.png differ diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 5cba4e88..3e28799a 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -1,9 +1,12 @@ import dask.array as da import matplotlib +import matplotlib.pyplot as plt +import numpy as np import scanpy as sc from matplotlib.colors import Normalize from spatial_image import to_spatial_image from spatialdata import SpatialData +from spatialdata.models import Image2DModel import spatialdata_plot # noqa: F401 from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over @@ -130,3 +133,17 @@ def test_plot_can_stick_to_zorder(self, sdata_blobs: SpatialData): def test_plot_can_render_multiscale_image_with_custom_cmap(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images("blobs_multiscale_image", channel=0, scale="scale2", cmap="Greys").pl.show() + + def test_plot_correctly_normalizes_multichannel_images(self, sdata_raccoon: SpatialData): + sdata_raccoon["raccoon_int16"] = Image2DModel.parse( + sdata_raccoon["raccoon"].data.astype(np.uint16) * 257, # 255 * 257 = 65535, + dims=("c", "y", "x"), + ) + + # show multi-channel vs single-channel + fig, axs = plt.subplots(nrows=1, ncols=2) + sdata_raccoon.pl.render_images("raccoon_int16", channel=[0]).pl.show(ax=axs[0], colorbar=False) + axs[0].set_title("single-channel uint16") + sdata_raccoon.pl.render_images("raccoon_int16", channel=[0, 1], palette=["yellow", "red"]).pl.show(ax=axs[1]) + axs[1].set_title("two-channel uint16") + fig.tight_layout()