Source code for tf_pwa.model.custom

import numpy as np
import tensorflow as tf

from tf_pwa.data import data_shape, split_generator
from tf_pwa.variable import SumVar

from .model import Model, register_nll_model

"""
Custom nll model
"""


[docs] class BaseCustomModel(Model):
[docs] def value_and_grad(self, fun): all_var = self.Amp.trainable_variables n_var = len(all_var) def _fun(*args, **kwargs): with tf.GradientTape(persistent=True) as tape: y = fun(*args, **kwargs) dy = tf.nest.map_structure( lambda x: tf.stack( tape.gradient(x, all_var, unconnected_gradients="zero") ), y, ) del tape return y, dy return _fun
[docs] def nll( self, data, mcdata, weight: tf.Tensor = 1.0, batch=None, bg=None, mc_weight=1.0, ): int_mc = self.eval_normal_factors(mcdata, mc_weight) nll = self.eval_nll_part(data, weight, int_mc, idx=0) return nll
[docs] def eval_normal_factors(self, mcdata, weight=None): return []
[docs] def eval_nll_part(self, data, weight=None, norm=None, idx=0): raise NotImplementedError("")
def _fast_int_mc_grad(self, data): if self.Amp.vm.strategy is not None: return self._fast_int_mc_grad_multi(data) else: return self.value_and_grad(self.eval_normal_factors)( data[0], data[1] ) def _fast_nll_part_grad(self, data, int_mc=None, idx=0): if int_mc is None: all_var = self.Amp.trainable_variables n_var = len(all_var) int_mc = SumVar([np.array(1.0)], [np.zeros((n_var,))], all_var) if self.Amp.vm.strategy is not None: return self._fast_nll_part_grad_multi( data, int_mc.value, int_mc.grad, idx ) else: return self.value_and_grad( lambda: self.eval_nll_part(data[0], data[1], int_mc(), idx) )() @tf.function def _fast_int_mc_grad_multi(self, ia): strategy = self.Amp.vm.strategy a, b = ia vm = self.Amp.vm per_replica_losses = vm.strategy.run( self.value_and_grad(self.eval_normal_factors), args=(a, b) ) tmp = vm.strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None ) return tmp @tf.function def _fast_nll_part_grad_multi(self, ia, int_mc_x, int_mc_g, idx): strategy = self.Amp.vm.strategy a, b = ia int_mc = SumVar(int_mc_x, int_mc_g, self.Amp.trainable_variables) vm = self.Amp.vm per_replica_losses = vm.strategy.run( self.value_and_grad( lambda i0, i1: self.eval_nll_part(i0, i1, int_mc(), idx) ), args=(a, b), ) tmp = vm.strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None ) return tmp
[docs] def nll_grad_batch(self, data, mcdata, weight, mc_weight): all_var = self.Amp.trainable_variables n_var = len(all_var) int_mc = None # SumVar(0., np.zeros((n_var,)), all_var) for i, j in zip(mcdata, mc_weight): a, grad = self._fast_int_mc_grad((i, j)) tmp = SumVar(a, grad, all_var) if int_mc is None: int_mc = tmp else: int_mc = int_mc + tmp ret = 0.0 ret_grad = 0.0 for idx, (i, j) in enumerate(zip(data, weight)): a, grads = self._fast_nll_part_grad((i, j), int_mc, idx) ret = ret + a ret_grad = ret_grad + tf.stack(grads) return ret, ret_grad
[docs] def nll_grad_hessian( self, data, mcdata, weight=1.0, batch=24000, bg=None, mc_weight=1.0 ): all_var = self.Amp.trainable_variables n_var = len(all_var) int_mc = None for i, j in zip( split_generator(mcdata, batch_size=batch), split_generator(mc_weight, batch_size=batch), ): tmp = SumVar.from_call_with_hess( lambda: self.eval_normal_factors(i, j), all_var ) if int_mc is None: int_mc = tmp else: int_mc = int_mc + tmp ret = 0.0 ret_grad = 0.0 ret_hess = 0.0 for idx, (i, j) in enumerate( zip( split_generator(data, batch_size=batch), split_generator(weight, batch_size=batch), ) ): with tf.GradientTape(persistent=True) as tape0: with tf.GradientTape() as tape: if int_mc is None: a = self.eval_nll_part(i, j, None, idx=idx) else: a = self.eval_nll_part(i, j, int_mc(), idx=idx) grads = tape.gradient(a, all_var, unconnected_gradients="zero") hess = [ tape0.gradient(gi, all_var, unconnected_gradients="zero") for gi in grads ] del tape0 ret = ret + a ret_grad = ret_grad + tf.stack(grads) ret_hess = ret_hess + tf.stack(hess) return ret, ret_grad, ret_hess
[docs] @register_nll_model("simple") class SimpleNllModel(BaseCustomModel):
[docs] def eval_normal_factors(self, mcdata, weight): return [tf.reduce_sum(self.Amp(mcdata) * weight)]
[docs] def eval_nll_part(self, data, weight, norm, idx=0): nll = -tf.reduce_sum(weight * tf.math.log(self.Amp(data))) nll_norm = tf.reduce_sum(weight) * tf.math.log(norm[0]) return nll + nll_norm
[docs] @register_nll_model("simple_clip") class SimpleClipNllModel(SimpleNllModel):
[docs] def eval_nll_part(self, data, weight, norm, idx=0): from .model import clip_log nll = -tf.reduce_sum(weight * clip_log(self.Amp(data))) nll_norm = tf.reduce_sum(weight) * clip_log(norm[0]) return nll + nll_norm
[docs] @register_nll_model("simple_cfit") class SimpleCFitModel(BaseCustomModel): required_params = ["bg_frac"]
[docs] def eval_normal_factors(self, mcdata, weight): amp = self.Amp(mcdata) * weight * mcdata.get("eff_value", 1.0) a = tf.reduce_sum(amp) bg = weight * mcdata.get("bg_value", 1.0) b = tf.reduce_sum(bg) return [a, b]
[docs] def eval_nll_part(self, data, weight, norm, idx=0): bg_frac = self.bg_frac pdf = (1 - bg_frac) * self.Amp(data) * data.get( "err_value", 1.0 ) / norm[0] + bg_frac * data.get("bg_value", 1.0) / norm[1] nll = -tf.reduce_sum(weight * tf.math.log(pdf)) return nll
[docs] @register_nll_model("simple_chi2") class SimpleChi2Model(BaseCustomModel): """ fit amp = weight directly. Required set extended = True. """
[docs] def eval_nll_part(self, data, weight, norm, idx=0): nll = 0.5 * tf.reduce_sum((weight - self.Amp(data)) ** 2) return nll
@register_nll_model("constr_frac") class SimpleNllFracModel(BaseCustomModel): required_params = [ "constr_frac" ] # {name: {"res":[name], mask_params: {}, "value": 0.1, "sigma": 0.01}} def eval_normal_factors(self, mcdata, weight): int_mc = tf.reduce_sum(self.Amp(mcdata) * weight) ret = [int_mc] if self.constr_frac is None: self.constr_frac = {} for k, v in self.constr_frac.items(): res = v.get("res", k) mask_params = v.get("mask_params", {}) with self.Amp.temp_used_res(res): with self.Amp.mask_params(mask_params): int_i = tf.reduce_sum(self.Amp(mcdata) * weight) ret.append(int_i) return ret def eval_nll_part(self, data, weight, norm, idx=0): nll = -tf.reduce_sum(weight * tf.math.log(self.Amp(data))) nll_norm = tf.reduce_sum(weight) * tf.math.log(norm[0]) ret = nll + nll_norm if idx == 0: for i, (k, v) in enumerate(self.constr_frac.items()): value = v["value"] sigma = v["sigma"] ret = ( ret + 0.5 * ((norm[i + 1] / norm[0] - value) / sigma) ** 2 ) return ret
[docs] @register_nll_model("cfit_constr_frac") class SimpleNllFracModel(BaseCustomModel): required_params = [ "constr_frac", "bg_frac", ] # {name: {"res":[name], mask_params: {}, "value": 0.1, "sigma": 0.01}}
[docs] def eval_normal_factors(self, mcdata, weight): int_mc = tf.reduce_sum( self.Amp(mcdata) * mcdata.get("eff_value", 1.0) * weight ) ret = [int_mc] if self.constr_frac is None: self.constr_frac = {} for k, v in self.constr_frac.items(): res = v.get("res", k) mask_params = v.get("mask_params", {}) with self.Amp.temp_used_res(res): with self.Amp.mask_params(mask_params): int_i = tf.reduce_sum(self.Amp(mcdata) * weight) ret.append(int_i) bg = weight * mcdata.get("bg_value", 1.0) b = tf.reduce_sum(bg) ret.append(b) return ret
[docs] def eval_nll_part(self, data, weight, norm, idx=0): bg_frac = self.bg_frac pdf = (1 - bg_frac) * self.Amp(data) * data.get( "eff_value", 1.0 ) / norm[0] + bg_frac * data.get("bg_value", 1.0) / norm[-1] ret = -tf.reduce_sum(weight * tf.math.log(pdf)) if idx == 0: for i, (k, v) in enumerate(self.constr_frac.items()): value = v["value"] sigma = v["sigma"] ret = ( ret + 0.5 * ((norm[i + 1] / norm[0] - value) / sigma) ** 2 ) return ret