"""
This module provides functionalities for plotting statistical data
related to a given variable or dataset. 
    - `plot_statistics` is designed to plot raw statistical measures
      (mean, maximum, minimum, and optional standard deviation) of a
      variable across a series of x-axis values. It allows for
      customization of plot labels, title, and saving the plot to a file.
    - `plot_bin_statistics` extends these capabilities to binned data,
      offering a way to visualize binned statistics (mean, maximum, minimum)
      along with their respective standard deviations. This function also 
      supports label and title customization, as well as saving the plot to 
      a specified path.
"""
from typing import Optional, Dict, Any
import numpy as np
import matplotlib.pyplot as plt
from mhkit.utils.type_handling import to_numeric_array
# pylint: disable=R0914
[docs]
def plot_statistics(
    x: np.ndarray,
    y_mean: np.ndarray,
    y_max: np.ndarray,
    y_min: np.ndarray,
    y_stdev: Optional[np.ndarray] = None,
    **kwargs: Dict[str, Any],
) -> plt.Axes:
    """
    Plot showing standard raw statistics of variable
    Parameters
    -----------
    x : numpy array
        Array of x-axis values
    y_mean : numpy array
        Array of mean statistical values of variable
    y_max : numpy array
        Array of max statistical values of variable
    y_min : numpy array
        Array of min statistical values of variable
    y_stdev : numpy array, optional
        Array of standard deviation statistical values of variable
    **kwargs : optional
        x_label : string
            x axis label for plot
        y_label : string
            y axis label for plot
        title : string, optional
            Title for plot
        save_path : string
            Path and filename to save figure.
    Returns
    --------
    ax : matplotlib pyplot axes
    """
    if y_stdev is None:
        y_stdev = []
    input_variables = [x, y_mean, y_max, y_min, y_stdev]
    variable_names = ["x", "y_mean", "y_max", "y_min", "y_stdev"]
    # Convert each input variable to a numeric array, ensuring all are numeric
    for i, variable in enumerate(input_variables):
        input_variables[i] = to_numeric_array(variable, variable_names[i])
    x, y_mean, y_max, y_min, y_stdev = input_variables
    x_label = kwargs.get("x_label", None)
    y_label = kwargs.get("y_label", None)
    title = kwargs.get("title", None)
    save_path = kwargs.get("save_path", None)
    if not isinstance(x_label, (str, type(None))):
        raise TypeError(f"x_label must be of type str. Got: {type(x_label)}")
    if not isinstance(y_label, (str, type(None))):
        raise TypeError(f"y_label must be of type str. Got: {type(y_label)}")
    if not isinstance(title, (str, type(None))):
        raise TypeError(f"title must be of type str. Got: {type(title)}")
    if not isinstance(save_path, (str, type(None))):
        raise TypeError(f"save_path must be of type str. Got: {type(save_path)}")
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(x, y_max, "^", label="max", mfc="none")
    ax.plot(x, y_mean, "o", label="mean", mfc="none")
    ax.plot(x, y_min, "v", label="min", mfc="none")
    if len(y_stdev) > 0:
        ax.plot(x, y_stdev, "+", label="stdev", c="m")
    ax.grid(alpha=0.4)
    ax.legend(loc="best")
    if x_label:
        ax.set_xlabel(x_label)
    if y_label:
        ax.set_ylabel(y_label)
    if title:
        ax.set_title(title)
    fig.tight_layout()
    if save_path is None:
        plt.show()
    else:
        fig.savefig(save_path)
        plt.close()
    return ax 
# pylint: disable=R0913
# pylint: disable=R0917
[docs]
def plot_bin_statistics(
    bin_centers: np.ndarray,
    bin_mean: np.ndarray,
    bin_max: np.ndarray,
    bin_min: np.ndarray,
    bin_mean_std: np.ndarray,
    bin_max_std: np.ndarray,
    bin_min_std: np.ndarray,
    **kwargs: Dict[str, Any],
) -> plt.Axes:
    """
    Plot showing standard binned statistics of single variable
    Parameters
    -----------
    bin_centers : numpy array
        x-axis bin center values
    bin_mean : numpy array
        Binned mean statistical values of variable
    bin_max : numpy array
        Binned max statistical values of variable
    bin_min : numpy array
        Binned min statistical values of variable
    bin_mean_std : numpy array
        Standard deviations of mean binned statistics
    bin_max_std : numpy array
        Standard deviations of max binned statistics
    bin_min_std : numpy array
        Standard deviations of min binned statistics
    **kwargs : optional
        x_label : string
            x axis label for plot
        y_label : string
            y axis label for plot
        title : string, optional
            Title for plot
        save_path : string
            Path and filename to save figure.
    Returns
    --------
    ax : matplotlib pyplot axes
    """
    input_variables = [
        bin_centers,
        bin_mean,
        bin_max,
        bin_min,
        bin_mean_std,
        bin_max_std,
        bin_min_std,
    ]
    variable_names = [
        "bin_centers",
        "bin_mean",
        "bin_max",
        "bin_min",
        "bin_mean_std",
        "bin_max_std",
        "bin_min_std",
    ]
    # Convert each input variable to a numeric array, ensuring all are numeric
    for i, variable in enumerate(input_variables):
        input_variables[i] = to_numeric_array(variable, variable_names[i])
    (
        bin_centers,
        bin_mean,
        bin_max,
        bin_min,
        bin_mean_std,
        bin_max_std,
        bin_min_std,
    ) = input_variables
    x_label = kwargs.get("x_label", None)
    y_label = kwargs.get("y_label", None)
    title = kwargs.get("title", None)
    save_path = kwargs.get("save_path", None)
    if not isinstance(x_label, (str, type(None))):
        raise TypeError(f"x_label must be of type str. Got: {type(x_label)}")
    if not isinstance(y_label, (str, type(None))):
        raise TypeError(f"y_label must be of type str. Got: {type(y_label)}")
    if not isinstance(title, (str, type(None))):
        raise TypeError(f"title must be of type str. Got: {type(title)}")
    if not isinstance(save_path, (str, type(None))):
        raise TypeError(f"save_path must be of type str. Got: {type(save_path)}")
    fig, ax = plt.subplots(figsize=(7, 5))
    ax.errorbar(
        bin_centers,
        bin_max,
        marker="^",
        mfc="none",
        yerr=bin_max_std,
        capsize=4,
        label="max",
    )
    ax.errorbar(
        bin_centers,
        bin_mean,
        marker="o",
        mfc="none",
        yerr=bin_mean_std,
        capsize=4,
        label="mean",
    )
    ax.errorbar(
        bin_centers,
        bin_min,
        marker="v",
        mfc="none",
        yerr=bin_min_std,
        capsize=4,
        label="min",
    )
    ax.grid(alpha=0.5)
    ax.legend(loc="best")
    if x_label:
        ax.set_xlabel(x_label)
    if y_label:
        ax.set_ylabel(y_label)
    if title:
        ax.set_title(title)
    fig.tight_layout()
    if save_path is None:
        plt.show()
    else:
        fig.savefig(save_path)
        plt.close()
    return ax