Source code for tf_pwa.amp.amp

import contextlib
import warnings

import numpy as np
import tensorflow as tf

from tf_pwa.amp.core import Variable, variable_scope
from tf_pwa.config import create_config, get_config, regist_config, temp_config
from tf_pwa.data import LazyCall, data_shape, split_generator

AMP_MODEL = "amplitude_model"
regist_config(AMP_MODEL, {})


[docs] def register_amp_model(name=None, f=None): """register a data mode :params name: mode name used in configuration :params f: Data Mode class """ def regist(g): if name is None: my_name = g.__name__ else: my_name = name g.model_name = name config = get_config(AMP_MODEL) if my_name in config: warnings.warn("Override mode {}".format(my_name)) config[my_name] = g return g if f is None: return regist return regist(f)
[docs] def create_amplitude(decay_group, **kwargs): mode = kwargs.get("model", "default") if isinstance(mode, dict): if len(mode.keys()) == 1: key = list(mode.keys())[0] kwargs.update(mode[key]) if "model" in mode[key]: mode = mode[key]["model"] else: mode = key else: ret = {} for k, v in mode.items(): kwargs["model"] = {k: v} ret[k] = create_amplitude(decay_group, **kwargs) del kwargs["model"] return ProdPDF(decay_group, pdfs=ret, **kwargs) return get_config(AMP_MODEL)[mode](decay_group, **kwargs)
[docs] class AbsPDF: def __init__( self, *args, name="", vm=None, polar=None, use_tf_function=False, no_id_cached=False, jit_compile=False, **kwargs, ): self.name = name with variable_scope(vm) as vm: if polar is not None: vm.polar = polar self.init_params(name) self.vm = vm self.vm = vm self.no_id_cached = no_id_cached self.f_data = [] if use_tf_function: from tf_pwa.experimental.wrap_function import WrapFun self.cached_fun = WrapFun(self.pdf, jit_compile=jit_compile) else: self.cached_fun = self.pdf self.extra_kwargs = kwargs
[docs] def get_params(self, trainable_only=False): return self.vm.get_all_dic(trainable_only)
[docs] def set_params(self, var): self.vm.set_all(var)
[docs] @contextlib.contextmanager def temp_params(self, var): params = self.get_params() self.set_params(var) yield var self.set_params(params)
[docs] @contextlib.contextmanager def mask_params(self, var): with self.vm.mask_params(var): yield
@property def variables(self): return self.vm.variables @property def trainable_variables(self): return self.vm.trainable_variables
[docs] def cached_available(self): return True
def __call__(self, data, cached=False): if isinstance(data, LazyCall): data = data.eval() if id(data) in self.f_data and not self.no_id_cached: if self.cached_available(): # decay_group.not_full: return self.cached_fun(data) else: self.f_data.append(id(data)) ret = self.pdf(data) return ret
[docs] class BaseAmplitudeModel(AbsPDF): def __init__(self, decay_group, **kwargs): self.decay_group = decay_group super().__init__(**kwargs) res = decay_group.resonances self.used_res = res self.res = res
[docs] def init_params(self, name=""): self.decay_group.init_params(name)
def __del__(self): if hasattr(self, "cached_fun"): del self.cached_fun # super(AmplitudeModel, self).__del__()
[docs] def cache_data(self, data, split=None, batch=None): for i in self.decay_group: for j in i.inner: print(j) if split is None and batch is None: return data else: n = data_shape(data) if batch is None: # split个一组,共batch组 batch = (n + split - 1) // split ret = list(split_generator(data, batch)) return ret
[docs] def set_used_res(self, res): self.decay_group.set_used_res(res)
[docs] @contextlib.contextmanager def temp_used_res(self, res, only=False): with self.decay_group.temp_used_res(res, only=only): yield
[docs] def set_used_chains(self, used_chains): self.decay_group.set_used_chains(used_chains)
[docs] def partial_weight(self, data, combine=None): if isinstance(data, LazyCall): data = data.eval() if combine is None: combine = [[i] for i in range(len(self.decay_group.chains))] o_used_chains = self.decay_group.chains_idx weights = [] for i in combine: self.decay_group.set_used_chains(i) weight = self.pdf(data) weights.append(weight) self.decay_group.set_used_chains(o_used_chains) return weights
[docs] def partial_weight_interference(self, data): return self.decay_group.partial_weight_interference(data)
[docs] def chains_particle(self): return self.decay_group.chains_particle()
[docs] def cached_available(self): return not self.decay_group.not_full
[docs] def pdf(self, data): ret = self.decay_group.sum_amp(data) return ret
[docs] def factor_iteration(self, deep=2): for i in self.decay_group.factor_iteration(deep): yield i
[docs] @contextlib.contextmanager def temp_total_gls_one(self): mask_part = [] for i in self.decay_group: mask_part.append(i) for j in i: mask_part.append(j) old_mask = [getattr(i, "mask_factor", False) for i in mask_part] for i in mask_part: i.mask_factor = True yield for i, j in zip(mask_part, old_mask): i.mask_factor = j
[docs] @register_amp_model("default") class AmplitudeModel(BaseAmplitudeModel):
[docs] def partial_weight(self, data, combine=None): if isinstance(data, LazyCall): data = data.eval() return self.decay_group.partial_weight(data, combine)
[docs] @register_amp_model("cached_amp") class CachedAmpAmplitudeModel(BaseAmplitudeModel):
[docs] def pdf(self, data): from tf_pwa.experimental.build_amp import build_params_vector n_data = data_shape(data) cached_data = data["cached_amp"] pv = build_params_vector(self.decay_group, data) partial_cached_data = [ cached_data[i] for i in self.decay_group.chains_idx ] ret = [] for idx, (i, j) in enumerate(zip(pv, partial_cached_data)): # print(j) # print(i.shape) a = tf.reshape(i, [-1, i.shape[1]] + [1] * (len(j[0].shape) - 1)) ret.append(tf.reduce_sum(a * tf.stack(j, axis=1), axis=1)) # print(ret) amp = tf.reduce_sum(ret, axis=0) return self.decay_group.sum_with_polarization(amp)
[docs] @register_amp_model("cached_shape") class CachedShapeAmplitudeModel(BaseAmplitudeModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cached_shape_idx = self.extra_kwargs.get("cached_shape_idx", None)
[docs] def get_cached_shape_idx(self): if self.cached_shape_idx is not None: return self.cached_shape_idx ret = [] for idx, decay_chain in enumerate(self.decay_group): for decay in decay_chain: if not decay.core.is_fixed_shape(): ret.append(idx) ret2 = [i for i in self.decay_group.chains_idx if i not in ret] self.cached_shape_idx = ret2 print("cached shape idx", ret2) return ret2
[docs] def pdf(self, data): from tf_pwa.experimental.build_amp import build_params_vector from tf_pwa.experimental.opt_int import build_params_vector as bv2 n_data = data_shape(data) cached_data = data["cached_amp"] cached_shape_idx = self.get_cached_shape_idx() old_chains_idx = self.decay_group.chains_idx cached_shape_idx = self.get_cached_shape_idx() ret = [] # amp parts without cached shape used_chains_idx = [ i for i in old_chains_idx if i not in cached_shape_idx ] self.decay_group.set_used_chains(used_chains_idx) pv = build_params_vector(self.decay_group, data) partial_cached_data = [cached_data[i] for i in used_chains_idx] self.decay_group.set_used_chains(old_chains_idx) ret = [] for idx, (i, j) in enumerate(zip(pv, partial_cached_data)): a = tf.reshape(i, [-1, i.shape[1]] + [1] * (len(j[0].shape) - 1)) ret.append(tf.reduce_sum(a * tf.stack(j, axis=1), axis=1)) # amp parts with cached shape cached_shape_idx2 = [ i for i in cached_shape_idx if i in old_chains_idx ] partial_cached_data2 = [cached_data[i] for i in cached_shape_idx2] pv2 = bv2(self.decay_group, concat=False) pv2 = [pv2[i] for i in cached_shape_idx2] for idx, (i, j) in enumerate(zip(pv2, partial_cached_data2)): a = tf.reshape(i, [-1, i.shape[0]] + [1] * (len(j[0].shape) - 1)) ret.append(tf.reduce_sum(a * j, axis=1)) # print(ret) amp = tf.reduce_sum(ret, axis=0) return self.decay_group.sum_with_polarization(amp)
[docs] @register_amp_model("base_factor") class FactorAmplitudeModel(BaseAmplitudeModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def get_amp_list(self, data): m_dep = self.decay_group.get_m_dep(data) if "cached_angle" in data: angle_amp = data["cached_angle"] else: angle_amp = self.decay_group.get_factor_angle_amp(data) ret = [] for a, b in zip(m_dep, angle_amp): tmp = b for i in a: total_size = np.prod(tmp.shape[1:]) if len(i.shape) == 1: i = tf.expand_dims(i, axis=-1) tmp = tf.reshape( tmp, (-1, i.shape[-1], total_size // i.shape[-1]) ) tmp = tmp * tf.expand_dims(i, axis=-1) tmp = tf.reduce_sum(tmp, axis=-2) ret.append(tmp) return ret
[docs] def get_amp_list_part(self, data): m_dep = self.decay_group.get_m_dep(data) if "cached_angle" in data: angle_amp = data["cached_angle"] else: angle_amp = self.decay_group.get_factor_angle_amp(data) ret = [] for a, b in zip(m_dep, angle_amp): tmp = b head_size = 1 for i in a: total_size = np.prod(tmp.shape[1:]) if len(i.shape) == 1: i = tf.expand_dims(i, axis=-1) tmp = tf.reshape( tmp, ( -1, head_size, i.shape[-1], total_size // i.shape[-1] // head_size, ), ) head_size *= i.shape[-1] tmp = tmp * tf.expand_dims(tf.expand_dims(i, axis=-1), axis=1) ret.append(tmp) return ret
[docs] def pdf(self, data): ret = self.get_amp_list(data) amp = tf.reduce_sum(ret, axis=0) return self.decay_group.sum_with_polarization(amp)
[docs] @register_amp_model("p4_directly") class P4DirectlyAmplitudeModel(BaseAmplitudeModel): def __init__(self, *args, base_model="default", **kwargs): new_kwargs = kwargs.copy() new_kwargs["model"] = base_model self.ref_amp = create_amplitude(*args, **new_kwargs) super().__init__(*args, **kwargs)
[docs] def init_params(self, *args, **kwargs): super().init_params(*args, **kwargs) self.ref_amp.init_params(*args, **kwargs)
[docs] def cal_angle(self, p4): from tf_pwa.cal_angle import cal_angle_from_momentum extra_kwargs = self.extra_kwargs["all_config"] kwargs = {} for k in [ "center_mass", "r_boost", "random_z", "align_ref", "only_left_angle", ]: if k in extra_kwargs: kwargs[k] = extra_kwargs[k] ret = cal_angle_from_momentum(p4, self.decay_group, **kwargs) return ret
[docs] def pdf(self, data): new_data = self.cal_angle(data["p4"]) return self.ref_amp({**new_data, **data})
[docs] @register_amp_model("simple_mlp") class MLPModel(BaseAmplitudeModel): def __init__( self, *args, n_hidden=10, n_layers=2, activation="softplus", **kwargs ): if isinstance(n_hidden, int): self.n_hidden = [n_hidden] * (n_layers - 1) else: self.n_hidden = n_hidden self.n_layers = len(self.n_hidden) + 1 self.activation = getattr(tf.nn, activation) super().__init__(*args, **kwargs)
[docs] def init_params(self, name=""): self.decay_chain = self.decay_group[0] from tf_pwa.data_trans.helicity_angle import HelicityAngle self.ha = HelicityAngle(self.decay_chain) self.top = self.decay_group.top n_decay = len(self.decay_chain) n_finals = n_decay + 1 self.Ws = [] self.Bs = [] for i in range(self.n_layers): if i == 0: n_input = n_decay * 3 else: n_input = self.n_hidden[i - 1] if i == self.n_layers - 1: n_output = 1 else: n_output = self.n_hidden[i] self.Ws.append( self.top.add_var(f"W{i}", shape=(n_input, n_output)) ) self.Bs.append(self.top.add_var(f"b{i}", shape=(n_output,)))
[docs] def pdf(self, data): mass, costheta, phi = self.ha.find_variable(data) m = [mass[i.core] for i in self.decay_chain] x = tf.stack(tf.nest.flatten([m, costheta, phi]), axis=-1) for i in range(self.n_layers): w = self.Ws[i]() x = tf.matmul(x, w) + self.Bs[i]() x = self.activation(x) return x[..., 0]
[docs] class ProdPDF(BaseAmplitudeModel): def __init__(self, *args, pdfs, **kwargs): super().__init__(*args, **kwargs) self.pdfs = pdfs
[docs] def partial_weight(self, data, combine=None): pw = [ f.partial_weight(data, combine=combine) for k, f in self.pdfs.items() ] sum_pw = [] for i in range(len(pw[0])): tmp = [] for j in range(len(pw)): tmp.append(pw[j][i]) sum_pw.append(tf.reduce_prod(tmp, axis=0)) return sum_pw
[docs] def pdf(self, data): y = [f.pdf(data) for k, f in self.pdfs.items()] return tf.reduce_prod(y, axis=0)
[docs] @register_amp_model("constant") class ConstantPDF(BaseAmplitudeModel):
[docs] def pdf(self, data): return tf.ones_like(data.get_weight())