rational.torch

class rational.torch.Rational(approx_func='leaky_relu', degrees=(5, 4), cuda=None, version='A', trainable=True, train_numerator=True, train_denominator=True, name=None)[source]

Rational activation function inherited from torch.nn.Module.

Parameters
  • approx_func (str) –

    The name of the approximated function for initialisation. The different initialable functions are available in rational.rationals_config.json.

    Default leaky_relu.

  • degrees (tuple of int) –

    The degrees of the numerator (P) and denominator (Q).

    Default (5, 4)

  • cuda (bool) –

    Use GPU CUDA version.

    If None, use cuda if available on the machine

    Default None

  • version (str) –

    Version of Rational to use. Rational(x) = P(x)/Q(x)

    A: Q(x) = 1 + |b_1.x| + |b_2.x| + … + |b_n.x|

    B: Q(x) = 1 + |b_1.x + b_2.x + … + b_n.x|

    C: Q(x) = 0.1 + |b_1.x + b_2.x + … + b_n.x|

    D: like B with noise

    Default A

  • trainable (bool) –

    If the weights are trainable, i.e, if they are updated during backward pass

    Default True

Returns

Rational module

Return type

Module

best_fit(functions_list, x=None, show=False)

Compute the distance between the rational and the functions in functions_list, and return the one with the minimal the distance.

Parameters
  • functions_list (list of callable) – The function you want to fit to rational.

  • x (array) –

    The range on which the curves of the functions are fitted together.

    Default None

  • show (bool) –

    If True, plots the final fitted function and rational (using matplotlib).

    Default False

Returns

((a, b, c, d), dist) with:

a, b, c, d: the parameters to adjust the function (vertical and horizontal scales and bias)

dist: The final distance between the rational function and the fitted one

Return type

tuple

capture(name='snapshot_0', x=None, fitted_function=True, other_func=None, returns=False)

Captures a snapshot of the rational functions and related in the snapshot_list variable (or returns it if returns=True).

Parameters
  • name (str) –

    Name of the snapshot.

    Default "snapshot_0"

  • x (range) –

    The range to print the function on.

    Default None

  • fitted_function (bool) –

    If True, displays the best fitted function if searched. Otherwise, returns it.

    Default True

  • other_funcs (callable) – another function to be plotted or a list of other callable functions or a dictionary with the function name as key and the callable as value.

  • returns (bool) –

    If True, returns the snapshot. Otherwise, saves it in self.snapshot_list

    Default False

classmethod capture_all(name='snapshot_0', x=None, fitted_function=True, other_func=None, returns=False)

Captures a snapshot of every instanciated rational functions and related in the snapshot_list variable (or returns a list of them if returns=True).

Parameters
  • name (str) –

    Name of the snapshot.

    Default "snapshot_0"

  • x (range) –

    The range to print the function on.

    Default None

  • fitted_function (bool) –

    If True, displays the best fitted function if searched. Otherwise, returns it.

    Default True

  • other_funcs (callable) – another function to be plotted or a list of other callable functions or a dictionary with the function name as key and the callable as value.

  • returns (bool) –

    If True, returns the snapshot. Otherwise, saves it in self.snapshot_list

    Default False

export_evolution_graph(path='rational_evolution.gif', animated=True, other_func=None)

Creates and saves an animated graph of the function evolution based on the successive snapshots saved in snapshot_list.

Parameters
  • path (str) –

    Complete path with name of the figure.

    Default "rational_evolution.gif"

  • animated (bool) –

    Complete path with name of the figure.

    Default True

  • other_func (callable) –

    another function to be plotted or a list of other callable functions or a dictionary with the function name as key and the callable as value.

    Default None

classmethod export_evolution_graphs(path='rationals_evolution.gif', together=True, layout='auto', animated=True, other_func=None)

Creates and saves an animated graph of the function evolution based on the successive snapshots saved in snapshot_list for each instanciated rational function.

Parameters
  • path (str) –

    Complete path with name of the figure.

    Default "rationals_evolution.gif"

  • together (bool) –

    If True, the graphs of every functions are stored in different files.

    Default True

  • layout (tuple or 'auto') –

    Grid layout of the figure. If “auto”, one is generated. (see layout).

    Default "auto"

  • animated (bool) –

    If True, creates an animated gif, else, different files are created.

    Default True

  • other_func (callable) –

    another function to be plotted or a list of other callable functions or a dictionary with the function name as key and the callable as value.

    Default None

export_graph(path='rational_function.svg', snap_number=- 1, other_func=None)

Saves one graph of the function based on the last snapshot (by default, and if available).

