Source code for tf_pwa.experimental.extra_data

import warnings

import numpy as np

from tf_pwa.amp import get_particle
from tf_pwa.cal_angle import cal_angle_from_momentum
from tf_pwa.config_loader.data import SimpleData, register_data_mode
from tf_pwa.data import data_shape


[docs] @register_data_mode("simple_npz") class NpzData(SimpleData):
[docs] def get_particle_p(self): order = self.dic.get("particle_p", None) if order is None: return self.get_dat_order() return order
[docs] def get_data(self, idx) -> dict: if self.cached_data is not None: data = self.cached_data.get(idx, None) if data is not None: return data files = self.get_data_file(idx) weight_sign = self.get_weight_sign(idx) return self.load_data(files, weight_sign)
[docs] def load_data( self, files, weights=None, weights_sign=1, charge=None ) -> dict: # print(files, weights) if files is None: return None order = self.get_dat_order() p_list = self.get_particle_p() center_mass = self.dic.get("center_mass", True) r_boost = self.dic.get("r_boost", False) random_z = self.dic.get("random_z", False) npz_data = np.load(files) p = { get_particle(str(v)): npz_data[str(k)] for k, v in zip(p_list, order) } data = cal_angle_from_momentum( p, self.decay_struct, center_mass=center_mass, r_boost=r_boost, random_z=random_z, ) if "weight" in npz_data: data["weight"] = npz_data["weight"] if "charge_conjugation" in npz_data: data["charge_conjugation"] = npz_data["charge_conjugation"] else: data["charge_conjugation"] = np.ones((data_shape(data),)) return data
[docs] @register_data_mode("multi_npz") class MultiNpzData(NpzData): def __init__(self, *args, **kwargs): super(MultiNpzData, self).__init__(*args, **kwargs) self._Ngroup = 0
[docs] def get_data(self, idx) -> list: if self.cached_data is not None: data = self.cached_data.get(idx, None) if data is not None: return data files = self.get_data_file(idx) if files is None: return None if not isinstance(files[0], list): files = [files] weight_sign = self.get_weight_sign(idx) ret = [self.load_data(i, weight_sign) for i in files] if self._Ngroup == 0: self._Ngroup = len(ret) elif idx != "phsp_noeff" and self._Ngroup != len(ret): warnings.warn("not the same data group") return ret
[docs] def get_phsp_noeff(self): if "phsp_noeff" in self.dic: phsp_noeff = self.get_data("phsp_noeff") assert len(phsp_noeff) == 1 return phsp_noeff[0] warnings.warn( "No data file as 'phsp_noeff', using the first 'phsp' file instead." ) return self.get_data("phsp")[0]