Source code for rational.utils.utils

import numpy as np
from numpy import zeros, inf
import matplotlib.pyplot as plt
from .warnings import RationalWarning, \
    RationalImportSeabornWarning, RationalImportScipyWarning


def _wrap_func(func, xdata, ydata, degrees):
    def func_wrapped(params):
        params1 = params[:degrees[0]+1]
        params2 = params[degrees[0]+1:]
        return func(xdata, params1, params2) - ydata
    return func_wrapped


def _curve_fit(f, xdata, ydata, degrees, version, p0=None, absolute_sigma=False,
               method=None, jac=None, **kwargs):
    from scipy.optimize.optimize import OptimizeWarning
    from scipy.optimize._lsq.least_squares import prepare_bounds
    from scipy.optimize.minpack import leastsq, _wrap_jac
    bounds = (-np.inf, np.inf)
    lb, ub = prepare_bounds(bounds, np.sum(degrees))
    if p0 is None:
        if version == "C":
            p0 = np.ones(np.sum(degrees)+2)
        else:
            p0 = np.ones(np.sum(degrees)+1)
    method = 'lm'

    ydata = np.asarray_chkfinite(ydata, float)

    if isinstance(xdata, (list, tuple, np.ndarray)):
        # `xdata` is passed straight to the user-defined `f`, so allow
        # non-array_like `xdata`.
        xdata = np.asarray_chkfinite(xdata, float)

    func = _wrap_func(f, xdata, ydata, degrees)  # Modification here  !!!
    if callable(jac):
        jac = _wrap_jac(jac, xdata, None)
    elif jac is None and method != 'lm':
        jac = '2-point'

    if 'args' in kwargs:
        raise ValueError("'args' is not a supported keyword argument.")

    # Remove full_output from kwargs, otherwise we're passing it in twice.
    return_full = kwargs.pop('full_output', False)
    res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)
    popt, pcov, infodict, errmsg, ier = res
    ysize = len(infodict['fvec'])
    cost = np.sum(infodict['fvec'] ** 2)
    if ier not in [1, 2, 3, 4]:
        raise RuntimeError("Optimal parameters not found: " + errmsg)

    warn_cov = False
    if pcov is None:
        # indeterminate covariance
        pcov = zeros((len(popt), len(popt)), dtype=float)
        pcov.fill(inf)
        warn_cov = True
    elif not absolute_sigma:
        if ysize > p0.size:
            s_sq = cost / (ysize - p0.size)
            pcov = pcov * s_sq
        else:
            pcov.fill(inf)
            warn_cov = True

    if warn_cov:
        RationalWarning.warn('Covariance of the parameters could not be estimated',
                             category=OptimizeWarning)

    if return_full:
        return popt, pcov, infodict, errmsg, ier
    else:
        return popt, pcov


def fit_rational_to_base_function(rational_func, ref_func, x, degrees=(5, 4), version="A"):
    y = ref_func(x)
    final_params = _curve_fit(rational_func, x, y, degrees=degrees, version=version,
                              maxfev=10000000)[0]
    return np.array(final_params[:degrees[0]+1]), np.array(final_params[degrees[0]+1:])