Parameters
  • path (str) –

    Complete path with name of the figure.

    Default "rational_functions.svg"

  • together (bool) –

    If True, the graphs of every functions are stored in different files.

    Default True

  • layout (tuple or 'auto') – Grid layout of the figure. If “auto”, one is generated. (see layout). Default auto

  • snap_number (int) –

    The snap to take in snapshot_list for each function.

    Default -1 (last)

  • other_func (callable) – another function to be plotted or a list of other callable functions or a dictionary with the function name as key and the callable as value. Default None

classmethod export_graphs(path='rational_functions.svg', together=True, layout='auto', snap_number=- 1, other_func=None)

Saves one or more graph(s) of the function based on the last snapshot (by default, and if available) for each instanciated rational function.

Parameters
  • path (str) –

    Complete path with name of the figure.

    Default "rational_functions.svg"

  • together (bool) –

    If True, the graphs of every functions are stored in different files.

    Default True

  • layout (tuple or 'auto') – Grid layout of the figure. If “auto”, one is generated. (see layout). Default "auto"

  • snap_number (int) –

    The snap to take in snapshot_list for each function.

    Default -1 (last)

  • other_func (callable) – another function to be plotted or a list of other callable functions or a dictionary with the function name as key and the callable as value. Default None

fit(function, x=None, show=False)

Compute the parameters a, b, c, and d to have the neurally equivalent function of the provided one as close as possible to this rational function.

Parameters
  • function (callable) – The function you want to fit to rational.

  • x (array) –

    The range on which the curves of the functions are fitted together.

    Default None

  • show (bool) –

    If True, plots the final fitted function and rational (using matplotlib).

    Default False

Returns

((a, b, c, d), dist) with:

a, b, c, d: the parameters to adjust the function (vertical and horizontal scales and bias)

dist: The final distance between the rational function and the fitted one

Return type

tuple

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_retrieve_mode(auto_stop=False, max_saves=1000, bin_width=0.1)[source]

Will retrieve the distribution of the input in self.distribution.

This will slow down the function, as it has to retrieve the input dist.

Parameters
  • auto_stop (bool) –

    If True, the retrieving will stop after max_saves calls to forward.

    Else, use torch.Rational.training_mode().

    Default False

  • max_saves (int) –

    The range on which the curves of the functions are fitted together.

    Default 1000

numpy()[source]

Returns a numpy version of this activation function.

classmethod save_all_inputs(save, auto_stop=False, max_saves=10000, bin_width='auto')[source]

Have every rational save every input.

Parameters
  • save (bool) – If True, every instanciated rational function will retrieve its input, else, it won’t.

  • auto_stop (bool) –

    If True, the retrieving will stop after max_saves calls to forward.

    Else, use torch.Rational.training_mode().

    Default True

  • max_saves (int) –

    The range on which the curves of the functions are fitted together.

    Default 10000

  • bin_width (float or "auto") –

    The size of the histogram’s bin width to store the input in.

    If “auto”, then automatically determines the bin width to have ~100 bins.

    Default "auto"

show(x=None, fitted_function=True, other_func=None, display=True, tolerance=0.001, title=None, axis=None)

Shows a graph of the function (or returns it if returns=True).

Parameters
  • x (range) –

    The range to print the function on.

    Default None

  • fitted_function (bool) –

    If True, displays the best fitted function if searched. Otherwise, returns it.

    Default True

  • other_funcs (callable) – another function to be plotted or a list of other callable functions or a dictionary with the function name as key and the callable as value.

  • display (bool) –

    If True, displays the plot. Otherwise, returns the figure.

    Default False

  • tolerance (float) –

    If the input histogram is used, it will be pruned.

    Every bin containg less than tolerance of the total input is pruned out. (Reduces noise). Default 0.001

  • title (str) – If not None, a title for the figure Default None

  • axis (matplotlib.pyplot.axis) – axis to be plotted on. If None, creates one automatically. Default None

classmethod show_all(x=None, fitted_function=True, other_func=None, display=True, tolerance=0.001, title=None, axes=None, layout='auto')

Shows a graph of the all instanciated rational functions (or returns it if returns=True).

Parameters
  • x (range) –

    The range to print the function on.

    Default None

  • fitted_function (bool) –

    If True, displays the best fitted function if searched. Otherwise, returns it.

    Default True

  • other_funcs (callable) – another function to be plotted or a list of other callable functions or a dictionary with the function name as key and the callable as value.

  • display (bool) –

    If True, displays the plot. Otherwise, returns the figure.

    Default False

  • tolerance (float) –

    If the input histogram is used, it will be pruned.

    Every bin containg less than tolerance of the total input is pruned out. (Reduces noise). Default 0.001

  • title (str) – If not None, a title for the figure Default None

  • axes (matplotlib.pyplot.axis) –

    On ax or a list of axes to be plotted on.

    If None, creates them automatically (see layout).

    Default None

  • layout (tuple or 'auto') –

    Grid layout of the figure. If “auto”, one is generated.

    Default "auto"

training_mode()[source]

Stops retrieving the distribution of the input in self.distribution.