1212
1313import matplotlib
1414import matplotlib .patches as mpatches
15- import matplotlib .patches as mplp
1615import matplotlib .path as mpath
1716import matplotlib .pyplot as plt
18- import multiscale_spatial_image as msi
1917import numpy as np
2018import pandas as pd
2119import shapely
4947from scanpy .plotting ._tools .scatterplots import _add_categorical_legend
5048from scanpy .plotting ._utils import add_colors_for_categorical_sample_annotation
5149from scanpy .plotting .palettes import default_20 , default_28 , default_102
52- from shapely .geometry import LineString , Polygon
5350from skimage .color import label2rgb
5451from skimage .morphology import erosion , square
5552from skimage .segmentation import find_boundaries
@@ -283,6 +280,30 @@ def _sanitise_na_color(na_color: ColorLike | None) -> tuple[str, bool]:
283280 raise ValueError (f"Invalid na_color value: { na_color } " )
284281
285282
283+ def _get_centroid_of_pathpatch (pathpatch : mpatches .PathPatch ) -> tuple [float , float ]:
284+ # Extract the vertices from the PathPatch
285+ path = pathpatch .get_path ()
286+ vertices = path .vertices
287+ x = vertices [:, 0 ]
288+ y = vertices [:, 1 ]
289+
290+ area = 0.5 * np .sum (x [:- 1 ] * y [1 :] - x [1 :] * y [:- 1 ])
291+
292+ # Calculate the centroid coordinates
293+ centroid_x = np .sum ((x [:- 1 ] + x [1 :]) * (x [:- 1 ] * y [1 :] - x [1 :] * y [:- 1 ])) / (6 * area )
294+ centroid_y = np .sum ((y [:- 1 ] + y [1 :]) * (x [:- 1 ] * y [1 :] - x [1 :] * y [:- 1 ])) / (6 * area )
295+
296+ return centroid_x , centroid_y
297+
298+
299+ def _scale_pathpatch_around_centroid (pathpatch : mpatches .PathPatch , scale_factor : float ) -> None :
300+
301+ centroid = _get_centroid_of_pathpatch (pathpatch )
302+ vertices = pathpatch .get_path ().vertices
303+ scaled_vertices = np .array ([centroid + (vertex - centroid ) * scale_factor for vertex in vertices ])
304+ pathpatch .get_path ().vertices = scaled_vertices
305+
306+
286307def _get_collection_shape (
287308 shapes : list [GeoDataFrame ],
288309 c : Any ,
@@ -352,63 +373,64 @@ def _get_collection_shape(
352373 outline_c = outline_c * fill_c .shape [0 ]
353374
354375 shapes_df = pd .DataFrame (shapes , copy = True )
355-
356- # remove empty points/polygons
357376 shapes_df = shapes_df [shapes_df ["geometry" ].apply (lambda geom : not geom .is_empty )]
358-
359- # reset index of shapes_df for case of spatial query
360377 shapes_df = shapes_df .reset_index (drop = True )
361378
362- rows = []
363-
364- def assign_fill_and_outline_to_row (
365- shapes : list [GeoDataFrame ], fill_c : list [Any ], outline_c : list [Any ], row : pd .Series , idx : int
379+ def _assign_fill_and_outline_to_row (
380+ fill_c : list [Any ], outline_c : list [Any ], row : dict [str , Any ], idx : int , is_multiple_shapes : bool
366381 ) -> None :
367382 try :
368- if len ( shapes ) > 1 and len (fill_c ) == 1 :
369- row ["fill_c" ] = fill_c
370- row ["outline_c" ] = outline_c
383+ if is_multiple_shapes and len (fill_c ) == 1 :
384+ row ["fill_c" ] = fill_c [ 0 ]
385+ row ["outline_c" ] = outline_c [ 0 ]
371386 else :
372387 row ["fill_c" ] = fill_c [idx ]
373388 row ["outline_c" ] = outline_c [idx ]
374389 except IndexError as e :
375- raise IndexError ("Could not assign fill and outline colors due to a mismatch in row-numbers." ) from e
376-
377- # Match colors to the geometry, potentially expanding the row in case of
378- # multipolygons
379- for idx , row in shapes_df .iterrows ():
380- geom = row ["geometry" ]
381- if geom .geom_type == "Polygon" :
382- row = row .to_dict ()
383- coords = np .array (geom .exterior .coords )
384- centroid = np .mean (coords , axis = 0 )
385- scaled_coords = [(centroid + (np .array (coord ) - centroid ) * s ).tolist () for coord in geom .exterior .coords ]
386- row ["geometry" ] = mplp .Polygon (scaled_coords , closed = True )
387- assign_fill_and_outline_to_row (shapes , fill_c , outline_c , row , idx )
388- rows .append (row )
389-
390- elif geom .geom_type == "MultiPolygon" :
391- # mp = _make_patch_from_multipolygon(geom)
392- for polygon in geom .geoms :
393- mp_copy = row .to_dict ()
394- coords = np .array (polygon .exterior .coords )
395- centroid = np .mean (coords , axis = 0 )
396- scaled_coords = [(centroid + (coord - centroid ) * s ).tolist () for coord in coords ]
397- mp_copy ["geometry" ] = mplp .Polygon (scaled_coords , closed = True )
398- assign_fill_and_outline_to_row (shapes , fill_c , outline_c , mp_copy , idx )
399- rows .append (mp_copy )
400-
401- elif geom .geom_type == "Point" :
402- row = row .to_dict ()
403- scaled_radius = row ["radius" ] * s
404- row ["geometry" ] = mplp .Circle (
405- (geom .x , geom .y ), radius = scaled_radius
406- ) # Circle is always scaled from its center
407- assign_fill_and_outline_to_row (shapes , fill_c , outline_c , row , idx )
408- rows .append (row )
409-
410- patches = pd .DataFrame (rows )
411-
390+ raise IndexError ("Could not assign fill and outline colors due to a mismatch in row numbers." ) from e
391+
392+ def _process_polygon (row : pd .Series , s : float ) -> dict [str , Any ]:
393+ coords = np .array (row ["geometry" ].exterior .coords )
394+ centroid = np .mean (coords , axis = 0 )
395+ scaled_coords = (centroid + (coords - centroid ) * s ).tolist ()
396+ return {** row .to_dict (), "geometry" : mpatches .Polygon (scaled_coords , closed = True )}
397+
398+ def _process_multipolygon (row : pd .Series , s : float ) -> list [dict [str , Any ]]:
399+ mp = _make_patch_from_multipolygon (row ["geometry" ])
400+ row_dict = row .to_dict ()
401+ for m in mp :
402+ _scale_pathpatch_around_centroid (m , s )
403+
404+ return [{** row_dict , "geometry" : m } for m in mp ]
405+
406+ def _process_point (row : pd .Series , s : float ) -> dict [str , Any ]:
407+ return {
408+ ** row .to_dict (),
409+ "geometry" : mpatches .Circle ((row ["geometry" ].x , row ["geometry" ].y ), radius = row ["radius" ] * s ),
410+ }
411+
412+ def _create_patches (shapes_df : GeoDataFrame , fill_c : list [Any ], outline_c : list [Any ], s : float ) -> pd .DataFrame :
413+ rows = []
414+ is_multiple_shapes = len (shapes_df ) > 1
415+
416+ for idx , row in shapes_df .iterrows ():
417+ geom_type = row ["geometry" ].geom_type
418+ processed_rows = []
419+
420+ if geom_type == "Polygon" :
421+ processed_rows .append (_process_polygon (row , s ))
422+ elif geom_type == "MultiPolygon" :
423+ processed_rows .extend (_process_multipolygon (row , s ))
424+ elif geom_type == "Point" :
425+ processed_rows .append (_process_point (row , s ))
426+
427+ for processed_row in processed_rows :
428+ _assign_fill_and_outline_to_row (fill_c , outline_c , processed_row , idx , is_multiple_shapes )
429+ rows .append (processed_row )
430+
431+ return pd .DataFrame (rows )
432+
433+ patches = _create_patches (shapes_df , fill_c , outline_c , s )
412434 return PatchCollection (
413435 patches ["geometry" ].values .tolist (),
414436 snap = False ,
@@ -788,7 +810,7 @@ def _map_color_seg(
788810 cell_id = np .array (cell_id )
789811 if color_vector is not None and isinstance (color_vector .dtype , pd .CategoricalDtype ):
790812 # users wants to plot a categorical column
791- if isinstance ( na_color , tuple ) and len ( na_color ) == 4 and np .any (color_source_vector .isna ()):
813+ if np .any (color_source_vector .isna ()):
792814 cell_id [color_source_vector .isna ()] = 0
793815 val_im : ArrayLike = map_array (seg , cell_id , color_vector .codes + 1 )
794816 cols = colors .to_rgba_array (color_vector .categories )
@@ -873,9 +895,9 @@ def _modify_categorical_color_mapping(
873895 modified_mapping = {key : mapping [key ] for key in mapping if key in groups or key == "NaN" }
874896 elif len (palette ) == len (groups ) and isinstance (groups , list ) and isinstance (palette , list ):
875897 modified_mapping = dict (zip (groups , palette ))
876-
877898 else :
878899 raise ValueError (f"Expected palette to be of length `{ len (groups )} `, found `{ len (palette )} `." )
900+
879901 return modified_mapping
880902
881903
@@ -891,7 +913,7 @@ def _get_default_categorial_color_mapping(
891913 palette = default_102
892914 else :
893915 palette = ["grey" for _ in range (len_cat )]
894- logger .info ("input has more than 103 categories. Uniform " " 'grey' color will be used for all categories." )
916+ logger .info ("input has more than 103 categories. Uniform 'grey' color will be used for all categories." )
895917
896918 return {cat : to_hex (to_rgba (col )[:3 ]) for cat , col in zip (color_source_vector .categories , palette [:len_cat ])}
897919
@@ -922,54 +944,6 @@ def _get_categorical_color_mapping(
922944 return _modify_categorical_color_mapping (base_mapping , groups , palette )
923945
924946
925- def _get_palette (
926- categories : Sequence [Any ],
927- adata : AnnData | None = None ,
928- cluster_key : None | str = None ,
929- palette : ListedColormap | str | list [str ] | None = None ,
930- alpha : float = 1.0 ,
931- ) -> Mapping [str , str ] | None :
932- palette = None if isinstance (palette , list ) and palette [0 ] is None else palette
933- if adata is not None and palette is None :
934- try :
935- palette = adata .uns [f"{ cluster_key } _colors" ] # type: ignore[arg-type]
936- if len (palette ) != len (categories ):
937- raise ValueError (
938- f"Expected palette to be of length `{ len (categories )} `, found `{ len (palette )} `. "
939- + f"Removing the colors in `adata.uns` with `adata.uns.pop('{ cluster_key } _colors')` may help."
940- )
941- return {cat : to_hex (to_rgba (col )[:3 ]) for cat , col in zip (categories , palette )}
942- except KeyError as e :
943- logger .warning (e )
944- return None
945-
946- len_cat = len (categories )
947-
948- if palette is None :
949- if len_cat <= 20 :
950- palette = default_20
951- elif len_cat <= 28 :
952- palette = default_28
953- elif len_cat <= len (default_102 ): # 103 colors
954- palette = default_102
955- else :
956- palette = ["grey" for _ in range (len_cat )]
957- logger .info ("input has more than 103 categories. Uniform " "'grey' color will be used for all categories." )
958- return {cat : to_hex (to_rgba (col )[:3 ]) for cat , col in zip (categories , palette [:len_cat ])}
959-
960- if isinstance (palette , str ):
961- cmap = ListedColormap ([palette ])
962- elif isinstance (palette , list ):
963- cmap = ListedColormap (palette )
964- elif isinstance (palette , ListedColormap ):
965- cmap = palette
966- else :
967- raise TypeError (f"Palette is { type (palette )} but should be string or list." )
968- palette = [to_hex (np .round (x , 5 )) for x in cmap (np .linspace (0 , 1 , len_cat ), alpha = alpha )]
969-
970- return dict (zip (categories , palette ))
971-
972-
973947def _maybe_set_colors (
974948 source : AnnData , target : AnnData , key : str , palette : str | ListedColormap | Cycler | Sequence [Any ] | None = None
975949) -> None :
@@ -1137,34 +1111,6 @@ def save_fig(fig: Figure, path: str | Path, make_dir: bool = True, ext: str = "p
11371111 fig .savefig (path , ** kwargs )
11381112
11391113
1140- def _get_cs_element_map (
1141- element : str | Sequence [str ] | None ,
1142- element_map : Mapping [str , Any ],
1143- ) -> Mapping [str , str ]:
1144- """Get the mapping between the coordinate system and the class."""
1145- # from spatialdata.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel
1146- element = list (element_map .keys ())[0 ] if element is None else element
1147- element = [element ] if isinstance (element , str ) else element
1148- d = {}
1149- for e in element :
1150- cs = list (element_map [e ].attrs ["transform" ].keys ())[0 ]
1151- d [cs ] = e
1152- # model = get_model(element_map["blobs_labels"])
1153- # if model in [Image2DModel, Image3DModel, Labels2DModel, Labels3DModel]
1154- return d
1155-
1156-
1157- def _multiscale_to_image (sdata : sd .SpatialData ) -> sd .SpatialData :
1158- if sdata .images is None :
1159- raise ValueError ("No images found in the SpatialData object." )
1160-
1161- for k , v in sdata .images .items ():
1162- if isinstance (v , msi .multiscale_spatial_image .DataTree ):
1163- sdata .images [k ] = Image2DModel .parse (v ["scale0" ].ds .to_array ().squeeze (axis = 0 ))
1164-
1165- return sdata
1166-
1167-
11681114def _get_linear_colormap (colors : list [str ], background : str ) -> list [LinearSegmentedColormap ]:
11691115 return [LinearSegmentedColormap .from_list (c , [background , c ], N = 256 ) for c in colors ]
11701116
@@ -1176,62 +1122,6 @@ def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap:
11761122 return ListedColormap (["black" ] + colors , N = len (colors ) + 1 )
11771123
11781124
1179- def _translate_image (
1180- image : DataArray ,
1181- translation : sd .transformations .transformations .Translation ,
1182- ) -> DataArray :
1183- shifts : dict [str , int ] = {axis : int (translation .translation [idx ]) for idx , axis in enumerate (translation .axes )}
1184- img = image .values .copy ()
1185- # for yx images (important for rasterized MultiscaleImages as labels)
1186- expanded_dims = False
1187- if len (img .shape ) == 2 :
1188- img = np .expand_dims (img , axis = 0 )
1189- expanded_dims = True
1190-
1191- shifted_channels = []
1192-
1193- # split channels, shift axes individually, them recombine
1194- if len (img .shape ) == 3 :
1195- for c in range (img .shape [0 ]):
1196- channel = img [c , :, :]
1197-
1198- # iterates over [x, y]
1199- for axis , shift in shifts .items ():
1200- pad_x , pad_y = (0 , 0 ), (0 , 0 )
1201- if axis == "x" and shift > 0 :
1202- pad_x = (abs (shift ), 0 )
1203- elif axis == "x" and shift < 0 :
1204- pad_x = (0 , abs (shift ))
1205-
1206- if axis == "y" and shift > 0 :
1207- pad_y = (abs (shift ), 0 )
1208- elif axis == "y" and shift < 0 :
1209- pad_y = (0 , abs (shift ))
1210-
1211- channel = np .pad (channel , (pad_y , pad_x ), mode = "constant" )
1212-
1213- shifted_channels .append (channel )
1214-
1215- if expanded_dims :
1216- return Labels2DModel .parse (
1217- np .array (shifted_channels [0 ]),
1218- dims = ["y" , "x" ],
1219- transformations = image .attrs ["transform" ],
1220- )
1221- return Image2DModel .parse (
1222- np .array (shifted_channels ),
1223- dims = ["c" , "y" , "x" ],
1224- transformations = image .attrs ["transform" ],
1225- )
1226-
1227-
1228- def _convert_polygon_to_linestrings (polygon : Polygon ) -> list [LineString ]:
1229- b = polygon .boundary .coords
1230- linestrings = [LineString (b [k : k + 2 ]) for k in range (len (b ) - 1 )]
1231-
1232- return [list (ls .coords ) for ls in linestrings ]
1233-
1234-
12351125def _split_multipolygon_into_outer_and_inner (mp : shapely .MultiPolygon ): # type: ignore
12361126 # https://stackoverflow.com/a/21922058
12371127
0 commit comments