Source code for tf_pwa.config_loader.config_loader

import contextlib
import copy
import functools
import itertools
import json
import os
import re
import time
import warnings

import matplotlib.pyplot as plt
import numpy as np
import sympy as sy
import yaml
from scipy.interpolate import UnivariateSpline, interp1d
from scipy.optimize import BFGS, basinhopping, minimize

from tf_pwa.adaptive_bins import AdaptiveBound, cal_chi2
from tf_pwa.amp import (
    DecayChain,
    DecayGroup,
    HelicityDecay,
    create_amplitude,
    get_decay,
    get_particle,
)
from tf_pwa.applications import (
    cal_hesse_correct,
    cal_hesse_error,
    corr_coef_matrix,
    fit,
    fit_fractions,
    force_pos_def,
    num_hess_inv_3point,
)
from tf_pwa.cal_angle import prepare_data_from_decay
from tf_pwa.data import (
    ReadData,
    data_index,
    data_merge,
    data_shape,
    data_split,
    data_to_numpy,
    load_data,
    save_data,
)
from tf_pwa.fit import FitResult
from tf_pwa.fit_improve import minimize as my_minimize
from tf_pwa.model import FCN, CombineFCN, MixLogLikehoodFCN, Model, Model_new
from tf_pwa.model.cfit import Model_cfit, Model_cfit_cached, ModelCfitExtended
from tf_pwa.model.opt_int import ModelCachedAmp, ModelCachedInt
from tf_pwa.particle import split_particle_type
from tf_pwa.root_io import has_uproot, save_dict_to_root
from tf_pwa.tensorflow_wrapper import tf
from tf_pwa.utils import time_print
from tf_pwa.variable import Variable, VarsManager

from .base_config import BaseConfig
from .data import load_data_mode
from .decay_config import DecayConfig


