import copy
import functools
import random
import yaml
from tf_pwa.amp import (
DecayChain,
DecayGroup,
HelicityDecay,
get_decay,
get_decay_chain,
get_particle,
split_particle_type,
)
from .base_config import BaseConfig
[docs]
def set_min_max(dic, name, name_min, name_max):
if name not in dic and name_min in dic and name_max in dic:
dic[name] = (
random.random() * (dic[name_max] - dic[name_min]) + dic[name_min]
)
[docs]
def decay_cut_ls(decay):
if isinstance(decay, HelicityDecay):
if len(decay.get_ls_list()) == 0:
return False, f"{decay} ls not aviable {decay.get_ls_list()}"
return True, ""
[docs]
def decay_cut_mass(decay):
if isinstance(decay, HelicityDecay):
if decay.core.mass is None or any(
[j.mass is None for j in decay.outs]
):
True, ""
# print(i, i.core.mass, [j.mass for j in i.outs])
if decay.core.mass < sum([j.mass for j in decay.outs]):
return (
False,
f"{decay} mass break {decay.core.mass} < {[j.mass for j in decay.outs]}",
)
return True, ""
[docs]
class DecayConfig(BaseConfig):
decay_chain_cut_list = {}
decay_cut_list = {
"ls_cut": decay_cut_ls,
"mass_cut": decay_cut_mass,
}
def __init__(self, dic, share_dict={}):
self.config = copy.deepcopy(dic)
self.decay_chain_config = dic.get("decay_chain", {})
self.data_config = dic.get("data", {})
self.share_dict = share_dict
self.particle_key_map = {
"Par": "P",
"m0": "mass",
"g0": "width",
"J": "J",
"P": "P",
"spins": "spins",
"bw": "model",
"model": "model",
"bw_l": "bw_l",
"running_width": "running_width",
}
self.cut_list = self.data_config.get("decay_chain_cut", ["ls_cut"])
self.decay_cut_list = self.data_config.get("decay_cut", self.cut_list)
self.decay_key_map = {"model": "model"}
self.dec = self.decay_item(self.config["decay"])
(
self.particle_map,
self.particle_property,
self.top,
self.finals,
) = self.particle_item(self.config["particle"], share_dict)
self.full_decay = DecayGroup(
self.get_decay_struct(
self.dec,
self.particle_map,
self.particle_property,
self.top,
self.finals,
self.decay_chain_config,
)
)
if self.data_config.get("cp_trans", True):
self.disable_allow_cc(self.full_decay)
self.decay_struct = DecayGroup(
self.get_decay_struct(
self.dec, {}, self.particle_property, process_cut=False
)
)
identical_particles = self.data_config.get("identical_particles", None)
cp_particles = self.data_config.get("cp_particles", None)
if identical_particles is not None:
self.decay_struct.identical_particles = identical_particles
self.full_decay.identical_particles = identical_particles
if cp_particles is not None:
self.decay_struct.cp_particles = cp_particles
self.full_decay.cp_particles = cp_particles
[docs]
@staticmethod
def load_config(file_name, share_dict={}):
if isinstance(file_name, dict):
return copy.deepcopy(file_name)
if isinstance(file_name, str):
if file_name in share_dict:
return DecayConfig.load_config(share_dict[file_name])
with open(file_name) as f:
ret = yaml.safe_load(f)
if ret is None:
ret = {}
return ret
raise TypeError("not support config {}".format(type(file_name)))
[docs]
def get_decay(self, full=True):
if full:
return self.full_decay
else:
return self.decay_struct
@staticmethod
def _list2decay(core, outs):
parts = []
params = {}
for j in outs:
if isinstance(j, dict):
for k, v in j.items():
params[k] = v
else:
parts.append(j)
dec = {"core": core, "outs": parts, "params": params}
return dec
[docs]
@staticmethod
def decay_item(decay_dict):
decs = []
for core, outs in decay_dict.items():
is_list = [isinstance(i, list) for i in outs]
if all(is_list):
for i in outs:
dec = DecayConfig._list2decay(core, i)
decs.append(dec)
else:
dec = DecayConfig._list2decay(core, outs)
decs.append(dec)
return decs
@staticmethod
def _do_include_dict(d, o, share_dict={}):
s = DecayConfig.load_config(o, share_dict)
for i in s:
if i in d:
if isinstance(d[i], dict):
s[i].update(d[i])
d[i] = s[i]
else:
d[i] = s[i]
[docs]
@staticmethod
def particle_item_list(particle_list):
particle_map = {}
particle_property = {}
for particle, candidate in particle_list.items():
if isinstance(candidate, list): # particle map
if len(candidate) == 0:
particle_map[particle] = []
for i in candidate:
if isinstance(i, str):
particle_map[particle] = particle_map.get(
particle, []
) + [i]
elif isinstance(i, dict):
map_i, pro_i = DecayConfig.particle_item_list(i)
for k, v in map_i.items():
particle_map[k] = particle_map.get(k, []) + v
particle_property.update(pro_i)
else:
raise ValueError(
"value of particle map {} is {}".format(i, type(i))
)
elif isinstance(candidate, dict):
particle_property[particle] = candidate
else:
raise ValueError(
"value of particle {} is {}".format(
particle, type(candidate)
)
)
return particle_map, particle_property
[docs]
@staticmethod
def particle_item(particle_list, share_dict={}):
top = particle_list.pop("$top", None)
finals = particle_list.pop("$finals", None)
includes = particle_list.pop("$include", None)
if includes:
if isinstance(includes, list):
for i in includes:
DecayConfig._do_include_dict(
particle_list, i, share_dict=share_dict
)
elif isinstance(includes, str):
DecayConfig._do_include_dict(
particle_list, includes, share_dict=share_dict
)
else:
raise ValueError(
"$include must be string or list of string not {}".format(
type(includes)
)
)
particle_map, particle_property = DecayConfig.particle_item_list(
particle_list
)
if isinstance(top, dict):
particle_property.update(top)
if isinstance(finals, dict):
particle_property.update(finals)
return particle_map, particle_property, top, finals
[docs]
def rename_params(self, params, is_particle=True):
ret = {}
if is_particle:
key_map = self.particle_key_map
else:
key_map = self.decay_key_map
for k, v in params.items():
ret[key_map.get(k, k)] = v
return ret
[docs]
def decay_chain_cut(self, decays):
ret = []
for i in decays:
flag = True
for name in self.cut_list:
if name not in DecayConfig.decay_chain_cut_list:
continue
f = DecayConfig.decay_chain_cut_list[name]
new_flag, msg = f(i)
flag = flag and new_flag
if not flag:
print(
"remove decay chain",
i,
"by",
name,
"\n\tbecause of",
msg,
)
break
if flag:
ret.append(i)
return ret
[docs]
def decay_cut(self, decays):
ret = []
for decay_chain in decays:
flag = True
for i in decay_chain:
for name in self.cut_list:
if name not in DecayConfig.decay_cut_list:
continue
f = DecayConfig.decay_cut_list[name]
new_flag, msg = f(i)
flag = flag and new_flag
if not flag:
print(
"remove decay chain",
decay_chain,
"by",
name,
"\n\tbecause of",
msg,
)
break
if not flag:
if i in i.core.decay:
i.core.decay.remove(i)
for j in i.outs:
if i in j.creators:
j.creators.remove(i)
break
if flag:
ret.append(decay_chain)
return ret
[docs]
def get_decay_struct(
self,
decay,
particle_map=None,
particle_params=None,
top=None,
finals=None,
chain_params={},
process_cut=True,
):
"""get decay structure for decay dict"""
particle_map = particle_map if particle_map is not None else {}
particle_params = (
particle_params if particle_params is not None else {}
)
base_particle_set = {}
particle_set = {}
def add_particle(name, _id):
name = "{}:{}".format(name, _id)
if name in particle_set:
return particle_set[name]
names = name.split(":")
params = particle_params.get(names[0], {})
params = self.rename_params(params)
set_min_max(params, "mass", "m_min", "m_max")
set_min_max(params, "width", "g_min", "g_max")
part = get_particle(name, **params)
particle_set[name] = part
return part
def add_base_particle(name):
if name in base_particle_set:
return base_particle_set[name]
part = get_particle(name) # , **params)
base_particle_set[name] = part
return part
def wrap_particle(name):
name_list = particle_map.get(name, [name])
return [add_base_particle(i) for i in name_list]
def all_combine(out):
if len(out) < 1:
yield []
else:
for i in out[0]:
for j in all_combine(out[1:]):
yield [i] + j
decs = []
new_decay_params = {}
for dec in decay:
core = wrap_particle(dec["core"])
outs = [wrap_particle(j) for j in dec["outs"]]
for i in core:
for j in all_combine(outs):
dec_i = get_decay(i, j)
new_decay_params[dec_i] = dec["params"]
decs.append(dec_i)
decay_list = {}
def add_decay(a, b, params):
b = tuple(b)
if (a, b) not in decay_list:
decay_list[(a, b)] = get_decay(a, b, **params)
return decay_list[(a, b)]
top_tmp, finals_tmp = set(), set()
if top is None or finals is None:
top_tmp, res, finals_tmp = split_particle_type(decs)
if top is None:
top_tmp = list(top_tmp)
assert len(top_tmp) == 1, "not only one top particle"
top = list(top_tmp)[0]
else:
if isinstance(top, list):
assert len(top) == 1, "only one initial supported"
top = top[0]
if isinstance(top, str):
top = base_particle_set[top]
elif isinstance(top, dict):
keys = list(top.keys())
assert len(keys) == 1
top = base_particle_set[keys.pop()]
else:
top = base_particle_set[str(top)]
if finals is None:
finals = list(finals_tmp)
elif isinstance(finals, (list, dict)):
finals = [base_particle_set[i] for i in finals]
else:
raise TypeError("{}: {}".format(finals, type(finals)))
dec_chain = top.chain_decay()
ret = []
for i in dec_chain:
if sorted(DecayChain(i).outs) == sorted(finals):
all_params = chain_params.get("$all", {})
count_input = {}
count_output = {}
all_dec = []
for dec in i:
count_input[dec.core.name] = (
count_input.get(dec.core.name, -1) + 1
)
core = add_particle(
dec.core.name, count_input[dec.core.name]
)
outs = []
for j in dec.outs:
count_output[j.name] = count_output.get(j.name, -1) + 1
out = add_particle(j.name, count_output[j.name])
outs.append(out)
dec_i = add_decay(core, outs, new_decay_params[dec])
all_dec.append(dec_i)
dec_c = get_decay_chain(all_dec, **all_params)
ret.append(dec_c)
if process_cut:
ret = self.decay_cut(ret)
ret = self.decay_chain_cut(ret)
if len(ret) == 0:
raise RuntimeError("not decay chain aviable, check you config.yml")
return ret
[docs]
def disable_allow_cc(self, decay_group):
for decay_chain in decay_group:
for decay in decay_chain:
if hasattr(decay, "allow_cc"):
decay.allow_cc = False