Source code for tf_pwa.amp.preprocess

import warnings

import tensorflow as tf

from tf_pwa.cal_angle import (
    CalAngleData,
    cal_angle_from_momentum,
    parity_trans,
)
from tf_pwa.config import create_config, get_config, regist_config, temp_config
from tf_pwa.data import HeavyCall, data_index, data_strip

PREPROCESSOR_MODEL = "preprocessor_model"
regist_config(PREPROCESSOR_MODEL, {})


[docs] def register_preprocessor(name=None, f=None): """register a data mode :params name: mode name used in configuration :params f: Data Mode class """ def regist(g): if name is None: my_name = g.__name__ else: my_name = name config = get_config(PREPROCESSOR_MODEL) if my_name in config: warnings.warn("Override mode {}".format(my_name)) config[my_name] = g return g if f is None: return regist return regist(f)
[docs] def create_preprocessor(decay_group, **kwargs): model = kwargs.get("model", "default") if "model" in kwargs: del kwargs["model"] if isinstance(model, (tuple, list)): ret = [] for model_i in model: ret.append( create_preprocessor(decay_group, model=model_i, **kwargs) ) return SeqPreProcessor(ret) elif isinstance(model, dict): assert len(model.keys()) == 1 name = list(model.keys())[0] new_kwargs = kwargs.copy() new_kwargs.update(model[name]) return create_preprocessor(decay_group, model=name, **new_kwargs) elif isinstance(model, str): return get_config(PREPROCESSOR_MODEL)[model]( decay_group, model=model, **kwargs ) else: raise TypeError("not support model type : {}".format(type(model)))
[docs] @register_preprocessor("default") class BasePreProcessor(HeavyCall): def __init__( self, decay_struct, root_config=None, model="defualt", data_type=None, **kwargs, ): self.decay_struct = decay_struct self.kwargs = kwargs self.model = model self.root_config = root_config self.data_type = data_type def __call__(self, x, **kwargs): data_type = kwargs.get("data_type", "data") if self.data_type is not None and data_type not in self.data_type: return x return self.call(x, **kwargs)
[docs] def call(self, x, **kwargs): if "particle" in x: return x p4 = x["p4"] if self.kwargs.get("cp_trans", False): charges = x.get("extra", {}).get("charge_conjugation", None) p4 = {k: parity_trans(v, charges) for k, v in p4.items()} kwargs = {} for k in [ "center_mass", "r_boost", "random_z", "align_ref", "only_left_angle", ]: if k in self.kwargs: kwargs[k] = self.kwargs[k] ret = cal_angle_from_momentum(p4, self.decay_struct, **kwargs) # TODO: rethink of extra, duplicate with lazy call for k, v in x.get("extra", {}).items(): ret[k] = v return ret
[docs] def list_to_tuple(data): if isinstance(data, list): return tuple([list_to_tuple(i) for i in data]) return data
[docs] class SeqPreProcessor(BasePreProcessor): def __init__(self, preprocessors): self.preprocessors = preprocessors def __call__(self, x, **kwargs): for f in self.preprocessors: x = f(x, **kwargs) return x
@register_preprocessor("cached_amp") class CachedAmpPreProcessor(BasePreProcessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.amp = self.root_config.get_amplitude() self.decay_group = self.amp.decay_group self.no_angle = self.kwargs.get("no_angle", False) self.no_p4 = self.kwargs.get("no_p4", False) def build_cached(self, x): from tf_pwa.experimental.build_amp import build_angle_amp_matrix # {**x2, **x["extra"]} # print(x["c"]) idx, c_amp = build_angle_amp_matrix(self.decay_group, x) x["cached_amp"] = list_to_tuple(c_amp) # print(x) return x def strip_data(self, x): strip_var = [] if self.no_angle: strip_var += ["ang", "aligned_angle"] if self.no_p4: strip_var += ["p"] if strip_var: x = data_strip(x, strip_var) return x def call(self, x, **kwargs): x = super().call(x, **kwargs) x = self.build_cached(x) x = self.strip_data(x) return x
[docs] @register_preprocessor("cached_shape") class CachedShapePreProcessor(CachedAmpPreProcessor):
[docs] def build_cached(self, x): from tf_pwa.experimental.build_amp import build_params_vector # old_chains_idx = self.decay_group.chains_idx cached_shape_idx = self.amp.get_cached_shape_idx() # used_chains_idx = [i for i in old_chains_idx if i not in cached_shape_idx] # self.decay_group.set_used_chains(used_chains_idx) x = super().build_cached(x) # self.decay_group.set_used_chains(old_chains_idx) old_cached_amp = list(x["cached_amp"]) dec = self.decay_group used_chains = dec.chains_idx dec.set_used_chains(cached_shape_idx) with self.amp.temp_total_gls_one(): pv = build_params_vector(dec, x) hij = [] for k, i in zip(cached_shape_idx, pv): tmp = old_cached_amp[k] # m_dep * angle_amp a = tf.reshape(i, [-1, i.shape[1]] + [1] * (len(tmp[0].shape) - 1)) old_cached_amp[k] = a * tf.stack(tmp, axis=1) dec.set_used_chains(used_chains) x["cached_amp"] = list_to_tuple(old_cached_amp) return x
[docs] @register_preprocessor("cached_angle") class CachedAnglePreProcessor(BasePreProcessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.amp = self.root_config.get_amplitude() self.decay_group = self.amp.decay_group self.no_angle = self.kwargs.get("no_angle", False) self.no_p4 = self.kwargs.get("no_p4", False)
[docs] def call(self, x, **kwargs): x2 = super().call(x) c_amp = self.decay_group.get_factor_angle_amp(x2) x2["cached_angle"] = list_to_tuple(c_amp) # print(x) return x2
[docs] @register_preprocessor("p4_directly") class CachedAmpPreProcessor(BasePreProcessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def call(self, x, **kwargs): return {"p4": x["p4"]}
[docs] @register_preprocessor("add_dalitz_var") class AddDalitzVarPreProcessor(BasePreProcessor): def __init__(self, *args, particles=None, **kwargs): super().__init__(*args, **kwargs) if particles is None: particles = self.decay_struct.outs assert len(particles) == 3, "Dlatiz plot require 3 final particles" self.particles = [self.decay_struct.get_particle(i) for i in particles] self.decay = self.find_decay(self.particles) top_map = self.decay.topology_map(self.decay.standard_topology()) self.index_particles = [top_map[i] for i in self.particles]
[docs] def find_decay(self, particles): return self.decay_struct[0]
[docs] def call(self, x, **kwargs): from tf_pwa.angle import LorentzVector as lv pi = [x["particle"][i]["p"] for i in self.index_particles] x["dalitz_var"] = { "s12": lv.M2(pi[0] + pi[1]), "s13": lv.M2(pi[0] + pi[2]), "s23": lv.M2(pi[2] + pi[1]), "m0": lv.M(pi[0] + pi[1] + pi[2]), "m1": lv.M(pi[0]), "m2": lv.M(pi[1]), "m3": lv.M(pi[2]), } return x
[docs] @register_preprocessor("bin_index") class AddBinIndexPreProcessor(BasePreProcessor): def __init__( self, *args, binning_variables=None, binning_schemes=None, binning_edges=None, **kwargs, ): super().__init__(*args, **kwargs) self.binning_variables = binning_variables self.idx = [] for i in self.binning_variables: if isinstance(i, (tuple, list)): self.idx.append(self.root_config.get_data_index(*i)) else: raise NotImplementedError self.binning_edges = binning_edges self.binning_schemes = binning_schemes assert len(self.binning_edges) == len( self.binning_schemes ), "require same size of edges and scheme" assert len(self.binning_edges) == len( self.idx ), "require same size of edges and variables"
[docs] def call(self, x, **kwargs): v = [data_index(x, i) for i in self.idx] idx = 0 for vi, (l, r), n in zip(v, self.binning_edges, self.binning_schemes): ratio = tf.clip_by_value((vi - l) / (r - l), 0.0, 1) n_idx = tf.cast(ratio * n, tf.int32) idx = idx * n + n_idx x["bin_index"] = idx return x
[docs] @register_preprocessor("add_ref_amp") class AddRefAmpPreProcessor(BasePreProcessor): def __init__( self, *args, config=None, params=None, varname="ref_amp", **kwargs ): super().__init__(*args, **kwargs) from tf_pwa.config_loader import ConfigLoader self.params = {} if params is None else params config = ConfigLoader(config) self.config = config config.set_params(self.params) self.ref_amp = config.get_amplitude() self.varname = varname
[docs] def call(self, x, **kwargs): a = self.ref_amp(x) x[self.varname] = a return x
[docs] @register_preprocessor("add_ref_amp_complex") class AddRefAmpCPreProcessor(AddRefAmpPreProcessor):
[docs] def call(self, x, **kwargs): a = self.ref_amp.decay_group.get_amp3(x) x[self.varname] = a return x
[docs] @register_preprocessor("repeat_values") class RepeatValuesPreProcessor(BasePreProcessor): def __init__(self, *args, varname="tag", values=[-1, 1], **kwargs): self.varname = varname self.values = values super().__init__(*args, **kwargs)
[docs] def call(self, x, **kwargs): from tf_pwa.data import data_repeat, data_shape shape = data_shape(x) x = data_repeat(x) vs = tf.cast(tf.stack(self.values), dtype=get_config("dtype")) v = tf.ones([shape, 1], dtype=get_config("dtype")) * vs x[self.varname] = tf.reshape(v, (-1,)) return x