[docs] class ConfigLoader(BaseConfig): """class for loading config.yml""" def __init__(self, file_name, vm=None, share_dict=None): if share_dict is None: share_dict = {} super().__init__(file_name, share_dict) self.config["data"] = self.config.get("data", {}) self.multi_gpu = self.config["data"].get("multi_gpu", False) self.share_dict = share_dict self.decay_config = DecayConfig(self.config, share_dict) self.dec = self.decay_config.dec self.particle_map, self.particle_property = ( self.decay_config.particle_map, self.decay_config.particle_property, ) self.top, self.finals = self.decay_config.top, self.decay_config.finals self.full_decay = self.decay_config.full_decay self.decay_struct = self.decay_config.decay_struct if vm is None: vm = VarsManager(multi_gpu=self.multi_gpu) self.vm = vm self.amps = {} self.cached_data = None self.bound_dic = {} self.gauss_constr_dic = {} self.init_value = {} self.plot_params = PlotParams( self.config.get("plot", {}), self.decay_struct ) self._neglect_when_set_params = [] self.inv_he = None self._Ngroup = 1 self.cached_fcn = {} self.extra_constrains = {} self.resolution_size = self.config.get("data", {}).get( "resolution_size", 1 ) self.chains_id_method = "auto" self.chains_id_method_table = {} self.data = load_data_mode( self["data"], self.decay_struct, config=self )
[docs] @staticmethod def load_config(file_name, share_dict={}): if isinstance(file_name, dict): return copy.deepcopy(file_name) if isinstance(file_name, str): if file_name in share_dict: return ConfigLoader.load_config(share_dict[file_name]) with open(file_name) as f: ret = yaml.load(f, yaml.FullLoader) return ret raise TypeError("not support config {}".format(type(file_name)))
[docs] def get_data_file(self, idx): if idx in self.config["data"]: ret = self.config["data"][idx] else: ret = None return ret
[docs] def get_dat_order(self, standard=False): order = self.config["data"].get("dat_order", None) if order is None: order = list(self.decay_struct.outs) else: order = [get_particle(str(i)) for i in order] if not standard: return order re_map = self.decay_struct.get_chains_map() def particle_item(): for j in re_map: for k, v in j.items(): for s, l in v.items(): yield s, l new_order = [] for i in order: for s, l in particle_item(): if str(l) == str(i): new_order.append(s) break else: new_order.append(i) return new_order
[docs] @functools.lru_cache() def get_data(self, idx): return self.data.get_data(idx)
[docs] def load_cached_data(self, file_name=None): return self.data.load_cached_data(file_name)
[docs] def save_cached_data(self, data, file_name=None): self.data.save_cached_data(data, file_name=file_name)
[docs] def get_all_data(self): datafile = ["data", "phsp", "bg", "inmc"] self.load_cached_data() data, phsp, bg, inmc = [self.get_data(i) for i in datafile] self._Ngroup = len(data) assert len(phsp) == self._Ngroup if bg is None: bg = [None] * self._Ngroup if inmc is None: inmc = [None] * self._Ngroup assert len(bg) == self._Ngroup assert len(inmc) == self._Ngroup self.save_cached_data(dict(zip(datafile, [data, phsp, bg, inmc]))) return data, phsp, bg, inmc
[docs] def get_data_index(self, sub, name): return self.plot_params.get_data_index(sub, name)
[docs] def get_phsp_noeff(self): if "phsp_noeff" in self.config["data"]: phsp_noeff = self.get_data("phsp_noeff") assert len(phsp_noeff) == 1 return phsp_noeff[0] warnings.warn( "No data file as 'phsp_noeff', using the first 'phsp' file instead." ) return self.get_data("phsp")[0]
[docs] def get_phsp_plot(self, tail=""): if "phsp_plot" + tail in self.config["data"]: assert len(self.config["data"]["phsp_plot" + tail]) == len( self.config["data"]["phsp"] ) return self.get_data("phsp_plot" + tail) return self.get_data("phsp" + tail)
[docs] def get_data_rec(self, name): ret = self.get_data(name + "_rec") if ret is None: ret = self.get_data(name) return ret
[docs] def get_decay(self, full=True): if full: return self.full_decay else: return self.decay_struct
[docs] @functools.lru_cache() def get_amplitude(self, vm=None, name=""): amp_config = self.config.get("data", {}) use_tf_function = amp_config.get("use_tf_function", False) no_id_cached = amp_config.get("no_id_cached", False) jit_compile = amp_config.get("jit_compile", False) amp_model = amp_config.get("amp_model", "default") cached_shape_idx = amp_config.get("cached_shape_idx", None) decay_group = self.full_decay self.check_valid_jp(decay_group) if vm is None: vm = self.vm if vm in self.amps: return self.amps[vm] amp = create_amplitude( decay_group, vm=vm, name=name, use_tf_function=use_tf_function, no_id_cached=no_id_cached, jit_compile=jit_compile, model=amp_model, cached_shape_idx=cached_shape_idx, all_config=amp_config, ) self.add_constraints(amp) self.amps[vm] = amp return amp
[docs] def eval_amplitude(self, *p, extra=None): extra = {} if extra is None else extra if len(p) == len(self.decay_struct.outs): data = self.data.cal_angle(p, **extra) elif len(p) == 1: data = self.data.cal_angle(p[0], **extra) elif len(p) == 0: data = self.data.cal_angle(**extra) else: raise "Not all data" amp = self.get_amplitude() return amp(data)
[docs] def check_valid_jp(self, decay_group): for decay_chain in decay_group: for dec in decay_chain: if isinstance(dec, HelicityDecay): dec.check_valid_jp()
[docs] def add_constraints(self, amp): constrains = self.config.get("constrains", {}) if constrains is None: constrains = {} self.add_decay_constraints(amp, constrains.get("decay", {})) self.add_particle_constraints(amp, constrains.get("particle", {})) self.add_fix_var_constraints(amp, constrains.get("fix_var", {})) self.add_free_var_constraints(amp, constrains.get("free_var", [])) self.add_var_range_constraints(amp, constrains.get("var_range", {})) self.add_var_equal_constraints(amp, constrains.get("var_equal", [])) self.add_pre_trans_constraints(amp, constrains.get("pre_trans", None)) self.add_from_trans_constraints( amp, constrains.get("from_trans", None) ) self.add_gauss_constr_constraints( amp, constrains.get("gauss_constr", {}) ) for k, v in self.extra_constrains.items(): v(amp, constrains.get(k, {}))
[docs] def register_extra_constrains(self, name, f=None): """ add extra_constrains """ def _reg(g): self.extra_constrains[name] = g return g if f is None: return _reg else: return _reg(f)
[docs] def add_fix_var_constraints(self, amp, dic=None): if dic is None: dic = {} for k, v in dic.items(): print("fix var: ", k, "=", v) amp.vm.set_fix(k, v)
[docs] def add_free_var_constraints(self, amp, dic=None): if dic is None: dic = [] for k in dic: print("free var: ", k) amp.vm.set_fix(k, unfix=True)
[docs] def add_var_range_constraints(self, amp, dic=None): if dic is None: dic = {} for k, v in dic.items(): print("variable range: ", k, " in ", v) self.bound_dic[k] = v
[docs] def add_var_equal_constraints(self, amp, dic=None): if dic is None: dic = [] for k in dic: print("same value:", k) amp.vm.set_same(k)
[docs] def add_pre_trans_constraints(self, amp, dic=None): if dic is None: return from tf_pwa.transform import create_trans for k, v in dic.items(): print("transform:", k, v) v["x"] = v.get("x", k) trans = create_trans(v) amp.vm.pre_trans[k] = trans
[docs] def add_from_trans_constraints(self, amp, dic=None): if dic is None: return var_equal = [] pre_trans = {} new_var = [] for k, v in dic.items(): x = v.pop("x", None) if x is not None: if isinstance(x, list) and k != x[0]: new_var += x var_equal.append([x[0], k]) elif isinstance(x, str) and x != k: new_var.append(x) var_equal.append([x, k]) else: raise TypeError("x should be str or list") else: x = k v["x"] = x pre_trans[k] = v for i in new_var: if i not in amp.vm.variables: amp.vm.add_real_var(i) ConfigLoader.add_var_equal_constraints(self, amp, var_equal) ConfigLoader.add_pre_trans_constraints(self, amp, pre_trans)
[docs] def add_decay_constraints(self, amp, dic=None): if dic is None: dic = {} fix_total_idx = dic.get("fix_chain_idx", 0) fix_total_val = dic.get("fix_chain_val", np.random.uniform(0, 2)) fix_decay = amp.decay_group.get_decay_chain(fix_total_idx) # fix which total factor fix_decay.total.set_fix_idx(fix_idx=0, fix_vals=(fix_total_val, 0.0)) decay_d = dic.get("decay_d", None) if decay_d is not None: if isinstance(decay_d, (float, int)): decay_d = [decay_d] * len(amp.decay_group[0]) if isinstance(decay_d, (list, tuple)): for i in amp.decay_group: for d, j in zip(decay_d, i): if hasattr(j.core, "d"): j.core.d = d if hasattr(j, "d"): j.d = d elif isinstance(decay_d, dict): for i in amp.decay_group: for d, j in zip(decay_d, i): if j.core.name in decay_d: d = decay_d.get(j.core.name) if hasattr(j.core, "d"): j.core.d = d if hasattr(j, "d"): j.d = d else: raise ValueError("decay_d should be list or dict")
[docs] def add_gauss_constr_constraints(self, amp, dic=None): dic = {} if dic is None else dic self.gauss_constr_dic.update(dic)
[docs] def free_for_extended(self, amp): constrains = self.config.get("constrains", {}) if constrains is None: constrains = {} dic = constrains.get("decay", {}) if dic is None: dic = {} fix_total_idx = dic.get("fix_chain_idx", 0) fix_decay = amp.decay_group.get_decay_chain(fix_total_idx) var = fix_decay.total var.vm.set_fix(var.name + "_0r", unfix=True)
[docs] def add_particle_constraints(self, amp, dic=None): if dic is None: dic = {} res_dec = {} for d in amp.decay_group: for p_i in d.inner: i = p_i.name res_dec[i] = d prefix_map = { "m0": "mass", "g0": "width", "m_": "mass_", "g_": "width_", } particle_config = self.decay_config.particle_property[i] params_dic = particle_config.get("params", None) if params_dic is None: params_dic = {} for name in list(particle_config): for prefix_i in prefix_map.keys(): if name.startswith(prefix_i): name2 = ( prefix_map[prefix_i] + name[len(prefix_i) :] ) params_dic[name2] = particle_config[name] for prefix_i in prefix_map.values(): if name.startswith(prefix_i): params_dic[name] = particle_config[name] variable_prefix = p_i.get_variable_name() set_prefix_constrains(self.vm, p_i, params_dic, self) simple_map = {"m": "mass", "g": "width"} gauss_constr = particle_config.get("gauss_constr", None) if gauss_constr is not None: assert isinstance(gauss_constr, dict) for k, v in gauss_constr.items(): if v: name = simple_map.get(k, k) full_name = variable_prefix + name var0 = self.vm.get(full_name) self.gauss_constr_dic[full_name] = ( float(var0), v, ) else: raise Exception( f"Need sigma of {k} of {p_i} when adding gaussian constraint" ) if isinstance(p_i.mass, Variable) or isinstance( p_i.width, Variable ): if "float" in particle_config and particle_config["float"]: if "m" in particle_config["float"]: p_i.mass.freed() # set_fix(i+'_mass',unfix=True) if "mass_max" in params_dic: upper = params_dic["mass_max"] # elif m_sigma is not None: # upper = self.config["particle"][i]["m0"] + 10 * m_sigma else: upper = None if "mass_min" in params_dic: lower = params_dic["mass_min"] # elif m_sigma is not None: # lower = self.config["particle"][i]["m0"] - 10 * m_sigma else: lower = None self.bound_dic[str(p_i.mass)] = (lower, upper) else: self._neglect_when_set_params.append(str(p_i.mass)) if "g" in particle_config["float"]: p_i.width.freed() # amp.vm.set_fix(i+'_width',unfix=True) if "width_max" in params_dic: upper = params_dic["width_max"] # elif g_sigma is not None: # upper = self.config["particle"][i]["g0"] + 10 * g_sigma else: upper = None if "width_min" in params_dic: lower = params_dic["width_min"] # elif g_sigma is not None: # lower = self.config["particle"][i]["g0"] - 10 * g_sigma else: lower = None self.bound_dic[str(p_i.width)] = (lower, upper) else: self._neglect_when_set_params.append( str(p_i.width) ) else: self._neglect_when_set_params.append( i + "_mass" ) # p_i.mass.name self._neglect_when_set_params.append( i + "_width" ) # p_i.width.name # share helicity variables if "coef_head" in particle_config: coef_head = particle_config["coef_head"] if coef_head in res_dec: d_coef_head = res_dec[coef_head] for j, h in zip(d, d_coef_head): if i in [str(jj) for jj in j.outs] or i is str( j.core ): h.g_ls.sameas(j.g_ls) # share total radium d_coef_head.total.r_shareto(d.total) else: particle_config["coef_head"] = i equal_params = dic.get("equal", {}) for k, v in equal_params.items(): for vi in v: a = [] for i in amp.decay_group.resonances: if str(i) in vi: a.append(i) a0 = a.pop(0) arg = getattr(a0, k) for i in a: arg_i = getattr(i, k) if isinstance(arg_i, Variable): arg_i.sameas(arg)
@functools.lru_cache() def _get_model(self, vm=None, name=""): amp = self.get_amplitude(vm=vm, name=name) model_name = self.config["data"].get("model", "auto") w_bkg, w_inmc = self._get_bg_weight() model = [] if model_name == "cfit": print("using cfit") bg_function = self.config["data"].get("bg_function", None) eff_function = self.config["data"].get("eff_function", None) w_bkg = self.config["data"]["bg_frac"] if not isinstance(w_bkg, list): w_bkg = [w_bkg] if self.config["data"].get("extended", False): self.free_for_extended(amp) for wb in w_bkg: if self.config["data"].get("extended", False): model.append( ModelCfitExtended(amp, wb, bg_function, eff_function) ) elif self.config["data"].get("cached_amp", False): model.append( Model_cfit_cached(amp, wb, bg_function, eff_function) ) else: model.append( Model_cfit( amp, wb, bg_function, eff_function, resolution_size=self.resolution_size, ) ) elif "inmc" in self.config["data"]: float_wmc = self.config["data"].get( "float_inmc_ratio_in_pdf", False ) if not isinstance(float_wmc, list): float_wmc = [float_wmc] * self._Ngroup assert len(float_wmc) == self._Ngroup for wb, wi, fw in zip(w_bkg, w_inmc, float_wmc): model.append(Model_new(amp, wb, wi, fw)) elif self.config["data"].get("cached_int", False): for wb in w_bkg: model.append(ModelCachedInt(amp, wb)) elif self.config["data"].get("cached_amp", False): for wb in w_bkg: model.append(ModelCachedAmp(amp, wb)) elif model_name not in ["auto", "default"]: from tf_pwa.model.model import get_nll_model extended = self.config["data"].get("extended", False) if extended: self.free_for_extended(amp) NewModel = get_nll_model(model_name) nll_params = copy.copy(self.config.get("nll_model", {})) params = {} for i in getattr(NewModel, "required_params", []): if i in self.config["data"]: params[i] = self.config["data"][i] elif i in nll_params: params[i] = nll_params.pop(i) else: raise IndexError( "not found required params {} for nll model".format(i) ) nll_params = self.config.get("nll_model", {}) for idx, wb in enumerate(w_bkg): new_params = {k: v for k, v in nll_params.items()} for k, v in params.items(): if isinstance(v, list): new_params[k] = v[idx] else: new_params[k] = v model.append(NewModel(amp, w_bkg=wb, **new_params)) else: extended = self.config["data"].get("extended", False) if extended: self.free_for_extended(amp) for wb in w_bkg: model.append( Model( amp, wb, resolution_size=self.resolution_size, extended=extended, ) ) return model def _get_bg_weight(self, data=None, bg=None, display=True): w_bkg = self.config["data"].get("bg_weight", 0.0) if not isinstance(w_bkg, list): w_bkg = [w_bkg] * self._Ngroup assert len(w_bkg) == self._Ngroup w_inmc = self.config["data"].get("inject_ratio", 0.0) if not isinstance(w_inmc, list): w_inmc = [w_inmc] * self._Ngroup assert len(w_inmc) == self._Ngroup weight_scale = self.config["data"].get("weight_scale", False) # ??? if weight_scale: data = data if data is not None else self.get_data("data") bg = bg if bg is not None else self.get_data("bg") tmp = [] for wb, dt, sb in zip(w_bkg, data, bg): if isinstance(wb, str): wb = self.data.load_weight_file(wb) tmp.append(wb * data_shape(dt) / data_shape(sb)) w_bkg = tmp if display: print("background weight:", w_bkg) else: tmp = [] for wb in w_bkg: if isinstance(wb, str): wb = self.data.load_weight_file(wb) tmp.append(wb) w_bkg = tmp return w_bkg, w_inmc
[docs] def get_fcn(self, all_data=None, batch=65000, vm=None, name=""): if all_data is None: if vm in self.cached_fcn: return self.cached_fcn[vm] data, phsp, bg, inmc = self.get_all_data() else: data, phsp, bg, inmc = all_data self._Ngroup = len(data) if inmc is None: inmc = [None] * self._Ngroup if bg is None: bg = [None] * self._Ngroup model = self._get_model(vm=vm, name=name) fcns = [] # print(self.config["data"].get("using_mix_likelihood", False)) if self.config["data"].get("using_mix_likelihood", False): print(" Using Mix Likelihood") fcn = MixLogLikehoodFCN( model, data, phsp, bg=bg, batch=batch, gauss_constr=self.gauss_constr_dic, ) if all_data is None: self.cached_fcn[vm] = fcn return fcn for idx, (md, dt, mc, sb, ij) in enumerate( zip(model, data, phsp, bg, inmc) ): if self.config["data"].get("model", "auto") == "cfit": fcns.append( FCN( md, dt, mc, batch=batch, inmc=ij, gauss_constr=self.gauss_constr_dic, ) ) else: fcns.append( FCN( md, dt, mc, bg=sb, batch=batch, inmc=ij, gauss_constr=self.gauss_constr_dic, ) ) if len(fcns) == 1: fcn = fcns[0] else: fcn = CombineFCN(fcns=fcns, gauss_constr=self.gauss_constr_dic) if all_data is None: self.cached_fcn[vm] = fcn return fcn
[docs] def get_ndf(self): amp = self.get_amplitude() args_name = amp.vm.trainable_vars return len(args_name)
[docs] @staticmethod def reweight_init_value(amp, phsp, ns=None): """reset decay chain total and make the integration to be ns""" total = [i.total for i in amp.decay_group] n_phsp = data_shape(phsp) weight = np.array(phsp.get("weight", [1] * n_phsp)) sw = np.sum(weight) if ns is None: ns = [1] * len(total) elif isinstance(ns, (int, float)): ns = [ns / len(total)] * len(total) for i in total: i.set_rho(1.0) pw = amp.partial_weight(phsp) for i, w, ni in zip(total, pw, ns): i.set_rho(np.sqrt(ni / np.sum(weight * w) * sw))
[docs] @time_print def fit( self, data=None, phsp=None, bg=None, inmc=None, batch=65000, method="BFGS", check_grad=False, improve=False, reweight=False, maxiter=None, jac=True, print_init_nll=True, callback=None, grad_scale=1.0, gtol=1e-3, ): if data is None and phsp is None: data, phsp, bg, inmc = self.get_all_data() fcn = self.get_fcn(batch=batch) else: fcn = self.get_fcn([data, phsp, bg, inmc], batch=batch) if self.config["data"].get("lazy_call", False): print_init_nll = False # print("sss") amp = self.get_amplitude() print("decay chains included: ") for i in self.full_decay: ls_list = [getattr(j, "get_ls_list", lambda x: None)() for j in i] print(" ", i, " ls: ", *ls_list) if reweight: ConfigLoader.reweight_init_value( amp, phsp[0], ns=data_shape(data[0]) ) print("\n########### initial parameters") print(json.dumps(amp.get_params(), indent=2), flush=True) if print_init_nll: print("initial NLL: ", fcn({})) # amp.get_params())) # fit configure # self.bound_dic[""] = (,) self.fit_params = fit( fcn=fcn, method=method, bounds_dict=self.bound_dic, check_grad=check_grad, improve=False, maxiter=maxiter, jac=jac, callback=callback, grad_scale=grad_scale, gtol=gtol, ) if self.fit_params.hess_inv is not None: self.inv_he = self.fit_params.hess_inv return self.fit_params
[docs] def reinit_params(self): vm = self.get_amplitude().vm vm.refresh_vars(init_val=self.init_value, bound_dic=self.bound_dic)
[docs] def fitNtimes(self, N, *args, **kwargs): for i in range(N): self.reinit_params() fit_result = self.fit(*args, **kwargs) fit_pars = json.dumps(fit_result.params, indent=2) print(fit_pars, flush=True)
[docs] def get_params_error( self, params=None, data=None, phsp=None, bg=None, inmc=None, batch=10000, using_cached=False, method=None, force_pos=True, correct_params=None, ): """ calculate parameters error """ if params is None: params = {} if correct_params is None: correct_params = [] if method is None: method = "correct" if hasattr(params, "params"): params = getattr(params, "params") if not using_cached: if data is None: data, phsp, bg, inmc = self.get_all_data() fcn = self.get_fcn([data, phsp, bg, inmc], batch=batch) if using_cached and self.inv_he is not None: hesse_error = np.sqrt(np.fabs(self.inv_he.diagonal())).tolist() elif method == "3-point": self.inv_he = num_hess_inv_3point(fcn, params) diag_he = self.inv_he.diagonal() hesse_error = np.sqrt(np.fabs(diag_he)).tolist() elif method == "correct": h = cal_hesse_correct(fcn, params, correct_params) if force_pos: self.inv_he = force_pos_def(h) else: self.inv_he = np.linalg.pinv(h) diag_he = self.inv_he.diagonal() hesse_error = np.sqrt(np.fabs(diag_he)).tolist() else: hesse_error, self.inv_he = cal_hesse_error( fcn, params, check_posi_def=True, save_npy=True, force_pos=force_pos, ) # print("parameters order") # print(fcn.model.Amp.vm.trainable_vars) # print("error matrix:") # print(self.inv_he) # print("correlation matrix:") # print(corr_coef_matrix(self.inv_he)) print("hesse_error:", hesse_error) err = dict(zip(self.vm.trainable_vars, hesse_error)) if hasattr(self, "fit_params"): self.fit_params.set_error(err) return err
[docs] @classmethod def register_function(cls, name=None): def _f(f): my_name = name if my_name is None: my_name = f.__name__ if hasattr(cls, my_name): warnings.warn("override function {}".format(name)) setattr(cls, my_name, f) return f return _f
[docs] def get_chain(self, idx): decay_group = self.full_decay return decay_group.get_decay_chain(idx)
[docs] def cal_fitfractions( self, params={}, mcdata=None, res=None, exclude_res=[], batch=25000, method="old", ): if hasattr(params, "params"): params = getattr(params, "params") if mcdata is None: mcdata = self.get_phsp_noeff() if self.config["data"].get("lazy_call", False): method = "new" amp = self.get_amplitude() if res is None: res = sorted( list(set([str(i) for i in amp.res]) - set(exclude_res)) ) frac_and_err = fit_fractions( amp, mcdata, self.inv_he, params, batch, res, method=method ) return frac_and_err
[docs] def cal_signal_yields(self, params={}, mcdata=None, batch=25000): if hasattr(params, "params"): params = getattr(params, "params") if mcdata is None: mcdata = self.get_data("phsp") extended = self.config["data"].get("extended", False) amp = self.get_amplitude() fracs = [ fit_fractions(amp, i, self.inv_he, params, batch) for i in mcdata ] data = self.get_data("data") bg = self.get_data("bg") if bg is None: N_total = [data_shape(i) for i in data] for i in data: N_data = data_shape(i) N_total.append((N_data, np.sqrt(N_data))) else: bg_weight, _ = self._get_bg_weight(data, bg) N_total = [] for i, j, w in zip(data, bg, bg_weight): N_data = data_shape(i) N_bg = data_shape(j) N_total.append( (N_data - w * N_bg, np.sqrt(N_data + w * w * N_bg)) ) N_sig_s = [] for frac_e, N_e in zip(fracs, N_total): frac, frac_err = frac_e N, N_err = N_e if extended: N_err = 0.0 N_sig = {} for i in frac: N_sig[i] = ( frac[i] * N, np.sqrt( (N * frac_err.get(i, 0.0)) ** 2 + (N_err * frac[i]) ** 2 ), ) N_sig_s.append(N_sig) return N_sig_s
[docs] def likelihood_profile(self, var, var_min, var_max, N=100): params = self.get_params() var0 = params[var] delta_var = (var_max - var_min) / N vm = self.get_amplitude().vm unfix = var in vm.get_all_dic(True) nlls_up = [] vars_up = [] while var0 <= var_max: vm.set_fix(var, var0) fit_result = self.fit() vars_up.append(var0) nlls_up.append(fit_result.min_nll) var0 += delta_var self.set_params(params) var0 = params[var] - delta_var vars_down = [] nlls_down = [] while var0 >= var_min: vm.set_fix(var, var0) fit_result = self.fit() vars_down.append(var0) nlls_down.append(fit_result.min_nll) var0 -= delta_var self.set_params(params) vm.set_fix(var, params[var], unfix=unfix) return vars_down[::-1] + vars_up, nlls_down[::-1] + nlls_up
[docs] def get_params(self, trainable_only=False): return self.get_amplitude().get_params(trainable_only)
[docs] def set_params(self, params, neglect_params=None): if isinstance(params, str): if params == "": return False try: with open(params) as f: params = yaml.safe_load(f) except Exception as e: print(e) return False else: neglect_params = [] if hasattr(params, "params"): params = params.params if isinstance(params, dict): if "value" in params: params = params["value"] amplitude = self.get_amplitude() ret = params.copy() if neglect_params is None: neglect_params = self._neglect_when_set_params if len(neglect_params) != 0: for v in params: if v in self._neglect_when_set_params: warnings.warn( "Neglect {} when setting params.".format( neglect_params ) ) del ret[v] amplitude.set_params(ret) return True
@contextlib.contextmanager def mask_params(self, var): with self.vm.mask_params(var): yield
[docs] def save_params(self, file_name): params = self.get_params() val = {k: float(v) for k, v in params.items()} with open(file_name, "w") as f: json.dump(val, f, indent=2)
[docs] @contextlib.contextmanager def params_trans(self): with self.vm.error_trans(self.inv_he) as f: yield f
[docs] @contextlib.contextmanager def mask_params(self, params): with self.vm.mask_params(params): yield
[docs] def attach_fix_params_error(self, params: dict, V_b=None) -> np.ndarray: """ The minimal condition .. math:: -\\frac{\\partial\\ln L(a,b)}{\\partial a} = 0, can be treated as a implect function :math:`a(b)`. The gradients is .. math:: \\frac{\\partial a }{\\partial b} = - (\\frac{\\partial^2 \ln L(a,b)}{\\partial a \\partial a })^{-1} \\frac{\\partial \ln L(a,b)}{\\partial a\\partial b }. The uncertanties from b with error matrix :math:`V_b` can propagate to a as .. math:: V_a = \\frac{\\partial a }{\\partial b} V_b \\frac{\\partial a }{\\partial b} This matrix will be added to the config.inv_he. """ fcn = self.get_fcn() new_params = list(params) for i in new_params: fcn.vm.set_fix(i, unfix=True) all_params = list(fcn.vm.trainable_vars) old_params = [i for i in all_params if i not in new_params] _, _, hess = fcn.nll_grad_hessian() hess = data_to_numpy(hess) for i in new_params: fcn.vm.set_fix(i) idx_a = np.array([all_params.index(i) for i in old_params]) idx_b = np.array([all_params.index(i) for i in new_params]) hess_aa = hess[idx_a][:, idx_a] hess_ab = hess[idx_a][:, idx_b] hess_aa = np.stack(hess_aa) hess_ab = np.stack(hess_ab) grad = np.dot(np.linalg.inv(hess_aa), hess_ab) if V_b is None: V_b = np.diag(list(params.values())) ** 2 V = np.dot(np.dot(grad, V_b), grad.T) if self.inv_he is None: old_V = 0.0 else: old_V = self.inv_he new_V = old_V + V self.inv_he = new_V return V
[docs] def batch_sum_var(self, *args, **kwargs): return self.vm.batch_sum_var(*args, **kwargs)
[docs] def save_tensorflow_model(self, dir_name): class CustomModule(tf.Module): def __init__(self, config_name, share_dict, final_params): self.config = ConfigLoader(config_name, share_dict=share_dict) self.amp = self.config.get_amplitude() self.config.set_params(final_params) self.all_variables = self.amp.vm.variables @tf.function() def __call__(self, *p): data = self.config.data.cal_angle(p) return self.amp(data) module = CustomModule(self.config, self.share_dict, self.get_params()) n_p = len(self.get_dat_order()) input_p = [tf.TensorSpec([None, 4], tf.float64) for i in range(n_p)] call = module.__call__.get_concrete_function(*input_p) tf.saved_model.save( module, dir_name, signatures={"serving_default": call} )
[docs] def set_prefix_constrains(vm, base, params_dic, self): prefix = base.get_variable_name() p_list = [] for v in params_dic: vname = v for tail in ["_range", "_sigma", "_free", "_constr", "_min", "_max"]: if v.endswith(tail): vname = v[: -len(tail)] break if vname not in p_list: # print(vname, v) p_list.append(vname) vv = base.get_var(vname) # print(vv, prefix + vname) # if isinstance(vv, Variable):# getattr(p_i, vname) if vv is None: continue p_sigma = params_dic.get(vname + "_sigma", None) if vname in params_dic and params_dic[vname] is not None: p_value = params_dic[vname] vv.set_value(p_value) if p_sigma is None: self.init_value[vname] = p_value else: self.init_value[vname] = [p_value, p_sigma] else: p_value = None p_free = params_dic.get(vname + "_free", None) if p_free: vv.freed() elif p_free is False: vv.fixed() p_range = vname + "_range" if p_range in params_dic and params_dic[p_range] is not None: lower, upper = params_dic[p_range] self.bound_dic[vv.name] = (lower, upper) # vm.set_bound({vv.name: (lower, upper)}) else: lower = params_dic.get(vname + "_min") upper = params_dic.get(vname + "_max") # print(lower, upper) if lower is not None or upper is not None: self.bound_dic[vv.name] = (lower, upper) # vm.set_bound({vv.name: (lower, upper)}) # self.bound_dic[vv.name] = (lower, upper) # elif p_sigma is not None and p_value is not None: # p_10sigma = 10 * p_sigma # self.bound_dic[vv.name] = ( # p_value - p_10sigma, # p_value + p_10sigma, # ) p_constr = vname + "_constr" if p_constr in params_dic and params_dic[p_constr] is not None: if params_dic[p_constr]: if p_value is None: raise Exception( "Need central value of {0} of {1} when adding gaussian constraint".format( vname, prefix ) ) if p_sigma is None: raise Exception( "Need sigma of {0} of {1} when adding gaussian constraint".format( vname, prefix ) ) self.gauss_constr_dic[vv.name] = ( params_dic[vname], p_sigma, )
[docs] def validate_file_name(s): rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/ \ : * ? " < > |' name = re.sub(rstr, "_", s) return name
[docs] class PlotParams(dict): def __init__(self, plot_config, decay_struct): self.config = plot_config self.defaults_config = {} self.defaults_config.update(self.config.get("config", {})) self.decay_struct = decay_struct chain_map = self.decay_struct.get_chains_map() self.re_map = {} for i in chain_map: for _, j in i.items(): for k, v in j.items(): self.re_map[v] = k self.params = [] for i in self.get_mass_vars(): self.params.append(i) for i in self.get_angle_vars(): self.params.append(i) for i in self.get_angle_vars(True): self.params.append(i) for i in self.get_index_vars(): self.params.append(i) for i in self.get_extra_vars(): self.params.append(i)
[docs] def get_data_index(self, sub, name): dec = self.decay_struct.topology_structure() if sub == "mass": p = get_particle(name) return "particle", self.re_map.get(p, p), "m" if sub == "p": p = get_particle(name) return "particle", self.re_map.get(p, p), "p" if sub == "angle": name_i = name.split("/") de_i = self.decay_struct.get_decay_chain(name_i) p = get_particle(name_i[-1]) for i in de_i: if p in i.outs: de = i break else: raise IndexError("not found such decay {}".format(name)) return ( "decay", de_i.standard_topology(), self.re_map.get(de, de), self.re_map.get(p, p), "ang", ) if sub == "aligned_angle": name_i = name.split("/") de_i = self.decay_struct.get_decay_chain(name_i) p = get_particle(name_i[-1]) for i in de_i: if p in i.outs: de = i break else: raise IndexError("not found such decay {}".format(name)) return ( "decay", de_i.standard_topology(), self.re_map.get(de, de), self.re_map.get(p, p), "aligned_angle", ) if sub == "index": name_i = name.split("/") return name_i raise ValueError("unknown sub {}".format(sub))
[docs] def read_plot_config(self, v): upper_ylim = v.get("upper_ylim", None) xrange = v.get("range", None) units = v.get("units", "") bins = v.get("bins", self.defaults_config.get("bins", 50)) legend = v.get("legend", self.defaults_config.get("legend", True)) legend_outside = v.get( "legend_outside", self.defaults_config.get("legend_outside", False), ) yscale = v.get("yscale", self.defaults_config.get("yscale", "linear")) upper_ylim = v.get("upper_ylim", None) return { "upper_ylim": upper_ylim, "legend": legend, "legend_outside": legend_outside, "range": xrange, "bins": bins, "units": units, "yscale": yscale, }
[docs] def get_mass_vars(self): mass = self.config.get("mass", {}) x = sy.symbols("x") for k, v in mass.items(): id_ = v.get("id", k) display = v.get("display", "M({})".format(k)) trans = v.get("trans", None) if trans is None: trans = lambda x: x else: trans = sy.sympify(trans) trans = sy.lambdify(x, trans, modules="numpy") units = v.get("units", "GeV") common_config = self.read_plot_config(v) idx = ( "particle", self.re_map.get(get_particle(id_), get_particle(id_)), "m", ) yield { **common_config, "units": units, "name": "m_" + k, "display": display, "idx": idx, "trans": trans, "readdata": ReadData(idx, trans), }
[docs] def get_angle_vars(self, is_align=False): if not is_align: ang = self.config.get("angle", {}) else: ang = self.config.get("aligned_angle", {}) for k, i in ang.items(): id_ = i.get("id", k) names = id_.split("/") name = names[0] number_decay = True if len(names) > 1: try: count = int(names[-1]) except ValueError: number_decay = False else: count = 0 if number_decay: decay_chain, decay = None, None part = self.re_map.get(get_particle(name), get_particle(name)) for decs in self.decay_struct: for dec in decs: if dec.core == get_particle(name): decay = dec.core.decay[count] for j in self.decay_struct: if decay in j: decay_chain = j.standard_topology() decay = self.re_map.get(decay, decay) part = decay.outs[0] else: _, decay_chain, decay, part, _ = self.get_data_index( "angle", id_ ) for j, v in i.items(): display = v.get("display", j) theta = j trans = lambda x: x if "cos" in j: theta = j[4:-1] trans = np.cos common_config = self.read_plot_config(v) if is_align: ang_type = "aligned_angle" else: ang_type = "ang" name_id = validate_file_name(k + "_" + j) if is_align: name_id = "aligned_" + name_id idx = ("decay", decay_chain, decay, part, ang_type, theta) yield { **common_config, "name": name_id, "display": display, "idx": idx, "trans": trans, "readdata": ReadData(idx, trans), }
[docs] def get_extra_vars(self): from tf_pwa.formula import build_expr_function dic = self.config.get("extra_vars", {}) for k, v in dic.items(): expr = v["expr"] where = v.get("where", {}) f_expr, used_var = build_expr_function(expr) var_f = [] for i in used_var: idx = where.get(i, i) if isinstance(idx, (list, tuple)): idx = idx[0], "/".join(idx[1]) var_f = ReadData(self.get_data_index(*idx)) elif isinstance(idx, str): for j in self.params: if j["name"] == idx: var_f.append(j["readdata"]) else: raise TypeError("unknown variables for trans ") def readdata(x): var = [data_to_numpy(i(x)) for i in var_f] return f_expr(**dict(zip(used_var, var))) id_ = v.get("id", k) display = v.get("display", str(expr)) common_config = self.read_plot_config(v) yield { **common_config, "name": k, "display": display, "readdata": readdata, }
[docs] def get_index_vars(self): dic = self.config.get("index", {}) for k, v in dic.items(): idx = self.get_data_index("index", k) id_ = v.get("id", k) display = v.get("display", str(k)) trans = v.get("trans", None) if trans is None: trans = lambda x: x else: trans = sy.sympify(trans) x = sy.symbols("x") trans = sy.lambdify(x, trans, modules="numpy") common_config = self.read_plot_config(v) readdata = ReadData(idx, trans) yield { **common_config, "name": k, "display": display, "readdata": readdata, }
[docs] def get_params(self, params=None): if params is None: return self.params if isinstance(params, str): params = [params] params_list = [] for i in self.params: if i["display"] in params: params_list.append(i) return params_list