[docs]def find_closest_equivalent(rational_func, new_func, x): """ Compute the parameters a, b, c, and d that minimizes distance between the rational function and the other function on the range `x` Arguments: rational_func (callable): The rational function to consider.\n new_func (callable): The function you want to fit to rational.\n x (array): The range on which the curves of the functions are fitted together.\n Default ``True`` Returns: tuple: ((a, b, c, d), dist) with: \n a, b, c, d: the parameters to adjust the function \ (vertical and horizontal scales and bias) \n dist: The final distance between the rational function and the \ fitted one """ initials = np.array([1., 0., 1., 0.]) # a, b, c, d y = rational_func(x) from scipy.optimize import curve_fit import torch def equivalent_func(x_array, a, b, c, d): return a * new_func(c * x_array + d) + b params = curve_fit(equivalent_func, x, y, initials) a, b, c, d = params[0] final_func_output = np.array(equivalent_func(x, a, b, c, d)) final_distance = np.sqrt(((y - final_func_output)**2).sum()) return (a, b, c, d), final_distance
[docs]class Snapshot(): """ Snapshot to save, display, and export images of rational functions. Makes it easy to generate animations of the function through time, ... etc. Arguments: name (str): The name of Snapshot. rational (Rational): A rational function to save fitted_function (bool): If ``True``, displays the best fitted function if searched. Otherwise, returns it. \n Default ``True`` other_func (callable): another function to be plotted or a list of other callable \ functions or a dictionary with the function name as key \ and the callable as value. Default ``None`` Returns: Module: Rational module """ _HIST_WARNED = False _SEABORN_WARNED = False _SCIPY_WARNED = False def __init__(self, name, rational, fitted_function=True, other_func=None): self.name = name self.rational = rational.numpy() self.use_kde = rational.use_kde self.range = None self.histogram = None self.rat_name = rational.func_name if rational.distribution is not None and \ not rational.distribution.is_empty: from copy import deepcopy self.histogram = deepcopy(rational.distribution) msg = "Automatically clearing the distribution after snapshot" RationalWarning.warn(msg) rational.clear_hist() if fitted_function and rational.distribution is not None: self.best_fitted_function = rational.best_fitted_function self.best_fitted_function_params = \ rational.best_fitted_function_params else: self.best_fitted_function = None self.best_fitted_function_params = None self.other_func = other_func
[docs] def show(self, x=None, fitted_function=True, other_func=None, display=True, tolerance=0.001, title=None, axis=None): """ Show the function using `matplotlib`. Arguments: x (range): The range to print the function on.\n Default ``None`` fitted_function (bool): If ``True``, displays the best fitted function if searched. Otherwise, returns it. \n Default ``True`` display (bool): If ``True``, displays the graph. Otherwise, returns a dictionary with functions informations. \n Default ``True`` other_func (callable): another function to be plotted or a list of other callable \ functions or a dictionary with the function name as key \ and the callable as value. Default ``None`` tolerance (float): Tolerance the bins frequency. If tolerance is 0.001, every frequency smaller than 0.001 \ will be cutted out of the histogram.\n Default ``True`` title (str) If not `None`, title to be displayed on the figure.\n Default ``None`` axis (matplotlib.pyplot.axis): axis to be plotted on. If None, creates one automatically. Default ``None`` """ if x is not None: if x.dtype != float: x = x.astype(float) if not isinstance(x, np.ndarray): x = np.array(x) elif x is None and self.range is not None: print("Snapshot: Using range from initialisation") x = self.range elif self.histogram is not None: x = np.array(self.histogram.bins, dtype=float) x = _cleared_arrays(self.histogram, tolerance)[1] elif x is None: x = np.arange(-3, 3, 0.01) y_rat = self.rational(x) try: import seaborn as sns sns.set_style("whitegrid") except ImportError: RationalImportSeabornWarning.warn() # Rational if axis is None: ax = plt.gca() else: ax = axis ax.plot(x, y_rat, label=f"{self.rat_name}", zorder=2) if fitted_function and self.best_fitted_function is not None: if '__name__' in dir(self.best_fitted_function): func_label = self.best_fitted_function.__name__ else: func_label = str(self.best_fitted_function) a, b, c, d = self.best_fitted_function_params y_bff = a * numpify(self.best_fitted_function, c * x + d) + b ax.plot(x, y_bff, "r-", label=f"Fitted {func_label}", zorder=2) # Histogram if self.histogram is not None: weights, bins = _cleared_arrays(self.histogram, tolerance) ax2 = ax.twinx() ax2.set_yticks([]) try: import scipy.stats as sts scipy_imported = True except ImportError: RationalImportScipyWarning.warn() if self.use_kde and scipy_imported: if len(bins) > 5: kde_curv = self.histogram.kde()(bins) ax2.plot(bins, kde_curv, lw=1) ax2.fill_between(bins, kde_curv, alpha = 0.3) else: print("The bin size is too big, bins contain too few " "elements.\nbins:", bins) ax2.bar([], []) # in case of remove needed else: ax2.bar(bins, weights/weights.max(), width=bins[1] - bins[0], linewidth=0, alpha=0.3) ax.set_zorder(ax2.get_zorder()+1) # put ax in front of ax2 ax.patch.set_visible(False) # Other funcs if other_func is None and self.other_func is not None: other_func = self.other_func if other_func is not None: if type(other_func) is dict: for func_label, func in other_func.items(): ax.plot(x, func(x), label=func) else: if type(other_func) is not list: other_func = [other_func] for func in other_func: if '__name__' in dir(func): func_label = func.__name__ else: func_label = str(func) ax.plot(x, numpify(func, x), label=func_label) ax.legend(loc='upper right') if title is None: if not "snapshot" in self.name: ax.set_title(self.name) else: ax.set_title(f"{title}") if axis is None: if display: plt.show() else: return plt.gcf()
[docs] def borders(self, x=None, fitted_function=True, other_func=None, tolerance=0.001): """ Returns the borders x_min, x_max, y_min, y_max. Arguments: x (range): The range to print the function on.\n Default ``None`` fitted_function (bool): If ``True``, displays the best fitted function if searched. Otherwise, returns it. \n Default ``True`` other_func (callable): another function to be plotted or a list of other callable \ functions or a dictionary with the function name as key \ and the callable as value. Default ``None`` tolerance (float): Tolerance the bins frequency. If tolerance is 0.001, every frequency smaller than 0.001 \ will be cutted out of the histogram.\n Default ``True`` Returns: Module: Rational module """ if x is not None: if x.dtype != float: x = x.astype(float) if not isinstance(x, np.ndarray): x = np.array(x) elif x is None and self.range is not None: x = self.range elif self.histogram is not None: x = np.array(self.histogram.bins, dtype=float) x = _cleared_arrays(self.histogram, tolerance)[1] elif x is None: x = np.arange(-3, 3, 0.01) y_rat = self.rational(x) x_min, x_max = x.min(), x.max() y_min, y_max = y_rat.min(), y_rat.max() if fitted_function and self.best_fitted_function is not None: a, b, c, d = self.best_fitted_function_params y_bff = a * numpify(self.best_fitted_function, c * x + d) + b y_min, y_max = min(y_min, y_bff.min()), max(y_max, y_bff.max()) # Other funcs if other_func is None and self.other_func is not None: other_func = self.other_func if other_func is not None: if type(other_func) is dict: for func_label, func in other_func.items(): y_of = numpify(func, x) y_min, y_max = min(y_min, y_of.min()), max(y_max, y_of.max()) else: if type(other_func) is not list: other_func = [other_func] for func in other_func: y_of = numpify(func, x) y_min, y_max = min(y_min, y_of.min()), max(y_max, y_of.max()) return x_min, x_max, y_min, y_max
[docs] def save(self, x=None, fitted_function=True, other_func=None, path=None, tolerance=0.001, title=None, format="svg"): """ Saves an image of the snapshot. Arguments: x (range): The range to print the function on.\n Default ``None`` fitted_function (bool): If ``True``, displays the best fitted function if searched. Otherwise, returns it. \n Default ``True`` other_func (callable): another function to be plotted or a list of other callable \ functions or a dictionary with the function name as key \ and the callable as value.\n Default ``None`` tolerance (float): Tolerance the bins frequency. If tolerance is 0.001, every frequency smaller than 0.001 \ will be cutted out of the histogram.\n Default ``True`` title (str) If not `None`, title to be displayed on the figure.\n Default ``None`` format (str) The format of the figure, if not in the title.\n Default ``svg`` """ fig = self.show(x, fitted_function, other_func, False, tolerance, title) if path is None: path = self.name + f".{format}" elif "." not in path: path += f".{format}" path = _repair_path(path) fig.savefig(path) fig.clf()
def __repr__(self): return f"Snapshot ({self.name})"
# def _cleared_arrays(hist, tolerance=0.001): # freq, bins = hist.normalize() # first = (freq > tolerance).argmax() # last = - (freq > tolerance)[::-1].argmax() # if last == 0: # return freq[first:], bins[first:] # return freq[first:last], bins[first:last] def _cleared_arrays(hist, tolerance=0.001): weights, bins = hist.weights, hist.bins total = weights.sum() first = (weights > tolerance*total).argmax() last = - (weights > tolerance*total)[::-1].argmax() if last == 0: return weights[first:], bins[first:] return weights[first:last], bins[first:last] def _repair_path(path): import os changed = False if os.path.exists(path): print(f'Path "{path}" exists') changed = True while os.path.exists(path): if "." in path: path_list = path.split(".") path_list[-2] = _increment_string(path_list[-2]) path = '.'.join(path_list) else: path = _increment_string(path) if changed: print(f'Incremented, new path : "{path}"') if '/' in path: directory = "/".join(path.split("/")[:-1]) if not os.path.exists(directory): print(f'Path "{directory}" does not exist, creating') os.makedirs(directory) return path def _increment_string(string): if string[-1] in [str(i) for i in range(10)]: import re last_number = re.findall(r'\d+', string)[-1] return string[:-len(last_number)] + str(int(last_number) + 1) else: return string + "_2" def _erase_suffix(string): if string[-1] in [str(i) for i in range(10)]: return "_".join(string.split("_")[:-1]) else: return string def _get_frontiers(snapshot_list, other_func=None, fitted_function=True, tolerance=0.001): x_min, x_max, y_min, y_max = np.inf, -np.inf, np.inf, -np.inf for snap in snapshot_list: x_mi, x_ma, y_mi, y_ma = snap.borders(fitted_function=fitted_function, other_func=other_func, tolerance=tolerance) if x_mi < x_min: x_min = x_mi if y_mi < y_min: y_min = y_mi if x_ma > x_max: x_max = x_ma if y_ma > y_max: y_max = y_ma span = y_max - y_min return x_min, x_max, y_min - 0.1 * span, y_max + 0.1 * span def numpify(func, x): """ Assert that the function is called and returns a numpy array """ try: return np.array(func(x)) except TypeError as tper: if "Tensor" in str(tper): import torch return func(torch.tensor(x)).detach().numpy() else: print("Doesn't know how to handle this type of data") raise tper def _get_auto_axis_layout(nb_plots): if nb_plots == 1: return 1, 1 mid = int(np.sqrt(nb_plots)) for i in range(mid, 1, -1): mod = nb_plots % i if mod == 0: return i, nb_plots // i if mid * (mid + 1) >= nb_plots: return mid, mid + 1 return mid + 1, mid + 1 def _path_for_multiple(path, suffix): from os import makedirs if "." in path: path_root = ".".join(path.split(".")[:-1]) path_ext = "." + path.split(".")[-1] else: path_root = path path_ext = "" main_part = path_root.split("/")[-1] save_folder = _repair_path(f"{path_root}_{suffix}") makedirs(save_folder) return f"{save_folder}/{main_part}{path_ext}"