Find weights for Rational approximating new functions

If you want your function to approximate a function that is not already in rational/rationals_config.json, you have to run a script that would automatically find the coefficient of P and Q and add them to your the file.

This script uses numpy, scipy, and matplotlib so be sure to have them installed.

from rational.utils.find_init_weights import find_weights
import torch.nn.functional as F  # To get the tanh function

Then call the find_weights function with and pass it the function you want to approximate. We’ll take tanh in this example.

find_weights(F.tanh)

You will be asked to provide the parameters you’d like for the rational functions. We provide here some example of those, that are the default ones.

# approximated function name: tanh
# approximated function name: tanh
# degree of the numerator P: 5
# degree of the denominator Q: 4
# lower bound: -3
# upper bound: 3
# Rational Version: B

After computation (can take some time) found weights will be printed:

# Found coeffient :
# P: [2.11729498e-09 9.99994250e-01 6.27633277e-07 1.07708645e-01
#  2.94655690e-08 8.71124374e-04]
# Q: [6.37690834e-07 4.41014181e-01 2.27476614e-07 1.45810399e-02]

You will be asked if you want to see a plot of the obtained rational function.

# Do you want a plot of the result (y/n)y

If you accept, you will be shown a plot (to check the accuracy):

../_images/approx_tanh.png

Then, you will be asked if you want to store this result for latter use.

# Do you want to store them in the json file ? (y/n)y

If you say yes, then you can now use this function to initialise your rational function:

from rational.torch import Rational

rational_tanh_B = Rational("tanh", version="B")
print(rational_tanh_B.init_approximation)
# 'tanh'

We can check that we obtain the same weights as the one found above:

print(rational_tanh_B.numerator.cpu().detach().numpy())
# [2.1172950e-09 9.9999428e-01 6.2763326e-07 1.0770865e-01 2.9465570e-08
#  8.7112439e-04]
print(rational_tanh_B.denominator.cpu().detach().numpy())
# [6.3769085e-07 4.4101417e-01 2.2747662e-07 1.4581040e-02]

You now now how to add an function to the init file !