Refactored plotter

This commit is contained in:
Alberto Soutullo 2024-03-14 13:04:26 +01:00
parent 098d6debaf
commit bed4592fd3
No known key found for this signature in database
GPG Key ID: A7CAC0D8343B0387
1 changed files with 31 additions and 17 deletions

View File

@ -8,6 +8,7 @@ import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
from typing import List, Dict
from matplotlib import ticker
from result import Err, Ok
# Project Imports
from src.utils.file_utils import get_files_from_folder_path, get_file_name_from_path
@ -29,36 +30,49 @@ class Plotter:
fig, axs = plt.subplots(nrows=1, ncols=len(plot_specs['data']), sharey='row', figsize=(15,15))
subplot_paths_group = self._create_subplot_paths_group(plot_specs)
self._insert_data_in_axs(subplot_paths_group, axs)
self._save_plot(plot_name)
def _insert_data_in_axs(self, subplot_paths_group: List, axs: np.ndarray):
for i, subplot_path_group in enumerate(subplot_paths_group):
subplot_title = subplot_path_group[1]
self._create_subplot(subplot_path_group[0], subplot_title, i, axs)
self._save_plot(plot_name)
subplot_df = self._create_subplot_df(subplot_path_group[0])
self._add_subplot_df_to_axs(subplot_df, i, subplot_title, axs)
def _save_plot(self, plot_name: str):
plt.tight_layout()
plt.savefig(plot_name)
plt.show()
def _create_subplot(self, subplot_paths_group: str, subplot_title: str, index: int, axs: np.ndarray):
def _create_subplot_df(self, subplot_paths_group: List):
subplot_df = pd.DataFrame()
for subplot_path in subplot_paths_group:
group_df = pd.DataFrame()
data_files_path = get_files_from_folder_path(subplot_path)
for file_path in data_files_path:
group_df = self._dump_file_mean_into_df(subplot_path+"/"+file_path, group_df)
group_df["class"] = subplot_path.split("/")[-2]
subplot_df = pd.concat([subplot_df, group_df])
subplot_df = self._concat_subplot_df(subplot_df, subplot_path)
subplot_df = pd.melt(subplot_df, id_vars=["class"])
self._add_subplot_to_axs(subplot_df, index, subplot_title, axs)
return subplot_df
def _add_subplot_to_axs(self, df: pd.DataFrame, index: int, subplot_title: str, axs: np.ndarray):
def _concat_subplot_df(self, subplot_df: pd.DataFrame, subplot_path: str):
group_df = pd.DataFrame()
data_files_path = get_files_from_folder_path(subplot_path)
for file_path in data_files_path:
result = self._dump_file_mean_into_df(subplot_path + "/" + file_path, group_df)
if result.is_err():
exit(1)
group_df = result.ok_value
group_df["class"] = subplot_path.split("/")[-2]
subplot_df = pd.concat([subplot_df, group_df])
return subplot_df
def _add_subplot_df_to_axs(self, df: pd.DataFrame, index: int, subplot_title: str, axs: np.ndarray):
box_plot = sns.boxplot(data=df, x="variable", y="value", hue="class", ax=axs[index],
showfliers=False)
@ -71,12 +85,12 @@ class Plotter:
box_plot.tick_params(labelbottom=True)
box_plot.xaxis.set_tick_params(rotation=45)
self.add_median_labels(box_plot)
self._add_median_labels(box_plot)
def _dump_file_mean_into_df(self, file_path: str, group_df: pd.DataFrame):
if not os.path.exists(file_path):
logger.error(f"Missing {file_path}")
return
return Err("")
file_name = get_file_name_from_path(file_path)
@ -85,7 +99,7 @@ class Plotter:
df_mean = pd.DataFrame(df_mean, columns=[file_name])
group_df = pd.concat([group_df, df_mean], axis=1)
return group_df
return Ok(group_df)
def _create_subplot_paths_group(self, plot_specs: Dict) -> List:
subplot_path = [
@ -94,7 +108,7 @@ class Plotter:
return subplot_path
def add_median_labels(self, ax: plt.Axes, fmt: str = ".3f") -> None:
def _add_median_labels(self, ax: plt.Axes, fmt: str = ".3f") -> None:
# https://stackoverflow.com/a/63295846
"""Add text labels to the median lines of a seaborn boxplot.