Refactor function to also show min and max values in boxplots

This commit is contained in:
Alberto Soutullo 2024-10-01 11:58:43 +02:00
parent eee6c8cf27
commit dfd0d7ad25
No known key found for this signature in database
GPG Key ID: A7CAC0D8343B0387
1 changed files with 31 additions and 15 deletions

View File

@ -67,7 +67,11 @@ class Plotter:
box_plot.xaxis.set_tick_params(rotation=45)
box_plot.legend(loc='upper right', bbox_to_anchor=(1, 1))
self._add_median_labels(box_plot)
self._add_stat_labels(box_plot)
show_min_max = plot_specs.get("show_min_max", False)
if show_min_max:
self._add_stat_labels(box_plot, value_type="min")
self._add_stat_labels(box_plot, value_type="max")
def _create_subplot_paths_group(self, plot_specs: Dict) -> List:
subplot_path = [[f"{folder}{data}" for folder in plot_specs["folder"]] for data in
@ -75,29 +79,41 @@ class Plotter:
return subplot_path
def _add_median_labels(self, ax: plt.Axes, fmt: str = ".3f") -> None:
# https://stackoverflow.com/a/63295846
"""Add text labels to the median lines of a seaborn boxplot.
def _add_stat_labels(self, ax: plt.Axes, fmt: str = ".3f", value_type: str = "median") -> None:
# Refactor from https://stackoverflow.com/a/63295846
"""
Add text labels to the median, minimum, or maximum lines of a seaborn boxplot.
Args:
ax: plt.Axes, e.g. the return value of sns.boxplot()
fmt: format string for the median value
ax: plt.Axes, e.g., the return value of sns.boxplot()
fmt: Format string for the value (e.g., min/max/median).
value_type: The type of value to label. Can be 'median', 'min', or 'max'.
"""
lines = ax.get_lines()
boxes = [c for c in ax.get_children() if "Patch" in str(c)]
boxes = [c for c in ax.get_children() if "Patch" in str(c)] # Get box patches
start = 4
if not boxes: # seaborn v0.13 => fill=False => no patches => +1 line
if not boxes: # seaborn v0.13 or above (no patches => need to shift index)
boxes = [c for c in ax.get_lines() if len(c.get_xdata()) == 5]
start += 1
lines_per_box = len(lines) // len(boxes)
for median in lines[start::lines_per_box]:
x, y = (data.mean() for data in median.get_data())
if value_type == "median":
line_idx = start
elif value_type == "min":
line_idx = start - 2 # min line comes 2 positions before the median
elif value_type == "max":
line_idx = start - 1 # max line comes 1 position before the median
else:
raise ValueError("Invalid value_type. Must be 'min', 'max', or 'median'.")
for value_line in lines[line_idx::lines_per_box]:
x, y = (data.mean() for data in value_line.get_data())
# choose value depending on horizontal or vertical plot orientation
value = x if len(set(median.get_xdata())) == 1 else y
text = ax.text(x, y, f'{value/1000:{fmt}}', ha='center', va='center',
fontweight='bold', color='white')
# create median-colored border around white text for contrast
value = x if len(set(value_line.get_xdata())) == 1 else y
text = ax.text(x, y, f'{value / 1000:{fmt}}', ha='center', va='center',
fontweight='bold', color='white', size=10)
# create colored border around white text for contrast
text.set_path_effects([
path_effects.Stroke(linewidth=3, foreground=median.get_color()),
path_effects.Stroke(linewidth=3, foreground=value_line.get_color()),
path_effects.Normal(),
])