Source code for rational.utils.find_init_weights

"""
find_init_weights.py
====================================
Finding the weights of the to map an specific activation function
"""

import json
import numpy as np
from .utils import fit_rational_to_base_function
import torch
import os
from rational.numpy.rationals import Rational_version_A, Rational_version_B, \
    Rational_version_C, Rational_version_N


def plot_result(x_array, rational_array, target_array,
                original_func_name="Original function"):
    import matplotlib.pyplot as plt
    plt.plot(x_array, rational_array, label="Rational approx")
    plt.plot(x_array, target_array, label=original_func_name)
    plt.legend()
    plt.grid()
    plt.show()


def append_to_config_file(params, approx_name, w_params, d_params, overwrite=None):
    rational_full_name = f'Rational_version_{params["version"]}{params["nd"]}/{params["dd"]}'
    cfd = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
    with open(f'{cfd}/rationals_config.json') as json_file:
        rationals_dict = json.load(json_file)  # rational_version -> approx_func
    approx_name = approx_name.lower()
    if rational_full_name in rationals_dict:
        if approx_name in rationals_dict[rational_full_name]:
            if overwrite is None:
                overwrite = input(f'Rational_{params["version"]} approximation of {approx_name} already exist. \
                                  \nDo you want to replace it ? (y/n)') in ["y", "yes"]
            if not overwrite:
                print("Parameters not stored")
                return
        else:
            rationals_params = {"init_w_numerator": w_params.tolist(),
                                "init_w_denominator": d_params.tolist(),
                                "ub": params["ub"], "lb": params["lb"]}
            rationals_dict[rational_full_name][approx_name] = rationals_params
            with open(f'{cfd}/rationals_config.json', 'w') as outfile:
                json.dump(rationals_dict, outfile, indent=1)
            print("Parameters stored in rationals_config.json")
            return
    rationals_dict[rational_full_name] = {}
    rationals_params = {"init_w_numerator": w_params.tolist(),
                        "init_w_denominator": d_params.tolist(),
                        "ub": params["ub"], "lb": params["lb"]}
    rationals_dict[rational_full_name][approx_name] = rationals_params
    with open(f'{cfd}/rationals_config.json', 'w') as outfile:
        json.dump(rationals_dict, outfile, indent=1)
    print("Parameters stored in rationals_config.json")



def typed_input(text, type, choice_list=None):
    assert isinstance(text, str)
    while True:
        try:
            inp = input(text)
            typed_inp = type(inp)
            if choice_list is not None:
                assert typed_inp in choice_list
            break
        except ValueError:
            print(f"Please provide an type: {type}")
            continue
        except AssertionError:
            print(f"Please provide a value within {choice_list}")
            continue
    return typed_inp


FUNCTION = None


[docs]def find_weights(function, function_name=None, degrees=None, bounds=None, version=None, plot=None, save=None, overwrite=None): """ Finds the weights of the numerator and the denominator of the rational function. Beside `function`, all parameters can be left to the default ``None``. \n In this case, user is asked to provide the params interactively. Arguments: function (callable): The function to approximate (e.g. from torch.functional).\n function_name (str): The name of this function (used at Rational initialisation)\n degrees (tuple of int): The degrees of the numerator (P) and denominator (Q).\n Default ``None`` bounds (tuple of int): The bounds to approximate on (e.g. (-3,3)).\n Default ``None`` version (str): Version of Rational to use. Rational(x) = P(x)/Q(x)\n `A`: Q(x) = 1 + \|b_1.x\| + \|b_2.x\| + ... + \|b_n.x\|\n `B`: Q(x) = 1 + \|b_1.x + b_2.x + ... + b_n.x\|\n `C`: Q(x) = 0.1 + \|b_1.x + b_2.x + ... + b_n.x\|\n `D`: like `B` with noise\n plot (bool): If True, plots the fitted and target functions. Default ``None`` save (bool): If True, saves the weights in the config file. Default ``None`` save (bool): If True, if weights already exist for this configuration, they are overwritten. Default ``None`` Returns: tuple: (numerator, denominator) if not `save`, otherwise `None` \n """ # To be changed by the function you want to approximate if function_name is None: function_name = input("approximated function name: ") FUNCTION = function def function_to_approx(x): # return np.heaviside(x, 0) x = torch.tensor(x) return FUNCTION(x) if degrees is None: nd = typed_input("degree of the numerator P: ", int) dd = typed_input("degree of the denominator Q: ", int) degrees = (nd, dd) else: nd, dd = degrees if bounds is None: print("On what range should the function be approximated ?") lb = typed_input("lower bound: ", float) ub = typed_input("upper bound: ", float) else: lb, ub = bounds nb_points = 100000 step = (ub - lb) / nb_points x = np.arange(lb, ub, step) if version is None: version = typed_input("Rational Version: ", str, ["A", "B", "C", "D", "N"]) if version == 'A': rational = Rational_version_A elif version == 'B': rational = Rational_version_B elif version == 'C': rational = Rational_version_C elif version == 'D': rational = Rational_version_B elif version == 'N': rational = Rational_version_N w_params, d_params = fit_rational_to_base_function(rational, function_to_approx, x, degrees=degrees, version=version) print(f"Found coeffient :\nP: {w_params}\nQ: {d_params}") if plot is None: plot = input("Do you want a plot of the result (y/n)") in ["y", "yes"] if plot: plot_result(x, rational(x, w_params, d_params), function_to_approx(x), function_name) params = {"version": version, "name": function_name, "ub": ub, "lb": lb, "nd": nd, "dd": dd} if save is None: save = input("Do you want to store them in the json file ? (y/n)") in ["y", "yes"] if save: append_to_config_file(params, function_name, w_params, d_params, overwrite) else: print("Parameters not stored") return w_params, d_params