|
16 | 16 | ) |
17 | 17 | from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES |
18 | 18 | from supervision.detection.core import Detections |
19 | | -from supervision.detection.utils import clip_boxes, mask_to_polygons, spread_out_boxes |
| 19 | +from supervision.detection.utils import ( |
| 20 | + clip_boxes, |
| 21 | + mask_to_polygons, |
| 22 | + polygon_to_mask, |
| 23 | + spread_out_boxes, |
| 24 | + xyxy_to_polygons, |
| 25 | +) |
20 | 26 | from supervision.draw.color import Color, ColorPalette |
21 | | -from supervision.draw.utils import draw_polygon |
22 | | -from supervision.geometry.core import Position |
| 27 | +from supervision.draw.utils import draw_polygon, draw_rounded_rectangle, draw_text |
| 28 | +from supervision.geometry.core import Point, Position, Rect |
23 | 29 | from supervision.utils.conversion import ( |
24 | 30 | ensure_cv2_image_for_annotation, |
25 | 31 | ensure_pil_image_for_annotation, |
@@ -2683,3 +2689,242 @@ def annotate(self, scene: ImageType, detections: Detections) -> ImageType: |
2683 | 2689 |
|
2684 | 2690 | np.copyto(scene, colored_mask) |
2685 | 2691 | return scene |
| 2692 | + |
| 2693 | + |
| 2694 | +class ComparisonAnnotator: |
| 2695 | + """ |
| 2696 | + Highlights the differences between two sets of detections. |
| 2697 | + Useful for comparing results from two different models, or the difference |
| 2698 | + between a ground truth and a prediction. |
| 2699 | +
|
| 2700 | + If present, uses the oriented bounding box data. |
| 2701 | + Otherwise, if present, uses a mask. |
| 2702 | + Otherwise, uses the bounding box data. |
| 2703 | + """ |
| 2704 | + |
| 2705 | + def __init__( |
| 2706 | + self, |
| 2707 | + color_1: Color = Color.RED, |
| 2708 | + color_2: Color = Color.GREEN, |
| 2709 | + color_overlap: Color = Color.BLUE, |
| 2710 | + *, |
| 2711 | + opacity: float = 0.75, |
| 2712 | + label_1: str = "", |
| 2713 | + label_2: str = "", |
| 2714 | + label_overlap: str = "", |
| 2715 | + label_scale: float = 1.0, |
| 2716 | + ): |
| 2717 | + """ |
| 2718 | + Args: |
| 2719 | + color_1 (Color): Color of areas only present in the first set of |
| 2720 | + detections. |
| 2721 | + color_2 (Color): Color of areas only present in the second set of |
| 2722 | + detections. |
| 2723 | + color_overlap (Color): Color of areas present in both sets of detections. |
| 2724 | + opacity (float): Annotator opacity, from `0` to `1`. |
| 2725 | + label_1 (str): Label for the first set of detections. |
| 2726 | + label_2 (str): Label for the second set of detections. |
| 2727 | + label_overlap (str): Label for areas present in both sets of detections. |
| 2728 | + label_scale (float): Controls how large the labels are. |
| 2729 | + """ |
| 2730 | + |
| 2731 | + self.color_1 = color_1 |
| 2732 | + self.color_2 = color_2 |
| 2733 | + self.color_overlap = color_overlap |
| 2734 | + |
| 2735 | + self.opacity = opacity |
| 2736 | + self.label_1 = label_1 |
| 2737 | + self.label_2 = label_2 |
| 2738 | + self.label_overlap = label_overlap |
| 2739 | + self.label_scale = label_scale |
| 2740 | + self.text_thickness = int(self.label_scale + 1.2) |
| 2741 | + |
| 2742 | + @ensure_cv2_image_for_annotation |
| 2743 | + def annotate( |
| 2744 | + self, scene: ImageType, detections_1: Detections, detections_2: Detections |
| 2745 | + ) -> ImageType: |
| 2746 | + """ |
| 2747 | + Highlights the differences between two sets of detections. |
| 2748 | +
|
| 2749 | + Args: |
| 2750 | + scene (ImageType): The image where detections will be drawn. |
| 2751 | + `ImageType` is a flexible type, accepting either `numpy.ndarray` |
| 2752 | + or `PIL.Image.Image`. |
| 2753 | + detections_1 (Detections): The first set of detections or predictions. |
| 2754 | + detections_2 (Detections): The second set of detections to compare or |
| 2755 | + ground truth. |
| 2756 | +
|
| 2757 | + Returns: |
| 2758 | + The annotated image. |
| 2759 | + """ |
| 2760 | + assert isinstance(scene, np.ndarray) |
| 2761 | + if detections_1.is_empty() and detections_2.is_empty(): |
| 2762 | + return scene |
| 2763 | + |
| 2764 | + use_obb = self._use_obb(detections_1, detections_2) |
| 2765 | + use_mask = self._use_mask(detections_1, detections_2) |
| 2766 | + |
| 2767 | + if use_obb: |
| 2768 | + mask_1 = self._mask_from_obb(scene, detections_1) |
| 2769 | + mask_2 = self._mask_from_obb(scene, detections_2) |
| 2770 | + |
| 2771 | + elif use_mask: |
| 2772 | + mask_1 = self._mask_from_mask(scene, detections_1) |
| 2773 | + mask_2 = self._mask_from_mask(scene, detections_2) |
| 2774 | + |
| 2775 | + else: |
| 2776 | + mask_1 = self._mask_from_xyxy(scene, detections_1) |
| 2777 | + mask_2 = self._mask_from_xyxy(scene, detections_2) |
| 2778 | + |
| 2779 | + mask_overlap = mask_1 & mask_2 |
| 2780 | + mask_1 = mask_1 & ~mask_overlap |
| 2781 | + mask_2 = mask_2 & ~mask_overlap |
| 2782 | + |
| 2783 | + color_layer = np.zeros_like(scene, dtype=np.uint8) |
| 2784 | + color_layer[mask_overlap] = self.color_overlap.as_bgr() |
| 2785 | + color_layer[mask_1] = self.color_1.as_bgr() |
| 2786 | + color_layer[mask_2] = self.color_2.as_bgr() |
| 2787 | + |
| 2788 | + scene[mask_overlap] = (1 - self.opacity) * scene[ |
| 2789 | + mask_overlap |
| 2790 | + ] + self.opacity * color_layer[mask_overlap] |
| 2791 | + scene[mask_1] = (1 - self.opacity) * scene[mask_1] + self.opacity * color_layer[ |
| 2792 | + mask_1 |
| 2793 | + ] |
| 2794 | + scene[mask_2] = (1 - self.opacity) * scene[mask_2] + self.opacity * color_layer[ |
| 2795 | + mask_2 |
| 2796 | + ] |
| 2797 | + |
| 2798 | + self._draw_labels(scene) |
| 2799 | + |
| 2800 | + return scene |
| 2801 | + |
| 2802 | + @staticmethod |
| 2803 | + def _use_obb(detections_1: Detections, detections_2: Detections) -> bool: |
| 2804 | + assert not detections_1.is_empty() or not detections_2.is_empty() |
| 2805 | + is_obb_1 = ORIENTED_BOX_COORDINATES in detections_1.data |
| 2806 | + is_obb_2 = ORIENTED_BOX_COORDINATES in detections_2.data |
| 2807 | + return ( |
| 2808 | + (is_obb_1 and is_obb_2) |
| 2809 | + or (is_obb_1 and detections_2.is_empty()) |
| 2810 | + or (detections_1.is_empty() and is_obb_2) |
| 2811 | + ) |
| 2812 | + |
| 2813 | + @staticmethod |
| 2814 | + def _use_mask(detections_1: Detections, detections_2: Detections) -> bool: |
| 2815 | + assert not detections_1.is_empty() or not detections_2.is_empty() |
| 2816 | + is_mask_1 = detections_1.mask is not None |
| 2817 | + is_mask_2 = detections_2.mask is not None |
| 2818 | + return ( |
| 2819 | + (is_mask_1 and is_mask_2) |
| 2820 | + or (is_mask_1 and detections_2.is_empty()) |
| 2821 | + or (detections_1.is_empty() and is_mask_2) |
| 2822 | + ) |
| 2823 | + |
| 2824 | + @staticmethod |
| 2825 | + def _mask_from_xyxy(scene: np.ndarray, detections: Detections) -> np.ndarray: |
| 2826 | + mask = np.zeros(scene.shape[:2], dtype=np.bool_) |
| 2827 | + if detections.is_empty(): |
| 2828 | + return mask |
| 2829 | + |
| 2830 | + resolution_wh = scene.shape[1], scene.shape[0] |
| 2831 | + polygons = xyxy_to_polygons(detections.xyxy) |
| 2832 | + |
| 2833 | + for polygon in polygons: |
| 2834 | + polygon_mask = polygon_to_mask(polygon, resolution_wh=resolution_wh) |
| 2835 | + mask |= polygon_mask.astype(np.bool_) |
| 2836 | + return mask |
| 2837 | + |
| 2838 | + @staticmethod |
| 2839 | + def _mask_from_obb(scene: np.ndarray, detections: Detections) -> np.ndarray: |
| 2840 | + mask = np.zeros(scene.shape[:2], dtype=np.bool_) |
| 2841 | + if detections.is_empty(): |
| 2842 | + return mask |
| 2843 | + |
| 2844 | + resolution_wh = scene.shape[1], scene.shape[0] |
| 2845 | + |
| 2846 | + for polygon in detections.data[ORIENTED_BOX_COORDINATES]: |
| 2847 | + polygon_mask = polygon_to_mask(polygon, resolution_wh=resolution_wh) |
| 2848 | + mask |= polygon_mask.astype(np.bool_) |
| 2849 | + return mask |
| 2850 | + |
| 2851 | + @staticmethod |
| 2852 | + def _mask_from_mask(scene: np.ndarray, detections: Detections) -> np.ndarray: |
| 2853 | + mask = np.zeros(scene.shape[:2], dtype=np.bool_) |
| 2854 | + if detections.is_empty(): |
| 2855 | + return mask |
| 2856 | + assert detections.mask is not None |
| 2857 | + |
| 2858 | + for detections_mask in detections.mask: |
| 2859 | + mask |= detections_mask.astype(np.bool_) |
| 2860 | + return mask |
| 2861 | + |
| 2862 | + def _draw_labels(self, scene: np.ndarray) -> None: |
| 2863 | + """ |
| 2864 | + Draw the labels, explaining what each color represents, with automatically |
| 2865 | + computed positions. |
| 2866 | +
|
| 2867 | + Args: |
| 2868 | + scene (np.ndarray): The image where the labels will be drawn. |
| 2869 | + """ |
| 2870 | + margin = int(50 * self.label_scale) |
| 2871 | + gap = int(40 * self.label_scale) |
| 2872 | + y0 = int(50 * self.label_scale) |
| 2873 | + height = int(50 * self.label_scale) |
| 2874 | + |
| 2875 | + marker_size = int(20 * self.label_scale) |
| 2876 | + padding = int(10 * self.label_scale) |
| 2877 | + text_box_corner_radius = int(10 * self.label_scale) |
| 2878 | + marker_corner_radius = int(4 * self.label_scale) |
| 2879 | + text_scale = self.label_scale |
| 2880 | + |
| 2881 | + label_color_pairs = [ |
| 2882 | + (self.label_1, self.color_1), |
| 2883 | + (self.label_2, self.color_2), |
| 2884 | + (self.label_overlap, self.color_overlap), |
| 2885 | + ] |
| 2886 | + |
| 2887 | + x0 = margin |
| 2888 | + for text, color in label_color_pairs: |
| 2889 | + if not text: |
| 2890 | + continue |
| 2891 | + |
| 2892 | + (text_w, _) = cv2.getTextSize( |
| 2893 | + text=text, |
| 2894 | + fontFace=CV2_FONT, |
| 2895 | + fontScale=self.label_scale, |
| 2896 | + thickness=self.text_thickness, |
| 2897 | + )[0] |
| 2898 | + |
| 2899 | + width = text_w + marker_size + padding * 4 |
| 2900 | + center_x = x0 + width // 2 |
| 2901 | + center_y = y0 + height // 2 |
| 2902 | + |
| 2903 | + draw_rounded_rectangle( |
| 2904 | + scene=scene, |
| 2905 | + rect=Rect(x=x0, y=y0, width=width, height=height), |
| 2906 | + color=Color.WHITE, |
| 2907 | + border_radius=text_box_corner_radius, |
| 2908 | + ) |
| 2909 | + |
| 2910 | + draw_rounded_rectangle( |
| 2911 | + scene=scene, |
| 2912 | + rect=Rect( |
| 2913 | + x=x0 + padding, |
| 2914 | + y=center_y - marker_size / 2, |
| 2915 | + width=marker_size, |
| 2916 | + height=marker_size, |
| 2917 | + ), |
| 2918 | + color=color, |
| 2919 | + border_radius=marker_corner_radius, |
| 2920 | + ) |
| 2921 | + |
| 2922 | + draw_text( |
| 2923 | + scene, |
| 2924 | + text, |
| 2925 | + text_anchor=Point(x=center_x + marker_size, y=center_y), |
| 2926 | + text_scale=text_scale, |
| 2927 | + text_thickness=self.text_thickness, |
| 2928 | + ) |
| 2929 | + |
| 2930 | + x0 += width + gap |
0 commit comments