Source code for tf_pwa.amp.amp

import contextlib
import warnings

import numpy as np
import tensorflow as tf

from tf_pwa.amp.core import Variable, variable_scope
from tf_pwa.config import create_config, get_config, regist_config, temp_config
from tf_pwa.data import LazyCall, data_shape, split_generator

AMP_MODEL = "amplitude_model"
regist_config(AMP_MODEL, {})


[docs] def register_amp_model(name=None, f=None): """register a data mode :params name: mode name used in configuration :params f: Data Mode class """ def regist(g): if name is None: my_name = g.__name__ else: my_name = name config = get_config(AMP_MODEL) if my_name in config: warnings.warn("Override mode {}".format(my_name)) config[my_name] = g return g if f is None: return regist return regist(f)
[docs] def create_amplitude(decay_group, **kwargs): mode = kwargs.get("model", "default") return get_config(AMP_MODEL)[mode](decay_group, **kwargs)
[docs] class AbsPDF: def __init__( self, *args, name="", vm=None, polar=None, use_tf_function=False, no_id_cached=False, jit_compile=False, **kwargs ): self.name = name with variable_scope(vm) as vm: if polar is not None: vm.polar = polar self.init_params(name) self.vm = vm self.vm = vm self.no_id_cached = no_id_cached self.f_data = [] if use_tf_function: from tf_pwa.experimental.wrap_function import WrapFun self.cached_fun = WrapFun(self.pdf, jit_compile=jit_compile) else: self.cached_fun = self.pdf self.extra_kwargs = kwargs
[docs] def get_params(self, trainable_only=False): return self.vm.get_all_dic(trainable_only)
[docs] def set_params(self, var): self.vm.set_all(var)
[docs] @contextlib.contextmanager def temp_params(self, var): params = self.get_params() self.set_params(var) yield var self.set_params(params)
[docs] @contextlib.contextmanager def mask_params(self, var): with self.vm.mask_params(var): yield
@property def variables(self): return self.vm.variables @property def trainable_variables(self): return self.vm.trainable_variables
[docs] def cached_available(self): return True
def __call__(self, data, cached=False): if isinstance(data, LazyCall): data = data.eval() if id(data) in self.f_data or self.no_id_cached: if self.cached_available(): # decay_group.not_full: return self.cached_fun(data) else: self.f_data.append(id(data)) ret = self.pdf(data) return ret
[docs] class BaseAmplitudeModel(AbsPDF): def __init__(self, decay_group, **kwargs): self.decay_group = decay_group super().__init__(**kwargs) res = decay_group.resonances self.used_res = res self.res = res
[docs] def init_params(self, name=""): self.decay_group.init_params(name)
def __del__(self): if hasattr(self, "cached_fun"): del self.cached_fun # super(AmplitudeModel, self).__del__()
[docs] def cache_data(self, data, split=None, batch=None): for i in self.decay_group: for j in i.inner: print(j) if split is None and batch is None: return data else: n = data_shape(data) if batch is None: # split个一组,共batch组 batch = (n + split - 1) // split ret = list(split_generator(data, batch)) return ret
[docs] def set_used_res(self, res): self.decay_group.set_used_res(res)
[docs] @contextlib.contextmanager def temp_used_res(self, res): with self.decay_group.temp_used_res(res): yield
[docs] def set_used_chains(self, used_chains): self.decay_group.set_used_chains(used_chains)
[docs] def partial_weight(self, data, combine=None): if isinstance(data, LazyCall): data = data.eval() if combine is None: combine = [[i] for i in range(len(self.decay_group.chains))] o_used_chains = self.decay_group.chains_idx weights = [] for i in combine: self.decay_group.set_used_chains(i) weight = self.pdf(data) weights.append(weight) self.decay_group.set_used_chains(o_used_chains) return weights
[docs] def partial_weight_interference(self, data): return self.decay_group.partial_weight_interference(data)
[docs] def chains_particle(self): return self.decay_group.chains_particle()
[docs] def cached_available(self): return not self.decay_group.not_full
[docs] def pdf(self, data): ret = self.decay_group.sum_amp(data) return ret
[docs] def factor_iteration(self, deep=2): for i in self.decay_group.factor_iteration(deep): yield i
[docs] @contextlib.contextmanager def temp_total_gls_one(self): mask_part = [] for i in self.decay_group: mask_part.append(i) for j in i: mask_part.append(j) old_mask = [getattr(i, "mask_factor", False) for i in mask_part] for i in mask_part: i.mask_factor = True yield for i, j in zip(mask_part, old_mask): i.mask_factor = j
[docs] @register_amp_model("default") class AmplitudeModel(BaseAmplitudeModel):
[docs] def partial_weight(self, data, combine=None): if isinstance(data, LazyCall): data = data.eval() return self.decay_group.partial_weight(data, combine)
[docs] @register_amp_model("cached_amp") class CachedAmpAmplitudeModel(BaseAmplitudeModel):
[docs] def pdf(self, data): from tf_pwa.experimental.build_amp import build_params_vector n_data = data_shape(data) cached_data = data["cached_amp"] pv = build_params_vector(self.decay_group, data) partial_cached_data = [ cached_data[i] for i in self.decay_group.chains_idx ] ret = [] for idx, (i, j) in enumerate(zip(pv, partial_cached_data)): # print(j) # print(i.shape) a = tf.reshape(i, [-1, i.shape[1]] + [1] * (len(j[0].shape) - 1)) ret.append(tf.reduce_sum(a * tf.stack(j, axis=1), axis=1)) # print(ret) amp = tf.reduce_sum(ret, axis=0) return self.decay_group.sum_with_polarization(amp)
[docs] @register_amp_model("cached_shape") class CachedShapeAmplitudeModel(BaseAmplitudeModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cached_shape_idx = self.extra_kwargs.get("cached_shape_idx", None)
[docs] def get_cached_shape_idx(self): if self.cached_shape_idx is not None: return self.cached_shape_idx ret = [] for idx, decay_chain in enumerate(self.decay_group): for decay in decay_chain: if not decay.core.is_fixed_shape(): ret.append(idx) ret2 = [i for i in self.decay_group.chains_idx if i not in ret] self.cached_shape_idx = ret2 print("cached shape idx", ret2) return ret2
[docs] def pdf(self, data): from tf_pwa.experimental.build_amp import build_params_vector from tf_pwa.experimental.opt_int import build_params_vector as bv2 n_data = data_shape(data) cached_data = data["cached_amp"] cached_shape_idx = self.get_cached_shape_idx() old_chains_idx = self.decay_group.chains_idx cached_shape_idx = self.get_cached_shape_idx() ret = [] # amp parts without cached shape used_chains_idx = [ i for i in old_chains_idx if i not in cached_shape_idx ] self.decay_group.set_used_chains(used_chains_idx) pv = build_params_vector(self.decay_group, data) partial_cached_data = [cached_data[i] for i in used_chains_idx] self.decay_group.set_used_chains(old_chains_idx) ret = [] for idx, (i, j) in enumerate(zip(pv, partial_cached_data)): a = tf.reshape(i, [-1, i.shape[1]] + [1] * (len(j[0].shape) - 1)) ret.append(tf.reduce_sum(a * tf.stack(j, axis=1), axis=1)) # amp parts with cached shape cached_shape_idx2 = [ i for i in cached_shape_idx if i in old_chains_idx ] partial_cached_data2 = [cached_data[i] for i in cached_shape_idx2] pv2 = bv2(self.decay_group, concat=False) pv2 = [pv2[i] for i in cached_shape_idx2] for idx, (i, j) in enumerate(zip(pv2, partial_cached_data2)): a = tf.reshape(i, [-1, i.shape[0]] + [1] * (len(j[0].shape) - 1)) ret.append(tf.reduce_sum(a * j, axis=1)) # print(ret) amp = tf.reduce_sum(ret, axis=0) return self.decay_group.sum_with_polarization(amp)
[docs] @register_amp_model("base_factor") class FactorAmplitudeModel(BaseAmplitudeModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def get_amp_list(self, data): m_dep = self.decay_group.get_m_dep(data) if "cached_angle" in data: angle_amp = data["cached_angle"] else: angle_amp = self.decay_group.get_factor_angle_amp(data) ret = [] for a, b in zip(m_dep, angle_amp): tmp = b for i in a: total_size = np.prod(tmp.shape[1:]) if len(i.shape) == 1: i = tf.expand_dims(i, axis=-1) tmp = tf.reshape( tmp, (-1, i.shape[-1], total_size // i.shape[-1]) ) tmp = tmp * tf.expand_dims(i, axis=-1) tmp = tf.reduce_sum(tmp, axis=-2) ret.append(tmp) return ret
[docs] def get_amp_list_part(self, data): m_dep = self.decay_group.get_m_dep(data) if "cached_angle" in data: angle_amp = data["cached_angle"] else: angle_amp = self.decay_group.get_factor_angle_amp(data) ret = [] for a, b in zip(m_dep, angle_amp): tmp = b head_size = 1 for i in a: total_size = np.prod(tmp.shape[1:]) if len(i.shape) == 1: i = tf.expand_dims(i, axis=-1) tmp = tf.reshape( tmp, ( -1, head_size, i.shape[-1], total_size // i.shape[-1] // head_size, ), ) head_size *= i.shape[-1] tmp = tmp * tf.expand_dims(tf.expand_dims(i, axis=-1), axis=1) ret.append(tmp) return ret
[docs] def pdf(self, data): ret = self.get_amp_list(data) amp = tf.reduce_sum(ret, axis=0) return self.decay_group.sum_with_polarization(amp)
[docs] @register_amp_model("p4_directly") class P4DirectlyAmplitudeModel(BaseAmplitudeModel):
[docs] def cal_angle(self, p4): from tf_pwa.cal_angle import cal_angle_from_momentum extra_kwargs = self.extra_kwargs["all_config"] kwargs = {} for k in [ "center_mass", "r_boost", "random_z", "align_ref", "only_left_angle", ]: if k in extra_kwargs: kwargs[k] = extra_kwargs[k] ret = cal_angle_from_momentum(p4, self.decay_group, **kwargs) return ret
[docs] def pdf(self, data): new_data = self.cal_angle(data["p4"]) return self.decay_group.sum_amp({**new_data, **data})