Source code for tf_pwa.experimental.extra_function

import functools

import tensorflow as tf


[docs] def extra_function(f0=None, using_numpy=True): """Using extra function with numerical differentiation. It can be used for numpy function or numba.vectorize function interface. >>> import numpy as np >>> sin2 = extra_function(np.sin) >>> a = tf.Variable([1.0,2.0], dtype="float64") >>> with tf.GradientTape(persistent=True) as tape0: ... with tf.GradientTape(persistent=True) as tape: ... b = sin2(a) ... g, = tape.gradient(b, [a,]) ... >>> h, = tape0.gradient(g, [a,]) >>> assert np.allclose(np.sin([1.0,2.0]), b.numpy()) >>> assert np.allclose(np.cos([1.0,2.0]), g.numpy()) >>> assert np.sum(np.abs(-np.sin([1.0,2.0]) - h.numpy())) < 1e-3 The numerical accuracy is not so well for second derivative. """ def _wrapper(f): delta_x = {"float64": 1e-6, "float32": 1e-3} @tf.custom_gradient def _grad(x, **kwargs): if using_numpy and hasattr(x, "numpy"): x = x.numpy() h = delta_x[x.dtype.name] f_u = f(x + h, **kwargs) f_d = f(x - h, **kwargs) f_0 = f(x, **kwargs) def _hess(dg): tmp = (f_u + f_d - 2 * f_0) / h / h return dg * tmp return (f_u - f_d) / 2 / h, _hess @tf.custom_gradient @functools.wraps(f) def _f(x, **kwargs): def _g2(dy): return dy * _grad(x, **kwargs) if using_numpy and hasattr(x, "numpy"): x2 = x.numpy() else: x2 = x f_0 = f(x2) return f_0, _g2 return _f if f0 is None: return _wrapper return _wrapper(f0)