Source code for tf_pwa.fitfractions

import functools

import numpy as np
import tensorflow as tf

from tf_pwa.data import LazyCall, data_split


def _eval_integral(
    f, data, var, weight=None, args=(), no_grad=False, kwargs=None
):
    kwargs = {} if kwargs is None else kwargs
    weight = 1.0 if weight is None else weight
    if no_grad:
        fx = f(data, *args, **kwargs) * weight
        ret = tf.reduce_sum(fx)
        ret_grad = np.zeros((len(var),))
    else:
        with tf.GradientTape() as tape:
            fx = f(data, *args, **kwargs) * weight
            ret = tf.reduce_sum(fx)
        ret_grad = tape.gradient(ret, var, unconnected_gradients="zero")
        if len(ret_grad) == 0:
            ret_grad = np.array([])
        else:
            ret_grad = np.stack([i.numpy() for i in ret_grad])
    int_square = tf.reduce_sum(fx**2)
    return ret.numpy(), ret_grad, int_square.numpy()


[docs] def force_list(x): if isinstance(x, (list, tuple)): return x return [x]
[docs] class FitFractions: def __init__(self, amp, res): self.amp = amp self.var = amp.trainable_variables self.n_var = len(amp.trainable_variables) self.res = res self.cached_int = {} self.cached_grad = {} self.cached_square = {} self.cached_int_total = 0.0 self.cached_grad_total = np.zeros((self.n_var,)) self.error_matrix = np.diag(np.zeros((self.n_var,))) self.init_res_table()
[docs] def init_res_table(self): n = len(self.res) self.cached_int_total = 0.0 self.cached_grad_total = np.zeros((self.n_var,)) for i in range(n): for j in range(i, -1, -1): if i == j: name = str(self.res[i]) else: name = (str(self.res[i]), str(self.res[j])) self.cached_int[name] = 0.0 self.cached_square[name] = 0.0 self.cached_grad[name] = np.zeros_like((self.n_var,))
[docs] def integral(self, mcdata, *args, batch=None, no_grad=False, **kwargs): self.init_res_table() if batch is None: self.append_int(mcdata, *args, no_grad=no_grad, **kwargs) else: for data_i in data_split(mcdata, batch): self.append_int(data_i, *args, no_grad=no_grad, **kwargs)
[docs] def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs): # print(data, data_shape(data)) if isinstance(mcdata, LazyCall): mcdata = mcdata.eval() if weight is None: weight = mcdata.get("weight", 1.0) int_mc, g_int_mc, _ = _eval_integral( self.amp, mcdata, var=self.var, weight=weight, args=args, kwargs=kwargs, ) self.cached_int_total += int_mc self.cached_grad_total += g_int_mc cahced_res = self.amp.used_res amp_tmp = self.amp fl = force_list for i in range(len(self.res)): for j in range(i, -1, -1): if i == j: name = str(self.res[i]) amp_tmp.set_used_res(fl(self.res[i])) else: name = (str(self.res[i]), str(self.res[j])) amp_tmp.set_used_res(fl(self.res[i]) + fl(self.res[j])) int_tmp, g_int_tmp, int_square = _eval_integral( amp_tmp, mcdata, var=self.var, weight=weight, args=args, kwargs=kwargs, ) self.cached_int[name] = self.cached_int[name] + int_tmp self.cached_grad[name] = self.cached_grad[name] + g_int_tmp self.cached_square[name] = ( self.cached_square[name] + int_square ) self.amp.set_used_res(cahced_res)
[docs] def get_frac_grad(self, sum_diag=True): n = len(self.res) int_mc = self.cached_int_total g_int_mc = self.cached_grad_total fit_frac = {} g_fit_frac = {} for i in range(n): name = str(self.res[i]) int_tmp = self.cached_int[name] g_int_tmp = self.cached_grad[name] fit_frac[name] = int_tmp / int_mc gij = g_int_tmp / int_mc - (int_tmp / int_mc) * g_int_mc / int_mc g_fit_frac[name] = gij for j in range(i - 1, -1, -1): name = (str(self.res[i]), str(self.res[j])) int_tmp = self.cached_int[name] g_int_tmp = self.cached_grad[name] fit_frac[name] = ( (int_tmp / int_mc) - fit_frac[str(self.res[i])] - fit_frac[str(self.res[j])] ) gij = ( g_int_tmp / int_mc - (int_tmp / int_mc) * g_int_mc / int_mc - g_fit_frac[str(self.res[i])] - g_fit_frac[str(self.res[j])] ) # print(name,gij.tolist()) g_fit_frac[name] = gij if sum_diag: fit_frac["sum_diag"] = sum([fit_frac[str(i)] for i in self.res]) g_fit_frac["sum_diag"] = sum( [g_fit_frac[str(i)] for i in self.res] ) return fit_frac, g_fit_frac
[docs] def get_frac(self, error_matrix=None, sum_diag=True, add_int_error=False): if error_matrix is None: error_matrix = self.error_matrix fit_frac, g_fit_frac = self.get_frac_grad(sum_diag=sum_diag) if error_matrix is None: return fit_frac, {} fit_frac_err = {} for k, v in g_fit_frac.items(): e = np.sqrt(np.dot(np.dot(error_matrix, v), v)) fit_frac_err[k] = e if add_int_error and k in self.cached_square: scale = 1 / self.cached_int_total**2 fit_frac_err[k] = np.sqrt(e**2 + self.cached_square[k] * scale) return fit_frac, fit_frac_err
def __iter__(self): return iter(self.get_frac())
[docs] def get_frac_diag_sum(self, error_matrix=None): if error_matrix is None: error_matrix = self.error_matrix sd = 0.0 sd_g = np.zeros((self.n_var,)) for i in self.cached_int: if isinstance(i, str): sd = sd + self.cached_int[i] sd_g += self.cached_grad[i] sd_e = np.sqrt(np.dot(np.dot(error_matrix, sd_g), sd_g)) return sd, sd_e
[docs] def nll_grad(f, var, args=(), kwargs=None, options=None): kwargs = kwargs if kwargs is not None else {} options = options if options is not None else {} @functools.wraps(f) def f_w(): with tf.GradientTape() as tape: ret = f(*args, **kwargs) g = tape.gradient(ret, var, unconnected_gradients="zero", **options) return ret, g return f_w
[docs] def cal_fitfractions(amp, mcdata, res=None, batch=None, args=(), kwargs=None): r""" defination: .. math:: FF_{i} = \frac{\int |A_i|^2 d\Omega }{ \int |\sum_{i}A_i|^2 d\Omega } \approx \frac{\sum |A_i|^2 }{\sum|\sum_{i} A_{i}|^2} interference fitfraction: .. math:: FF_{i,j} = \frac{\int 2Re(A_i A_j*) d\Omega }{ \int |\sum_{i}A_i|^2 d\Omega } = \frac{\int |A_i +A_j|^2 d\Omega }{ \int |\sum_{i}A_i|^2 d\Omega } - FF_{i} - FF_{j} gradients (for error transfer): .. math:: \frac{\partial }{\partial \theta_i }\frac{f(\theta_i)}{g(\theta_i)} = \frac{\partial f(\theta_i)}{\partial \theta_i} \frac{1}{g(\theta_i)} - \frac{\partial g(\theta_i)}{\partial \theta_i} \frac{f(\theta_i)}{g^2(\theta_i)} """ kwargs = kwargs if kwargs is not None else {} var = amp.trainable_variables # allvar = [i.name for i in var] cahced_res = amp.used_res if res is None: res = list(amp.res) n_res = len(res) fitFrac = {} err_fitFrac = {} g_fitFrac = [None] * n_res amp.set_used_res(res) weight = 1.0 if batch is not None: weight = mcdata.get("weight", 1.0) mcdata = list(data_split(mcdata, batch)) if not isinstance(weight, float): weight = list(data_split(weight, batch)) int_mc, g_int_mc = sum_gradient( amp, mcdata, var=var, weight=weight, args=args, kwargs=kwargs ) for i in range(n_res): for j in range(i, -1, -1): amp_tmp = amp if i == j: name = "{}".format(res[i]) amp_tmp.set_used_res([res[i]]) else: name = (str(res[i]), str(res[j])) amp_tmp.set_used_res([res[i], res[j]]) int_tmp, g_int_tmp = sum_gradient( amp_tmp, mcdata, var=var, weight=weight, args=args, kwargs=kwargs, ) if i == j: fitFrac[name] = int_tmp / int_mc gij = ( g_int_tmp / int_mc - (int_tmp / int_mc) * g_int_mc / int_mc ) g_fitFrac[i] = gij else: fitFrac[name] = ( (int_tmp / int_mc) - fitFrac["{}".format(res[i])] - fitFrac["{}".format(res[j])] ) gij = ( g_int_tmp / int_mc - (int_tmp / int_mc) * g_int_mc / int_mc - g_fitFrac[i] - g_fitFrac[j] ) # print(name,gij.tolist()) err_fitFrac[name] = gij amp.set_used_res(cahced_res) return fitFrac, err_fitFrac
[docs] def cal_fitfractions_no_grad( amp, mcdata, res=None, batch=None, args=(), kwargs=None ): r""" calculate fit fractions without gradients. """ kwargs = kwargs if kwargs is not None else {} var = amp.trainable_variables # allvar = [i.name for i in var] cahced_res = amp.used_res if res is None: res = list(amp.res) n_res = len(res) fitFrac = {} amp.set_used_res(res) weight = 1.0 if batch is not None: weight = mcdata.get("weight", 1.0) mcdata = list(data_split(mcdata, batch)) if not isinstance(weight, float): weight = list(data_split(weight, batch)) int_mc = sum_no_gradient( amp, mcdata, var=var, weight=weight, args=args, kwargs=kwargs ) for i in range(n_res): for j in range(i, -1, -1): amp_tmp = amp if i == j: name = "{}".format(res[i]) amp_tmp.set_used_res([res[i]]) else: name = "{}x{}".format(res[i], res[j]) amp_tmp.set_used_res([res[i], res[j]]) int_tmp = sum_no_gradient( amp_tmp, mcdata, var=var, weight=weight, args=args, kwargs=kwargs, ) if i == j: fitFrac[name] = int_tmp / int_mc else: fitFrac[name] = ( (int_tmp / int_mc) - fitFrac["{}".format(res[i])] - fitFrac["{}".format(res[j])] ) amp.set_used_res(cahced_res) return fitFrac
[docs] def sum_gradient( amp, data, var, weight=1.0, func=lambda x: x, grad=True, args=(), kwargs=None, ): kwargs = kwargs if kwargs is not None else {} # n_variables = len(var) if isinstance(weight, float): # 给data乘weight weight = [weight] * len(data) nll_list = [] g_list = [] def f(d, w): amp2s = amp(d, *args, **kwargs) # amp是振幅表达式 l_a = func(amp2s) # ampPDF转换成每组数据的NLL表达式 return tf.reduce_sum(tf.cast(w, l_a.dtype) * l_a) for d, w in zip(data, weight): if grad: # 是否提供grad表达式 p_nll, a = nll_grad( f, var, args=(d, w) )() # 调用上面定义的nll_grad,返回f(d,w)和f对var的导数 g_list.append([i.numpy() for i in a]) else: p_nll = f(d, w) nll_list.append(p_nll.numpy()) nll = sum(nll_list) if grad: g = np.array(g_list).sum(0) return nll, g # nll值和各var的导数g值 return nll
sum_no_gradient = functools.partial(sum_gradient, grad=False)