From dfd0d7ad2578bd2b36edb7e75620a68287558343 Mon Sep 17 00:00:00 2001 From: Alberto Soutullo Date: Tue, 1 Oct 2024 11:58:43 +0200 Subject: [PATCH] Refactor function to also show min and max values in boxplots --- src/plotting/plotter.py | 46 +++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/src/plotting/plotter.py b/src/plotting/plotter.py index 00359bc..159ab51 100644 --- a/src/plotting/plotter.py +++ b/src/plotting/plotter.py @@ -67,7 +67,11 @@ class Plotter: box_plot.xaxis.set_tick_params(rotation=45) box_plot.legend(loc='upper right', bbox_to_anchor=(1, 1)) - self._add_median_labels(box_plot) + self._add_stat_labels(box_plot) + show_min_max = plot_specs.get("show_min_max", False) + if show_min_max: + self._add_stat_labels(box_plot, value_type="min") + self._add_stat_labels(box_plot, value_type="max") def _create_subplot_paths_group(self, plot_specs: Dict) -> List: subplot_path = [[f"{folder}{data}" for folder in plot_specs["folder"]] for data in @@ -75,29 +79,41 @@ class Plotter: return subplot_path - def _add_median_labels(self, ax: plt.Axes, fmt: str = ".3f") -> None: - # https://stackoverflow.com/a/63295846 - """Add text labels to the median lines of a seaborn boxplot. + def _add_stat_labels(self, ax: plt.Axes, fmt: str = ".3f", value_type: str = "median") -> None: + # Refactor from https://stackoverflow.com/a/63295846 + """ + Add text labels to the median, minimum, or maximum lines of a seaborn boxplot. Args: - ax: plt.Axes, e.g. the return value of sns.boxplot() - fmt: format string for the median value + ax: plt.Axes, e.g., the return value of sns.boxplot() + fmt: Format string for the value (e.g., min/max/median). + value_type: The type of value to label. Can be 'median', 'min', or 'max'. """ lines = ax.get_lines() - boxes = [c for c in ax.get_children() if "Patch" in str(c)] + boxes = [c for c in ax.get_children() if "Patch" in str(c)] # Get box patches start = 4 - if not boxes: # seaborn v0.13 => fill=False => no patches => +1 line + if not boxes: # seaborn v0.13 or above (no patches => need to shift index) boxes = [c for c in ax.get_lines() if len(c.get_xdata()) == 5] start += 1 lines_per_box = len(lines) // len(boxes) - for median in lines[start::lines_per_box]: - x, y = (data.mean() for data in median.get_data()) + + if value_type == "median": + line_idx = start + elif value_type == "min": + line_idx = start - 2 # min line comes 2 positions before the median + elif value_type == "max": + line_idx = start - 1 # max line comes 1 position before the median + else: + raise ValueError("Invalid value_type. Must be 'min', 'max', or 'median'.") + + for value_line in lines[line_idx::lines_per_box]: + x, y = (data.mean() for data in value_line.get_data()) # choose value depending on horizontal or vertical plot orientation - value = x if len(set(median.get_xdata())) == 1 else y - text = ax.text(x, y, f'{value/1000:{fmt}}', ha='center', va='center', - fontweight='bold', color='white') - # create median-colored border around white text for contrast + value = x if len(set(value_line.get_xdata())) == 1 else y + text = ax.text(x, y, f'{value / 1000:{fmt}}', ha='center', va='center', + fontweight='bold', color='white', size=10) + # create colored border around white text for contrast text.set_path_effects([ - path_effects.Stroke(linewidth=3, foreground=median.get_color()), + path_effects.Stroke(linewidth=3, foreground=value_line.get_color()), path_effects.Normal(), ])