Source code for tf_pwa.config_loader.mix_config

import functools

from tf_pwa.amp import AmplitudeModel
from tf_pwa.amp.core import Variable, variable_scope
from tf_pwa.data import EvalLazy
from tf_pwa.model import FCN

from .multi_config import MultiConfig


[docs] class MixAmplitude(AmplitudeModel): def __init__( self, amps, same_data=False, vm=None, base_idx=0, no_scale=False ): self.amps = amps self.vm = amps[0].vm self.same_data = same_data self.base_idx = base_idx self.cached_fun = [] self.no_scale = no_scale with variable_scope(self.vm): self.scale = Variable("scale", shape=(len(self.amps),)) self.scale.set_fix_idx(0, fix_vals=1) self.scale.set_value(1.0)
[docs] def partial_weight(self, data, combine): if not self.same_data: data = data["datas"][self.base_idx] return self.amps[self.base_idx].partial_weight(data, combine)
def __getattr__(self, name): return getattr(self.amps[self.base_idx], name)
[docs] def pdf(self, data): ret = 0 scale = 0 scale_var = self.scale() for idx, amp in enumerate(self.amps): if not self.same_data and "datas" in data: data_i = data["datas"][idx] else: data_i = data w = data_i.get("weight", 1) ret = ret + amp(data_i) * w * scale_var[idx] scale = scale + w if self.no_scale: return ret return ret / scale * len(self.amps)
def __call__(self, data): return self.pdf(data)
[docs] class MixConfig(MultiConfig): def __init__( self, *args, total_same=False, same_data=False, no_scale=False, **kwargs, ): super().__init__(*args, total_same=total_same, **kwargs) self.same_data = same_data self.no_scale = no_scale self.cached_amps = None
[docs] def get_data(self, name): if self.same_data: return self.configs[0].get_data(name) all_data = [i.get_data(name) for i in self.configs] if all_data[0] is None: return all_data[0] ret = [] n_data = len(all_data) for i in range(len(all_data[0])): tmp = {"datas": [j[i] for j in all_data]} tmp["weight"] = sum(j[i]["weight"] for j in all_data) / n_data ret.append(tmp) return ret
[docs] def get_data_rec(self, name): ret = self.get_data(name + "_rec") if ret is None: ret = self.get_data(name) return ret
[docs] def get_phsp_plot(self, tail=""): ret = self.get_data("phsp_plot" + tail) if ret is None: ret = self.get_data("phsp" + tail) return ret
[docs] def get_all_data(self): datafile = ["data", "phsp", "bg", "inmc"] data, phsp, bg, inmc = [self.get_data(i) for i in datafile] self._Ngroup = len(data) assert len(phsp) == self._Ngroup if bg is None: bg = [None] * self._Ngroup if inmc is None: inmc = [None] * self._Ngroup assert len(bg) == self._Ngroup assert len(inmc) == self._Ngroup return data, phsp, bg, inmc
[docs] def get_amplitude(self, vm=None, name="", base_idx=0): if self.cached_amps is None: self.cached_amps = self.get_amplitudes(vm=vm) return MixAmplitude( self.cached_amps, self.same_data, base_idx=base_idx, no_scale=self.no_scale, )
@functools.lru_cache() def _get_model(self, vm=None, name=""): amp = self.get_amplitude(vm=vm, name=name) models = self.configs[0]._get_model(vm=vm, name=name, amp=amp) return models
[docs] def get_fcn(self, datas=None, vm=None, batch=65000): all_data = datas if all_data is None: all_data = self.get_all_data() model = self._get_model(vm=vm) return self.configs[-1].get_fcn( all_data=all_data, vm=model[0].vm, batch=batch, model=model )
[docs] def plot_partial_wave(self, *args, base_idx=0, **kwargs): amp = self.get_amplitude(base_idx=base_idx) data = self.get_data_rec("data") bg = self.get_data_rec("bg") phsp = self.get_phsp_plot() phsp_rec = self.get_phsp_plot("_rec") data_index_prefix = () if not self.same_data: data_index_prefix = "datas", base_idx self.configs[base_idx].plot_partial_wave( *args, amp=amp, data=data, phsp=phsp, bg=bg, phsp_rec=phsp_rec, data_index_prefix=data_index_prefix, **kwargs, )