import contextlib
import warnings
import numpy as np
import tensorflow as tf
from tf_pwa.amp.core import Variable, variable_scope
from tf_pwa.config import create_config, get_config, regist_config, temp_config
from tf_pwa.data import LazyCall, data_shape, split_generator
AMP_MODEL = "amplitude_model"
regist_config(AMP_MODEL, {})
[docs]
def register_amp_model(name=None, f=None):
"""register a data mode
:params name: mode name used in configuration
:params f: Data Mode class
"""
def regist(g):
if name is None:
my_name = g.__name__
else:
my_name = name
g.model_name = name
config = get_config(AMP_MODEL)
if my_name in config:
warnings.warn("Override mode {}".format(my_name))
config[my_name] = g
return g
if f is None:
return regist
return regist(f)
[docs]
def create_amplitude(decay_group, **kwargs):
mode = kwargs.get("model", "default")
if isinstance(mode, dict):
if len(mode.keys()) == 1:
key = list(mode.keys())[0]
kwargs.update(mode[key])
if "model" in mode[key]:
mode = mode[key]["model"]
else:
mode = key
else:
ret = {}
for k, v in mode.items():
kwargs["model"] = {k: v}
ret[k] = create_amplitude(decay_group, **kwargs)
del kwargs["model"]
return ProdPDF(decay_group, pdfs=ret, **kwargs)
return get_config(AMP_MODEL)[mode](decay_group, **kwargs)
[docs]
class AbsPDF:
def __init__(
self,
*args,
name="",
vm=None,
polar=None,
use_tf_function=False,
no_id_cached=False,
jit_compile=False,
**kwargs,
):
self.name = name
with variable_scope(vm) as vm:
if polar is not None:
vm.polar = polar
self.init_params(name)
self.vm = vm
self.vm = vm
self.no_id_cached = no_id_cached
self.f_data = []
if use_tf_function:
from tf_pwa.experimental.wrap_function import WrapFun
self.cached_fun = WrapFun(self.pdf, jit_compile=jit_compile)
else:
self.cached_fun = self.pdf
self.extra_kwargs = kwargs
[docs]
def get_params(self, trainable_only=False):
return self.vm.get_all_dic(trainable_only)
[docs]
def set_params(self, var):
self.vm.set_all(var)
[docs]
@contextlib.contextmanager
def temp_params(self, var):
params = self.get_params()
self.set_params(var)
yield var
self.set_params(params)
[docs]
@contextlib.contextmanager
def mask_params(self, var):
with self.vm.mask_params(var):
yield
@property
def variables(self):
return self.vm.variables
@property
def trainable_variables(self):
return self.vm.trainable_variables
[docs]
def cached_available(self):
return True
def __call__(self, data, cached=False):
if isinstance(data, LazyCall):
data = data.eval()
if id(data) in self.f_data and not self.no_id_cached:
if self.cached_available(): # decay_group.not_full:
return self.cached_fun(data)
else:
self.f_data.append(id(data))
ret = self.pdf(data)
return ret
[docs]
class BaseAmplitudeModel(AbsPDF):
def __init__(self, decay_group, **kwargs):
self.decay_group = decay_group
super().__init__(**kwargs)
res = decay_group.resonances
self.used_res = res
self.res = res
[docs]
def init_params(self, name=""):
self.decay_group.init_params(name)
def __del__(self):
if hasattr(self, "cached_fun"):
del self.cached_fun
# super(AmplitudeModel, self).__del__()
[docs]
def cache_data(self, data, split=None, batch=None):
for i in self.decay_group:
for j in i.inner:
print(j)
if split is None and batch is None:
return data
else:
n = data_shape(data)
if batch is None: # split个一组,共batch组
batch = (n + split - 1) // split
ret = list(split_generator(data, batch))
return ret
[docs]
def set_used_res(self, res):
self.decay_group.set_used_res(res)
[docs]
@contextlib.contextmanager
def temp_used_res(self, res, only=False):
with self.decay_group.temp_used_res(res, only=only):
yield
[docs]
def set_used_chains(self, used_chains):
self.decay_group.set_used_chains(used_chains)
[docs]
def partial_weight(self, data, combine=None):
if isinstance(data, LazyCall):
data = data.eval()
if combine is None:
combine = [[i] for i in range(len(self.decay_group.chains))]
o_used_chains = self.decay_group.chains_idx
weights = []
for i in combine:
self.decay_group.set_used_chains(i)
weight = self.pdf(data)
weights.append(weight)
self.decay_group.set_used_chains(o_used_chains)
return weights
[docs]
def partial_weight_interference(self, data):
return self.decay_group.partial_weight_interference(data)
[docs]
def chains_particle(self):
return self.decay_group.chains_particle()
[docs]
def cached_available(self):
return not self.decay_group.not_full
[docs]
def pdf(self, data):
ret = self.decay_group.sum_amp(data)
return ret
[docs]
def factor_iteration(self, deep=2):
for i in self.decay_group.factor_iteration(deep):
yield i
[docs]
@contextlib.contextmanager
def temp_total_gls_one(self):
mask_part = []
for i in self.decay_group:
mask_part.append(i)
for j in i:
mask_part.append(j)
old_mask = [getattr(i, "mask_factor", False) for i in mask_part]
for i in mask_part:
i.mask_factor = True
yield
for i, j in zip(mask_part, old_mask):
i.mask_factor = j
[docs]
@register_amp_model("default")
class AmplitudeModel(BaseAmplitudeModel):
[docs]
def partial_weight(self, data, combine=None):
if isinstance(data, LazyCall):
data = data.eval()
return self.decay_group.partial_weight(data, combine)
[docs]
@register_amp_model("cached_amp")
class CachedAmpAmplitudeModel(BaseAmplitudeModel):
[docs]
def pdf(self, data):
from tf_pwa.experimental.build_amp import build_params_vector
n_data = data_shape(data)
cached_data = data["cached_amp"]
pv = build_params_vector(self.decay_group, data)
partial_cached_data = [
cached_data[i] for i in self.decay_group.chains_idx
]
ret = []
for idx, (i, j) in enumerate(zip(pv, partial_cached_data)):
# print(j)
# print(i.shape)
a = tf.reshape(i, [-1, i.shape[1]] + [1] * (len(j[0].shape) - 1))
ret.append(tf.reduce_sum(a * tf.stack(j, axis=1), axis=1))
# print(ret)
amp = tf.reduce_sum(ret, axis=0)
return self.decay_group.sum_with_polarization(amp)
[docs]
@register_amp_model("cached_shape")
class CachedShapeAmplitudeModel(BaseAmplitudeModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_shape_idx = self.extra_kwargs.get("cached_shape_idx", None)
[docs]
def get_cached_shape_idx(self):
if self.cached_shape_idx is not None:
return self.cached_shape_idx
ret = []
for idx, decay_chain in enumerate(self.decay_group):
for decay in decay_chain:
if not decay.core.is_fixed_shape():
ret.append(idx)
ret2 = [i for i in self.decay_group.chains_idx if i not in ret]
self.cached_shape_idx = ret2
print("cached shape idx", ret2)
return ret2
[docs]
def pdf(self, data):
from tf_pwa.experimental.build_amp import build_params_vector
from tf_pwa.experimental.opt_int import build_params_vector as bv2
n_data = data_shape(data)
cached_data = data["cached_amp"]
cached_shape_idx = self.get_cached_shape_idx()
old_chains_idx = self.decay_group.chains_idx
cached_shape_idx = self.get_cached_shape_idx()
ret = []
# amp parts without cached shape
used_chains_idx = [
i for i in old_chains_idx if i not in cached_shape_idx
]
self.decay_group.set_used_chains(used_chains_idx)
pv = build_params_vector(self.decay_group, data)
partial_cached_data = [cached_data[i] for i in used_chains_idx]
self.decay_group.set_used_chains(old_chains_idx)
ret = []
for idx, (i, j) in enumerate(zip(pv, partial_cached_data)):
a = tf.reshape(i, [-1, i.shape[1]] + [1] * (len(j[0].shape) - 1))
ret.append(tf.reduce_sum(a * tf.stack(j, axis=1), axis=1))
# amp parts with cached shape
cached_shape_idx2 = [
i for i in cached_shape_idx if i in old_chains_idx
]
partial_cached_data2 = [cached_data[i] for i in cached_shape_idx2]
pv2 = bv2(self.decay_group, concat=False)
pv2 = [pv2[i] for i in cached_shape_idx2]
for idx, (i, j) in enumerate(zip(pv2, partial_cached_data2)):
a = tf.reshape(i, [-1, i.shape[0]] + [1] * (len(j[0].shape) - 1))
ret.append(tf.reduce_sum(a * j, axis=1))
# print(ret)
amp = tf.reduce_sum(ret, axis=0)
return self.decay_group.sum_with_polarization(amp)
[docs]
@register_amp_model("base_factor")
class FactorAmplitudeModel(BaseAmplitudeModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def get_amp_list(self, data):
m_dep = self.decay_group.get_m_dep(data)
if "cached_angle" in data:
angle_amp = data["cached_angle"]
else:
angle_amp = self.decay_group.get_factor_angle_amp(data)
ret = []
for a, b in zip(m_dep, angle_amp):
tmp = b
for i in a:
total_size = np.prod(tmp.shape[1:])
if len(i.shape) == 1:
i = tf.expand_dims(i, axis=-1)
tmp = tf.reshape(
tmp, (-1, i.shape[-1], total_size // i.shape[-1])
)
tmp = tmp * tf.expand_dims(i, axis=-1)
tmp = tf.reduce_sum(tmp, axis=-2)
ret.append(tmp)
return ret
[docs]
def get_amp_list_part(self, data):
m_dep = self.decay_group.get_m_dep(data)
if "cached_angle" in data:
angle_amp = data["cached_angle"]
else:
angle_amp = self.decay_group.get_factor_angle_amp(data)
ret = []
for a, b in zip(m_dep, angle_amp):
tmp = b
head_size = 1
for i in a:
total_size = np.prod(tmp.shape[1:])
if len(i.shape) == 1:
i = tf.expand_dims(i, axis=-1)
tmp = tf.reshape(
tmp,
(
-1,
head_size,
i.shape[-1],
total_size // i.shape[-1] // head_size,
),
)
head_size *= i.shape[-1]
tmp = tmp * tf.expand_dims(tf.expand_dims(i, axis=-1), axis=1)
ret.append(tmp)
return ret
[docs]
def pdf(self, data):
ret = self.get_amp_list(data)
amp = tf.reduce_sum(ret, axis=0)
return self.decay_group.sum_with_polarization(amp)
[docs]
@register_amp_model("p4_directly")
class P4DirectlyAmplitudeModel(BaseAmplitudeModel):
def __init__(self, *args, base_model="default", **kwargs):
new_kwargs = kwargs.copy()
new_kwargs["model"] = base_model
self.ref_amp = create_amplitude(*args, **new_kwargs)
super().__init__(*args, **kwargs)
[docs]
def init_params(self, *args, **kwargs):
super().init_params(*args, **kwargs)
self.ref_amp.init_params(*args, **kwargs)
[docs]
def cal_angle(self, p4):
from tf_pwa.cal_angle import cal_angle_from_momentum
extra_kwargs = self.extra_kwargs["all_config"]
kwargs = {}
for k in [
"center_mass",
"r_boost",
"random_z",
"align_ref",
"only_left_angle",
]:
if k in extra_kwargs:
kwargs[k] = extra_kwargs[k]
ret = cal_angle_from_momentum(p4, self.decay_group, **kwargs)
return ret
[docs]
def pdf(self, data):
new_data = self.cal_angle(data["p4"])
return self.ref_amp({**new_data, **data})
[docs]
@register_amp_model("simple_mlp")
class MLPModel(BaseAmplitudeModel):
def __init__(
self, *args, n_hidden=10, n_layers=2, activation="softplus", **kwargs
):
if isinstance(n_hidden, int):
self.n_hidden = [n_hidden] * (n_layers - 1)
else:
self.n_hidden = n_hidden
self.n_layers = len(self.n_hidden) + 1
self.activation = getattr(tf.nn, activation)
super().__init__(*args, **kwargs)
[docs]
def init_params(self, name=""):
self.decay_chain = self.decay_group[0]
from tf_pwa.data_trans.helicity_angle import HelicityAngle
self.ha = HelicityAngle(self.decay_chain)
self.top = self.decay_group.top
n_decay = len(self.decay_chain)
n_finals = n_decay + 1
self.Ws = []
self.Bs = []
for i in range(self.n_layers):
if i == 0:
n_input = n_decay * 3
else:
n_input = self.n_hidden[i - 1]
if i == self.n_layers - 1:
n_output = 1
else:
n_output = self.n_hidden[i]
self.Ws.append(
self.top.add_var(f"W{i}", shape=(n_input, n_output))
)
self.Bs.append(self.top.add_var(f"b{i}", shape=(n_output,)))
[docs]
def pdf(self, data):
mass, costheta, phi = self.ha.find_variable(data)
m = [mass[i.core] for i in self.decay_chain]
x = tf.stack(tf.nest.flatten([m, costheta, phi]), axis=-1)
for i in range(self.n_layers):
w = self.Ws[i]()
x = tf.matmul(x, w) + self.Bs[i]()
x = self.activation(x)
return x[..., 0]
[docs]
class ProdPDF(BaseAmplitudeModel):
def __init__(self, *args, pdfs, **kwargs):
super().__init__(*args, **kwargs)
self.pdfs = pdfs
[docs]
def partial_weight(self, data, combine=None):
pw = [
f.partial_weight(data, combine=combine)
for k, f in self.pdfs.items()
]
sum_pw = []
for i in range(len(pw[0])):
tmp = []
for j in range(len(pw)):
tmp.append(pw[j][i])
sum_pw.append(tf.reduce_prod(tmp, axis=0))
return sum_pw
[docs]
def pdf(self, data):
y = [f.pdf(data) for k, f in self.pdfs.items()]
return tf.reduce_prod(y, axis=0)
[docs]
@register_amp_model("constant")
class ConstantPDF(BaseAmplitudeModel):
[docs]
def pdf(self, data):
return tf.ones_like(data.get_weight())