Source code for rational.mxnet.rationals

"""
Rational Activation Functions for MXNET
=======================================

This module allows you to create Rational Neural Networks using Learnable
Rational activation functions with MXNET networks.
"""
import mxnet as mx
from mxnet import initializer
from mxnet.gluon import HybridBlock

from rational.utils.get_weights import get_parameters
from rational.mxnet.versions import _version_a, _version_b, _version_c, _version_d
from rational._base.rational_base import Rational_base


[docs]class Rational(Rational_base, HybridBlock): """ Rational Activation Function, inheriting from ``mxnet.gluon.HybridBlock``. Arguments: approx_func (str): The name of the approximated function for initialisation. \n The different functions are available in `rational.rationals_config.json`. \n Default: ``leaky_relu`` degrees (tuple of int): The degrees of the numerator (P) and denominator (Q).\n Default ``(5, 4)`` cuda (bool): whether to execute on cuda device.\n NOTE: THIS PARAMETER IS CURRENTLY NOT CONSIDERED.\n CUDA GPUS ARE USED WHEN IT IS POSSIBLE version (str): Version of Rational to use. Rational(x) = P(x)/Q(x), where P(x) = (a_0 + a_1 * x + a_2 * x^2 + ... + a_n * x^n) and \n `A`: Q(x) = (1 + \|b_0 * x\| + \| b_1 * x^2\| + ... + \| b_m * x^{m+1}\|)\n `B`: Q(x) = (1 + \|b_0 * x + b_1 * x^2 + ... + b_m * x^{m + 1}\|)\n `C`: Q(x) = (0.1 + \|b_0 + b_1 * x + b_2 * x^2 + ... + b_m * x^m\|)\n `D`: like `B` with noised coefficients b_i\n Default ``A`` trainable (bool): Whether the weights are trainable, i.e, if they are updated during backward pass. \n Default ``True`` Returns: HybridBlock: Rational hybrid block """ def __init__(self, approx_func='leaky_relu', degrees=(5, 4), cuda=False, version='A', trainable=True, name=None, **kwargs): if name is None: name = approx_func super().__init__(name) # super(Rational, self).__init__(**kwargs) # read initial parameter configuration from external files w_numerator, w_denominator = get_parameters( version, degrees, approx_func) # convert w_numerator and w_denominator to mxnet arrays w_numerator = mx.nd.array(w_numerator) w_denominator = mx.nd.array(w_denominator) # register the amount of weights in numerator and denominator, since we need them during # symbolic execution, but are unable to retrieve them at later stages self.numerator_length = len(w_numerator) self.denominator_length = len(w_denominator) self.training = trainable self.degrees = degrees self.version = version self.init_approximation = approx_func # set specified context (currently not happening, since unclear, how and why helpful) # self.device = gpu() if cuda else cpu() # register and configure weights (numerator and denominator coefficients) with self.name_scope(): self.numerator = self.params.get(name='w_numerator', shape=(len(w_numerator),), init=initializer.Constant( w_numerator), grad_req='write' if trainable else 'null', differentiable=trainable) self.denominator = self.params.get(name='w_denominator', shape=(len(w_denominator),), init=initializer.Constant( w_denominator), grad_req='write' if trainable else 'null', differentiable=trainable) # register whether function is trainable, since this information needs to be passed to # version D self.training = trainable self.init_approximation = approx_func # set rational activation function version self.rational_func = {'A': _version_a, 'B': _version_b, 'C': _version_c, 'D': _version_d} \ .get(version) if self.rational_func is None: raise ValueError( "rational activation function version %s not implemented" % version)
[docs] def hybrid_forward(self, F, x, numerator, denominator): return self.rational_func(F, x, numerator, denominator, self.training, self.numerator_length, self.denominator_length)
[docs] def numpy(self): """ Returns a numpy version of this activation function. """ from rational.numpy import Rational as Rational_numpy rational_n = Rational_numpy(self.init_approximation, self.degrees, self.version) rational_n.numerator = self.numerator.data().asnumpy().tolist() rational_n.denominator = self.denominator.data().asnumpy().tolist() return rational_n
@property def device(self): return str(mx.context.current_context())