Source code for tf_pwa.params_trans

import contextlib

import numpy as np
import tensorflow as tf


[docs] class ParamsTrans: def __init__(self, vm, err_matrix=None): self.vm = vm self.err_matrix = err_matrix self.tape = None
[docs] @contextlib.contextmanager def trans(self): with tf.GradientTape(persistent=True) as tape: yield self self.tape = tape
[docs] @contextlib.contextmanager def mask_params(self, params): with self.vm.mask_params(params): yield
[docs] def get_grad(self, val, keep=False): grad = self.tape.gradient( val, self.vm.trainable_variables, unconnected_gradients="zero", ) # print(grad) grad = tf.stack(grad, axis=-1) if not keep: del self.tape return grad
[docs] def get_error(self, vals, keep=False): if self.err_matrix is None: raise ValueError("no error matrix provided") if isinstance(vals, (list, tuple)): ret = type(vals)([self.get_error(v, keep=True) for v in vals]) elif isinstance(vals, dict): ret = {k: self.get_error(v, keep=True) for k, v in vals.items()} elif isinstance(vals, tf.Tensor): if len(vals.shape) == 0: # scalar grad = self.tape.gradient( vals, self.vm.trainable_variables, unconnected_gradients="zero", ) grad = tf.stack(grad, axis=-1) ret = tf.sqrt( tf.reduce_sum( tf.linalg.matvec(self.err_matrix, grad) * grad ) ) else: # vector grad = self.tape.jacobian( vals, self.vm.trainable_variables, unconnected_gradients="zero", ) grad = tf.stack(grad, axis=-1) grad = tf.reshape(grad, (-1, len(self.vm.trainable_variables))) new_err_matrix = tf.matmul( tf.matmul(grad, self.err_matrix), grad, transpose_b=True ) ret = tf.sqrt(tf.linalg.tensor_diag_part(new_err_matrix)) ret = tf.reshape(ret, vals.shape) else: raise TypeError( f"unsuported type {type(vals)}, use tensor instead" ) if not keep: del self.tape return ret
[docs] def get_error_matrix(self, vals, keep=False): if self.err_matrix is None: raise ValueError("no error matrix provided") if isinstance(vals, (list, tuple)): grad = [ self.tape.gradient( i, self.vm.trainable_variables, unconnected_gradients="zero", ) for i in vals ] elif isinstance(vals, tf.Tensor): grad = self.tape.jacobian( vals, self.vm.trainable_variables, unconnected_gradients="zero" ) if not keep: del self.tape # print(grad) grad = np.stack(grad).reshape((-1, len(self.vm.trainable_variables))) # print(grad, self.err_matrix, np.dot(grad, self.err_matrix), grad.T) return np.dot(np.dot(grad, self.err_matrix), grad.T)
def __getitem__(self, key): return self.vm.variables[key]