"""
Basic Amplitude Calculations.
A partial wave analysis process has following structure:
DecayGroup: addition (+)
DecayChain: multiplication (x)
Decay, Particle(Propagator)
"""
import contextlib
import functools
import inspect
import warnings
from itertools import combinations
from pprint import pprint
import numpy as np
import sympy as sym
from tf_pwa.breit_wigner import BW, BWR, Bprime, Bprime_q2, to_complex
from tf_pwa.cg import cg_coef
from tf_pwa.config import get_config, regist_config, temp_config
from tf_pwa.data import LazyCall, data_map, data_shape, split_generator
from tf_pwa.dec_parser import load_dec_file
from tf_pwa.dfun import get_D_matrix_lambda
from tf_pwa.einsum import einsum
from tf_pwa.particle import DEFAULT_DECAY, BaseParticle, Decay
from tf_pwa.particle import DecayChain as BaseDecayChain
from tf_pwa.particle import DecayGroup as BaseDecayGroup
from tf_pwa.particle import (
_spin_int,
_spin_range,
cp_charge_group,
split_particle_type,
)
from tf_pwa.tensorflow_wrapper import tf
from tf_pwa.variable import Variable, VarsManager
# from pysnooper import snoop
PARTICLE_MODEL = "particle_model"
regist_config(PARTICLE_MODEL, {})
DECAY_MODEL = "decay_model"
regist_config(DECAY_MODEL, {})
DECAY_CHAIN_MODEL = "decay_chain_model"
regist_config(DECAY_CHAIN_MODEL, {})
[docs]
def register_particle(name=None, f=None):
"""register a particle model
:params name: model name used in configuration
:params f: Model class
"""
def regist(g):
if name is None:
my_name = g.__name__
else:
my_name = name
config = get_config(PARTICLE_MODEL)
if my_name in config:
warnings.warn("Override model {}".format(my_name))
config[my_name] = g
g.model_name = my_name
return g
if f is None:
return regist
return regist(f)
[docs]
def register_decay(name=None, num_outs=2, f=None):
"""register a decay model
:params name: model name used in configuration
:params f: Model class
"""
def regist(g):
if name is None:
my_name = g.__name__
else:
my_name = name
config = get_config(DECAY_MODEL)
id_ = (num_outs, my_name)
if id_ in config:
warnings.warn("Override deccay model {}".format(my_name))
config[id_] = g
g.model_name = my_name
return g
if f is None:
return regist
return regist(f)
[docs]
def register_decay_chain(name=None, f=None):
"""register a decay model
:params name: model name used in configuration
:params f: Model class
"""
def regist(g):
if name is None:
my_name = g.__name__
else:
my_name = name
config = get_config(DECAY_CHAIN_MODEL)
id_ = my_name
if id_ in config:
warnings.warn("Override deccay model {}".format(my_name))
config[id_] = g
g.model_name = my_name
return g
if f is None:
return regist
return regist(f)
regist_particle = register_particle
regist_decay = register_decay
[docs]
def get_particle_model(name):
all_model = get_config(PARTICLE_MODEL)
return all_model.get(name, None)
[docs]
def get_particle_model_name(p):
all_model = get_config(PARTICLE_MODEL)
for k, v in all_model.items():
if type(p) is v:
return k
return str(type(p))
[docs]
def get_particle(*args, model="default", **kwargs):
"""method for getting particle of model"""
if isinstance(model, dict):
model_class = trans_model(model)
else:
model_class = get_particle_model(model)
if model_class is None:
warnings.warn(
"No model named {} found, use default instead.".format(model)
)
model_class = get_particle_model("default")
ret = model_class(*args, **kwargs)
ret.model_name = model
return ret
[docs]
def trans_model(model):
expr = model.get("expr")
expr = sym.simplify(expr)
var = {str(k): str(k) for k in expr.free_symbols}
var.update(model.get("where", {}))
model_name = []
for k, v in var.items():
if isinstance(v, str):
model_name.append((k, v))
assert len(model_name) == 1
expr = sym.simplify(expr)
var_name, name = model_name.pop()
expr2 = expr.subs({k: v for k, v in var.items() if k != var_name})
assert len(expr2.free_symbols) == 1, str(expr2)
fun = sym.lambdify((var_name,), expr2, "tensorflow")
base_model = get_particle_model(name)
class _TempModel(base_model):
_from_trans = True
def get_amp(self, *args, **kwargs):
amp = super().get_amp(*args, **kwargs)
return fun(amp)
return _TempModel
[docs]
def get_decay_model(model, num_outs=2):
id_ = (num_outs, model)
return get_config(DECAY_MODEL)[id_]
[docs]
def get_decay(core, outs, **kwargs):
"""method for getting decay of model"""
num_outs = len(outs)
prod_params = {}
for i in outs:
prod_params.update(getattr(i, "production_params", {}))
decay_params = getattr(core, "decay_params", {})
new_kwargs = {**prod_params, **decay_params, **kwargs}
model = new_kwargs.get("model", "default")
return get_decay_model(model, num_outs)(core, outs, **new_kwargs)
[docs]
def get_decay_chain(decays, **kwargs):
"""method for getting decay of model"""
decay_params = {}
for i in decays:
decay_params.update(getattr(i, "decay_chain_params", {}))
new_kwargs = {**decay_params, **kwargs}
model = new_kwargs.pop("model", "default")
return get_config(DECAY_CHAIN_MODEL)[model](decays, **new_kwargs)
[docs]
def data_device(data):
def get_device(dat):
if hasattr(dat, "device"):
return dat.device
return None
pprint(data_map(data, get_device))
return data
[docs]
def get_name(self, names):
name = (
(str(self) + "_" + names)
.replace(":", "/")
.replace("+", ".")
.replace(",", "")
.replace("[", "")
.replace("]", "")
.replace(" ", "")
)
return name
def _add_var(self, names, is_complex=False, shape=(), **kwargs):
name = get_name(self, names)
return Variable(name, shape, is_complex, **kwargs)
[docs]
class AmpBase(object):
"""Base class for amplitude"""
[docs]
def get_params_head(self):
if getattr(self, "params_head", None) is None:
self.params_head = str(self)
return self.params_head
[docs]
def add_var(self, names, is_complex=False, shape=(), **kwargs):
"""
default add_var method
"""
if not hasattr(self, "_variables_map"):
self._variables_map = {}
if True:
default_config = getattr(self, "default_params", {}).get(names, {})
if isinstance(default_config, (float, int)):
default_config = {"value": default_config}
kwargs.update(default_config)
name = self.get_variable_name(names)
var = Variable(name, shape, is_complex, **kwargs)
self._variables_map[names] = var
return var
[docs]
def get_var(self, name):
return getattr(self, "_variables_map", {}).get(name)
[docs]
def get_variable_name(self, name=""):
return get_name(self.get_params_head(), name)
[docs]
def amp_shape(self):
raise NotImplementedError
[docs]
def get_factor_variable(self):
return []
[docs]
@contextlib.contextmanager
def variable_scope(vm=None):
"""variabel name scope"""
if vm is None:
vm = VarsManager(dtype=get_config("dtype"))
with temp_config("vm", vm):
yield vm
[docs]
def simple_deepcopy(dic):
if isinstance(dic, dict):
return {k: simple_deepcopy(v) for k, v in dic.items()}
if isinstance(dic, list):
return [simple_deepcopy(v) for v in dic]
if isinstance(dic, tuple):
return tuple([simple_deepcopy(v) for v in dic])
return dic
[docs]
def simple_cache_fun(f):
name = "simple_cached_" + f.__name__
@functools.wraps(f)
def g(self, *args, **kwargs):
if not hasattr(self, name):
setattr(self, name, f(self, *args, **kwargs))
return getattr(self, name)
return g
[docs]
def get_relative_p(m_0, m_1, m_2):
"""relative momentum for 0 -> 1 + 2"""
M12S = m_1 + m_2
M12D = m_1 - m_2
if hasattr(M12S, "dtype"):
m_0 = tf.convert_to_tensor(m_0, dtype=M12S.dtype)
m_eff = tf.where(m_0 > M12S, m_0, M12S)
p = (m_eff - M12S) * (m_eff + M12S) * (m_eff - M12D) * (m_eff + M12D)
# if p is negative, which results from bad data, the return value is 0.0
# print("p", tf.where(p==0), m_0, m_1, m_2)
return tf.sqrt(p) / (2 * m_eff)
[docs]
def get_relative_p2(m_0, m_1, m_2):
"""relative momentum for 0 -> 1 + 2"""
M12S = m_1 + m_2
M12D = m_1 - m_2
if hasattr(M12S, "dtype"):
m_0 = tf.convert_to_tensor(m_0, dtype=M12S.dtype)
# m_eff = tf.where(m_0 > M12S, m_0, M12S)
p = (m_0 - M12S) * (m_0 + M12S) * (m_0 - M12D) * (m_0 + M12D)
# if p is negative, which results from bad data, the return value is 0.0
# print("p", tf.where(p==0), m_0, m_1, m_2)
return p / (2 * m_0) ** 2
def _ad_hoc(m0, m_max, m_min):
r"""ad-hoc formula
.. math::
m_0^{eff} = m^{min} + \frac{m^{max} - m^{min}}{2}(1+tanh \frac{m_0 - \frac{m^{max} + m^{min}}{2}}{m^{max} - m^{min}})
"""
k = (m_max - m_min) / 2
m_eff = k * (1 + tf.tanh((2 * m0 - (m_max + m_min)) / k / 4))
return m_eff + m_min
[docs]
@regist_particle("BWR")
@regist_particle("default")
class Particle(BaseParticle, AmpBase):
"""
.. math::
R(m) = \\frac{1}{m_0^2 - m^2 - i m_0 \\Gamma(m)}
Argand diagram
.. plot::
>>> import matplotlib.pyplot as plt
>>> plt.clf()
>>> from tf_pwa.utils import plot_particle_model
>>> axis = plot_particle_model("BWR")
Pole position
.. plot::
>>> import matplotlib.pyplot as plt
>>> plt.clf()
>>> from tf_pwa.utils import plot_pole_function
>>> axis = plot_pole_function("BWR")
"""
def __init__(
self,
*args,
running_width=True,
bw_l=None,
width_norm=False,
params_head=None,
**kwargs
):
super(Particle, self).__init__(*args, **kwargs)
self.running_width = running_width
self.bw_l = bw_l
self.width_norm = width_norm
self.params_head = None
[docs]
def init_params(self):
self.d = 3.0
if self.mass is None:
self.mass = self.add_var("mass", fix=True)
# print("$$$$$",self.mass)
else:
if not isinstance(self.mass, Variable):
self.mass = self.add_var("mass", value=self.mass, fix=True)
if self.width is not None:
if not isinstance(self.width, Variable):
self.width = self.add_var("width", value=self.width, fix=True)
[docs]
def is_fixed_shape(self):
for k, v in self.__dict__.items():
if isinstance(v, Variable):
if not v.is_fixed():
return False
return True
[docs]
def get_amp(self, data, data_c, **kwargs):
mass = self.get_mass()
width = self.get_width()
if width is None:
return tf.ones_like(data["m"])
if not self.running_width:
ret = BW(data["m"], mass, width)
else:
q = data_c["|q|"]
q0 = data_c["|q0|"]
if self.bw_l is None:
decay = self.decay[0]
self.bw_l = min(decay.get_l_list())
ret = BWR(data["m"], mass, width, q, q0, self.bw_l, self.d)
# ret = tf.where(q0 > 0, ret, tf.zeros_like(ret))
# ret = tf.where(q > 0, ret, tf.zeros_like(ret))
if self.width_norm:
c_width = tf.complex(width, tf.zeros_like(width))
return tf.cast(c_width, ret.dtype) * ret
return ret
def __call__(self, m):
mass = self.get_mass()
m1 = self.decay[0].outs[0].get_mass()
m2 = self.decay[0].outs[1].get_mass()
q = get_relative_p(m, m1, m2)
q0 = get_relative_p(mass, m1, m2)
return self.get_amp(
{"m": m}, {"|q|": q, "|q|2": q**2, "|q0|": q0, "|q0|2": q0**2}
)
[docs]
def amp_shape(self):
return ()
[docs]
def get_mass(self):
if self.mass is None:
warnings.warn(
f"The mass of {self} is None, may be you should calculate amplitude first to infer mass"
)
if callable(self.mass):
return self.mass()
return self.mass
[docs]
def get_width(self):
if callable(self.width):
return self.width()
return self.width
[docs]
def get_factor(self):
return None
[docs]
def get_sympy_var(self):
return sym.var("m m0 g0 m1 m2")
[docs]
def get_subdecay_mass(self, idx=0):
return [i.get_mass() for i in self.decay[idx].outs]
[docs]
def get_num_var(self):
mass = self.get_mass()
width = self.get_width()
m1, m2 = self.get_subdecay_mass()
return mass, width, m1, m2
[docs]
def solve_pole(self, init=None, sheet=0, return_complex=True):
mass = self.get_mass()
width = self.get_width()
if init is None:
init_pole = float(mass) - sym.I * float(width) / 2
else:
init_pole = float(np.real(init)) - sym.I * float(np.imag(init))
from tf_pwa.formula import create_complex_root_sympy_tfop
var = self.get_sympy_var()
f = self.get_sympy_dom(*var, sheet=sheet)
g = create_complex_root_sympy_tfop(f, var[1:], var[0], init_pole)
ret = g(*self.get_num_var())
if not return_complex:
ret = tf.math.real(ret), tf.math.imag(ret)
return ret
[docs]
def get_sympy_dom(self, m, m0, g0, m1=None, m2=None, sheet=0):
if self.get_width() is None:
raise NotImplemented
from tf_pwa.formula import BW_dom, BWR_dom
if not self.running_width or m1 is None or m2 is None:
return BW_dom(m, m0, g0)
else:
if self.bw_l is None:
decay = self.decay[0]
self.bw_l = min(decay.get_l_list())
return BWR_dom(m, m0, g0, self.bw_l, m1, m2)
[docs]
def pole_function(self, sheet=0, modules="numpy"):
from tf_pwa.formula import create_numpy_function
var = self.get_sympy_var()
f = self.get_sympy_dom(*var, sheet=sheet)
val = self.get_num_var()
return create_numpy_function(f, var[1:], val, var[0], modules=modules)
[docs]
@regist_particle("x")
class ParticleX(Particle):
"""simple particle model for mass, (used in expr)
.. math::
R(m) = m
.. plot::
>>> import matplotlib.pyplot as plt
>>> plt.clf()
>>> from tf_pwa.utils import plot_particle_model
>>> axis = plot_particle_model("x")
"""
def __call__(self, m):
return self.get_amp({"m": m})
[docs]
def get_amp(self, data, *args, **kwargs):
m = data["m"]
zeros = tf.zeros_like(m)
return tf.complex(m, zeros)
[docs]
class SimpleResonances(Particle):
def __init__(self, *args, **kwargs):
self.params = {}
super(SimpleResonances, self).__init__(*args, **kwargs)
def __call__(self, m, m0=None, g0=None, q=None, q0=None, **kwargs):
raise NotImplementedError
[docs]
def get_amp(self, *args, **kwargs):
m = args[0]["m"]
q, q0 = None, None
if len(args) >= 2:
q = args[1].get("|q|", 1.0)
q0 = args[1].get("|q0|", 1.0)
m0 = self.get_mass()
g0 = self.get_width()
return self(m, m0=m0, g0=g0, q=q, q0=q0, **kwargs)
[docs]
class FloatParams(float):
pass
[docs]
def simple_resonance(name, fun=None, params=None):
"""convert simple fun f(m) into a resonances model
:params name: model name used in configuration
:params fun: Model function
:params params: arguments name list for parameters
"""
if params is None:
params = {}
def _wrapper(f):
argspec = inspect.getfullargspec(f)
args = argspec.args
if argspec.defaults is None:
defaults = {}
else:
defaults = dict(zip(argspec.args[::-1], argspec.defaults[::-1]))
@register_particle(name)
class _R(SimpleResonances):
def init_params(self):
if "m0" in argspec.args and "g0" in argspec.args:
super(_R, self).init_params()
self.params = {}
for i in argspec.args:
tp = argspec.annotations.get(i, None)
if i in params or tp is FloatParams:
val = getattr(self, i, defaults.get(i, None))
if val is None:
self.params[i] = self.add_var(i)
else:
self.params[i] = self.add_var(
i, value=val, fix=True
)
def is_fixed_shape(self):
ret = super().is_fixed_shape()
for k, v in self.params.items():
ret = ret and v.is_fixed()
return ret
def __call__(self, m, **kwargs):
my_kwargs = {}
for i in argspec.args:
if i in kwargs:
my_kwargs[i] = kwargs[i]
elif i in self.params:
my_kwargs[i] = self.params[i]()
elif hasattr(self, i):
my_kwargs[i] = getattr(self, i)
ret = f(m, **my_kwargs)
return tf.cast(ret, tf.complex128)
__call__.__doc__ = f.__doc__
_R.get_amp.__doc__ = f.__doc__
return _R
if fun is None:
return _wrapper
return _wrapper(fun)
[docs]
class AmpDecay(Decay, AmpBase):
"""base class for decay with amplitude"""
[docs]
def get_params_head(self):
if getattr(self, "params_head", None) is None:
core = self.core.get_params_head()
outs = [i.get_params_head() for i in self.outs]
self.params_head = "{}->{}".format(core, "+".join(outs))
return self.params_head
[docs]
def amp_shape(self):
ret = [len(self.core.spins)]
for i in self.outs:
ret.append(len(i.spins))
return tuple(ret)
# @simple_cache_fun
[docs]
def amp_index(self, base_map):
ret = [base_map[self.core]]
for i in self.outs:
ret.append(base_map[i])
return ret
[docs]
def n_helicity_inner(self):
ret = []
for i in self.outs:
if getattr(self, "helicity_inner_full", False):
ret.append(_spin_int(2 * i.J + 1))
else:
ret.append(len(i.spins))
return ret
[docs]
def list_helicity_inner(self):
ret = []
for i in self.outs:
if getattr(self, "helicity_inner_full", False):
ret.append(tuple(_spin_range(-i.J, i.J)))
else:
ret.append(i.spins)
return ret
[docs]
@regist_decay("default")
@regist_decay("gls-bf")
class HelicityDecay(AmpDecay):
r"""default decay model
The total amplitude is
.. math::
A = H_{\lambda_{B},\lambda_{C}}^{A \rightarrow B+C} D^{J_A*}_{\lambda_{A}, \lambda_{B}-\lambda_{C}} (\varphi,\theta,0)
The helicity coupling is
.. math::
H_{\lambda_{B},\lambda_{C}}^{A \rightarrow B+C} =
\sum_{ls} g_{ls} \sqrt{\frac{2l+1}{2 J_{A}+1}} \langle l 0; s \delta|J_{A} \delta\rangle \langle J_{B} \lambda_{B} ;J_{C} -\lambda_{C} | s \delta \rangle q^{l} B_{l}'(q, q_0, d)
The fit parameters is :math:`g_{ls}`
There are some options
(1). `has_bprime=False` will remove the :math:`B_{l}'(q, q_0, d)` part.
(2). `has_barrier_factor=False` will remove the :math:`q^{l} B_{l}'(q, q_0, d)` part.
(3). `barrier_factor_norm=True` will replace :math:`q^l` with :math:`(q/q_{0})^l`
(4). `below_threshold=True` will replace the mass used to calculate :math:`q_0` with
.. math::
m_0^{eff} = m^{min} + \frac{m^{max} - m^{min}}{2}(1+tanh \frac{m_0 - \frac{m^{max} + m^{min}}{2}}{m^{max} - m^{min}})
(5). `l_list=[l1, l2]` and `ls_list=[[l1, s1], [l2, s2]]` options give the list of all possible LS used in the decay.
(6). `no_q0=True` will set the :math:`q_0=1`.
"""
def __init__(
self,
*args,
has_barrier_factor=True,
l_list=None,
barrier_factor_mass=False,
has_ql=True,
has_bprime=True,
aligned=False,
allow_cc=True,
ls_list=None,
barrier_factor_norm=False,
params_polar=None,
below_threshold=False,
force_min_l=False,
params_head=None,
no_q0=False,
helicity_inner_full=False,
ls_selector=None,
**kwargs
):
super(HelicityDecay, self).__init__(*args, **kwargs)
self.has_barrier_factor = has_barrier_factor
self.l_list = l_list
self.barrier_factor_mass = barrier_factor_mass
self.has_ql = has_ql
self.has_bprime = has_bprime
self.aligned = aligned
self.allow_cc = allow_cc
self.force_min_l = force_min_l
self.single_gls = False
self.ls_index = None
self.total_ls = None
self.barrier_factor_norm = barrier_factor_norm
self.below_threshold = below_threshold
self.ls_list = None
if ls_list is not None:
self.ls_list = tuple([tuple(i) for i in ls_list])
self.params_polar = params_polar
self.mask_factor = False
self.params_head = params_head
self.no_q0 = no_q0
self.helicity_inner_full = helicity_inner_full
self.ls_selector = ls_selector
[docs]
def get_params_head(self):
if self.params_head is None:
core = self.core.get_params_head()
outs = [i.get_params_head() for i in self.outs]
self.params_head = "{}->{}".format(core, "+".join(outs))
return self.params_head
[docs]
def check_valid_jp(self):
if len(self.get_ls_list()) == 0:
if not self.p_break:
raise ValueError(
"""invalid spin parity for {}, maybe you should set `p_break: True` for weak decay""".format(
self
)
)
raise ValueError("invalid spin parity for {}".format(self))
[docs]
def set_ls(self, ls):
if self.total_ls is None:
self.total_ls = self.get_ls_list()
self.ls_list = tuple([tuple(i) for i in ls])
self.single_gls = len(ls) == 1
# print(self, "total_ls: ", self.total_ls)
total_ls = self.total_ls
if len(total_ls) == len(ls):
self.ls_index = None
return
self.ls_index = []
for i in self.ls_list:
self.ls_index.append(total_ls.index(i))
[docs]
def init_params(self):
self.d = 3.0
ls = self.get_ls_list()
self.g_ls = self.add_var(
"g_ls", is_complex=True, polar=self.params_polar, shape=(len(ls),)
)
try:
self.g_ls.set_fix_idx(fix_idx=0, fix_vals=(1.0, 0.0))
except Exception as e:
print(e, self, self.get_ls_list())
[docs]
def get_factor_variable(self):
return [(self.g_ls,)]
[docs]
def get_factor(self):
return self.get_g_ls()
[docs]
def mask_factor_vars(self):
return self.g_ls.factor_names()
[docs]
def factor_iter_names(self, deep=1, extra=[]):
if deep == 0:
yield {}
if len(extra) > 0:
for j in extra[0].factor_iter_names(deep, extra=extra[1:]):
all_var = self.mask_factor_vars()
for i in all_var:
a = {k: 0.0 for k in all_var if k != i}
yield {**a, **j}
else:
all_var = self.mask_factor_vars()
for i in self.mask_factor_vars():
yield {k: 0.0 for k in all_var if k != i}
def _get_particle_mass(self, p, data, from_data=False):
if from_data and p in data:
return data[p]["m"]
if p.mass is None:
p.mass = tf.reduce_mean(data[p]["m"])
warnings.warn(
"no mass for particle {}, set it to {}".format(p, p.mass)
)
return p.get_mass()
[docs]
def get_relative_momentum(self, data, from_data=False):
""""""
_get_mass = lambda p: self._get_particle_mass(p, data, from_data)
m0 = _get_mass(self.core)
m1 = _get_mass(self.outs[0])
m2 = _get_mass(self.outs[1])
if self.below_threshold:
m3 = _get_mass(
[i for i in self.core.creators[0].outs if i != self.core][0]
)
m_eff = _ad_hoc(
m0, _get_mass(self.core.creators[0].core) - m3, m1 + m2
)
m0 = tf.where(m0 < m1 + m2, m_eff, m0)
return get_relative_p(m0, m1, m2)
[docs]
def get_relative_momentum2(self, data, from_data=False):
""""""
_get_mass = lambda p: self._get_particle_mass(p, data, from_data)
m0 = _get_mass(self.core)
m1 = _get_mass(self.outs[0])
m2 = _get_mass(self.outs[1])
if self.below_threshold:
m3 = _get_mass(
[i for i in self.core.creators[0].outs if i != self.core][0]
)
m_eff = _ad_hoc(
m0, _get_mass(self.core.creators[0].core) - m3, m1 + m2
)
m0 = tf.where(m0 < m1 + m2, m_eff, m0)
ret = get_relative_p2(m0, m1, m2)
return ret
[docs]
def get_cg_matrix(self, out_sym=False):
ls = self.get_ls_list()
return self._get_cg_matrix(
ls, out_sym=out_sym, helicity_inner_full=self.helicity_inner_full
)
@functools.lru_cache()
def _get_cg_matrix(
self, ls, out_sym=False, helicity_inner_full=False
): # CG factor inside H
"""
[(l,s),(lambda_b,lambda_c)]
.. math::
\\sqrt{\\frac{ 2 l + 1 }{ 2 j_a + 1 }}
\\langle j_b, j_c, \\lambda_b, - \\lambda_c | s, \\lambda_b - \\lambda_c \\rangle
\\langle l, s, 0, \\lambda_b - \\lambda_c | j_a, \\lambda_b - \\lambda_c \\rangle
"""
m = len(ls)
ja = self.core.J
jb = self.outs[0].J
jc = self.outs[1].J
n = self.n_helicity_inner() # require helicity_inner_full
ret = np.zeros(shape=(m, *n))
sqrt = np.sqrt
my_cg_coef = cg_coef
if out_sym:
from sympy.physics.quantum.cg import CG
sint = (
lambda x: sym.simplify(sym.sign(x) * _spin_int(abs(x) * 2)) / 2
)
sqrt = lambda x: sym.sqrt(sint(x))
def my_cg_coef(j1, j2, m1, m2, j3, m3):
ret = CG(*list(map(sint, [j1, m1, j2, m2, j3, m3]))).doit()
print(j1, j2, m1, m2, j3, m3, ret)
return ret
ret = ret.tolist()
for i, ls_i in enumerate(ls):
l, s = ls_i
for i1, lambda_b in enumerate(self.list_helicity_inner()[0]):
for i2, lambda_c in enumerate(self.list_helicity_inner()[1]):
ret[i][i1][i2] = (
sqrt(2 * l + 1)
/ sqrt(2 * ja + 1)
* my_cg_coef(
jb, jc, lambda_b, -lambda_c, s, lambda_b - lambda_c
)
* my_cg_coef(
l,
s,
0,
lambda_b - lambda_c,
ja,
lambda_b - lambda_c,
)
)
return ret
[docs]
def build_ls2hel_eq(self):
cg_matrix = self.get_cg_matrix(out_sym=True)
gls = []
for l, s in self.get_ls_list():
gls.append(sym.Symbol("g_{}_{}".format(l, s)))
hel = []
eqs = []
for ib, lb in enumerate(self.outs[0].spins):
for ic, lc in enumerate(self.outs[1].spins):
tmp = sym.Symbol("H_{}_{}".format(lb, lc))
rhs = 0
for idx, gi in enumerate(gls):
rhs = rhs + cg_matrix[idx][ib][ic] * gi
if rhs == 0:
continue
eq = sym.Eq(tmp, sym.simplify(rhs))
hel.append(tmp)
eqs.append(eq)
return [gls, hel, eqs]
[docs]
def build_simple_data(self):
data_p = {self.core: {"m": self.core.get_mass()}}
data = {}
zero = np.array(0.0)
for i in self.outs:
data_p[i] = {"m": i.get_mass()}
data[i] = {"ang": {"alpha": zero, "beta": zero, "gamma": zero}}
return {"data": data, "data_p": data_p}
[docs]
def get_helicity_amp(self, data, data_p, **kwargs):
m_dep = self.get_ls_amp(data, data_p, **kwargs)
cg_trans = tf.cast(self.get_cg_matrix(), m_dep.dtype)
n_ls = len(self.get_ls_list())
m_dep = tf.reshape(m_dep, (-1, n_ls, 1, 1))
cg_trans = tf.reshape(cg_trans, (n_ls, *self.n_helicity_inner()))
H = tf.reduce_sum(m_dep * cg_trans, axis=1)
# print(n_ls, cg_trans, self, m_dep.shape) # )data_p)
if self.allow_cc:
all_data = kwargs.get("all_data", {})
charge = all_data.get("charge_conjugation", None)
if charge is not None:
H = tf.where(
charge[..., None, None] > 0, H, H[..., ::-1, ::-1]
)
ret = tf.reshape(H, (-1, 1, *self.n_helicity_inner()))
return ret
[docs]
def get_angle_helicity_amp(self, data, data_p, **kwargs):
m_dep = self.get_angle_ls_amp(data, data_p, **kwargs)
cg_trans = tf.cast(self.get_cg_matrix(), m_dep.dtype)
n_ls = len(self.get_ls_list())
m_dep = tf.reshape(m_dep, (-1, n_ls, 1, 1))
cg_trans = tf.reshape(cg_trans, (n_ls, *self.n_helicity_inner()))
H = tf.reduce_sum(m_dep * cg_trans, axis=1)
# print(n_ls, cg_trans, self, m_dep.shape) # )data_p)
if self.allow_cc:
all_data = kwargs.get("all_data", {})
charge = all_data.get("charge_conjugation", None)
if charge is not None:
H = tf.where(
charge[..., None, None] > 0, H, H[..., ::-1, ::-1]
)
ret = tf.reshape(H, (-1, 1, *self.n_helicity_inner()))
return ret
[docs]
def get_factor_H(self, data, data_p, **kwargs): # -> (n, n_ls, h1, h2)
m_dep = self.get_angle_ls_amp(data, data_p, **kwargs) # (n,l)
cg_trans = tf.cast(self.get_cg_matrix(), m_dep.dtype)
n_ls = len(self.get_ls_list())
m_dep = tf.reshape(m_dep, (-1, n_ls, 1, 1))
cg_trans = tf.reshape(cg_trans, (n_ls, *self.n_helicity_inner()))
# H = tf.reduce_sum(m_dep * cg_trans, axis=1)
H = m_dep * cg_trans # (n, n_ls, h1, h2)
return H
[docs]
def get_factor_angle_helicity_amp(self, data, data_p, **kwargs):
H = self.get_factor_H(data, data_p, **kwargs)
if self.allow_cc:
all_data = kwargs.get("all_data", {})
charge = all_data.get("charge_conjugation", None)
if charge is not None:
H = tf.where(
charge[..., None, None] > 0, H, H[..., ::-1, ::-1]
)
ret = tf.reshape(
H,
(
-1,
H.shape[-3],
1,
*self.n_helicity_inner(),
),
)
return ret
[docs]
def get_g_ls(self):
gls = self.g_ls()
if self.ls_index is None:
ret = tf.stack(gls)
else:
ret = tf.stack([gls[k] for k in self.ls_index])
if self.mask_factor:
return tf.ones_like(ret)
return ret
[docs]
def get_ls_amp_org(self, data, data_p, **kwargs):
g_ls = self.get_g_ls()
# print(g_ls)
q0 = self.get_relative_momentum(data_p, False)
data["|q0|"] = q0
if "|q|" in data:
q = data["|q|"]
else:
q = self.get_relative_momentum(data_p, True)
data["|q|"] = q
if self.has_barrier_factor:
bf = self.get_barrier_factor(data_p[self.core]["m"], q, q0, self.d)
mag = g_ls
m_dep = mag * tf.cast(bf, mag.dtype)
else:
m_dep = tf.reshape(g_ls, (1, -1))
return m_dep
[docs]
def get_ls_amp(self, data, data_p, **kwargs):
g_ls = self.get_g_ls()
q0 = self.get_relative_momentum2(data_p, False)
data["|q0|2"] = q0
if "|q|2" in data:
q = data["|q|2"]
else:
q = self.get_relative_momentum2(data_p, True)
data["|q|2"] = q
if self.has_barrier_factor:
bf = self.get_barrier_factor2(
data_p[self.core]["m"], q, q0, self.d
)
mag = g_ls
bf = to_complex(bf)
m_dep = mag * tf.cast(bf, mag.dtype)
else:
m_dep = tf.reshape(g_ls, (1, -1))
return m_dep
[docs]
def get_angle_g_ls(self):
gls = tf.ones_like(self.g_ls())
# [complex(1.0, 0.0) for i in range(len(self.g_ls()))]
if self.ls_index is None:
return gls # tf.stack(gls)
return tf.stack([gls[k] for k in self.ls_index])
[docs]
def get_angle_ls_amp(self, data, data_p, **kwargs):
g_ls = self.get_angle_g_ls()
return g_ls
[docs]
def get_barrier_factor(self, mass, q, q0, d):
ls = self.get_l_list()
ret = []
for l in ls:
if self.force_min_l:
l = min(ls)
if self.has_bprime:
tmp = q**l * tf.cast(Bprime(l, q, q0, d), dtype=q.dtype)
else:
tmp = q**l
# tmp = tf.where(q > 0, tmp, tf.zeros_like(tmp))
ret.append(tf.reshape(tmp, (-1, 1)))
ret = tf.concat(ret, axis=-1)
mass_dep = self.get_barrier_factor_mass(mass)
return ret * mass_dep
[docs]
def get_barrier_factor2(self, mass, q2, q02, d):
ls = self.get_l_list()
if self.no_q0:
q02 = tf.ones_like(q02)
ret = []
for l in ls:
if self.force_min_l:
l = min(ls)
if self.has_bprime:
bp = Bprime_q2(l, q2, q02, d)
if self.has_ql:
tmp = q2 ** (l / 2) * tf.cast(bp, dtype=q2.dtype)
else:
tmp = tf.ones_like(q2) * tf.cast(bp, dtype=q2.dtype)
if self.barrier_factor_norm:
tmp = tmp / tf.cast(tf.abs(q02), tmp.dtype) ** (l / 2)
else:
if self.has_ql:
tmp = q2 ** (l / 2)
else:
tmp = tf.ones_like(q2)
# tmp = tf.where(q > 0, tmp, tf.zeros_like(tmp))
ret.append(tf.reshape(tmp, (-1, 1)))
ret = tf.concat(ret, axis=-1)
mass_dep = self.get_barrier_factor_mass(mass)
return ret * mass_dep
[docs]
def get_barrier_factor_mass(self, mass):
if not self.barrier_factor_mass:
return 1.0
ls = tf.convert_to_tensor(self.get_l_list(), dtype=mass.dtype)
m_dep = 1.0 / tf.pow(tf.expand_dims(mass, -1), ls)
return m_dep
[docs]
def add_algin(self, ret, data):
a = self.core
b = self.outs[0]
c = self.outs[1]
if self.aligned:
for j, particle in enumerate(self.outs):
if particle.J != 0:
ang = data[particle].get("aligned_angle", None)
if ang is None and not getattr(
self, "helicity_inner_full", False
):
continue
dt = get_D_matrix_lambda(
ang,
particle.J,
self.list_helicity_inner()[j],
particle.spins,
)
dt_shape = [-1, 1, 1, 1, 1]
dt_shape[j + 2] = len(self.list_helicity_inner()[j])
dt_shape[j + 3] = len(particle.spins)
dt = tf.reshape(dt, dt_shape)
if j >= 1:
D_shape = [
-1,
len(a.spins),
len(b.spins),
len(c.spins),
]
else:
D_shape = [
-1,
len(a.spins),
len(b.spins),
self.n_helicity_inner()[-1],
]
D_shape.insert(j + 3, 2)
D_shape[j + 3] = 1
ret = tf.reshape(ret, D_shape)
ret = dt * ret
ret = tf.reduce_sum(ret, axis=j + 2)
return ret
[docs]
def get_amp(self, data, data_p, **kwargs):
a = self.core
b = self.outs[0]
c = self.outs[1]
ang = data[b]["ang"]
D_conj = get_D_matrix_lambda(
ang, a.J, a.spins, *self.list_helicity_inner()
)
H = self.get_helicity_amp(data, data_p, **kwargs)
H = tf.reshape(H, (-1, 1, *self.n_helicity_inner()))
H = tf.cast(H, dtype=D_conj.dtype)
ret = H * tf.stop_gradient(D_conj)
# print(self, H, D_conj)
# exit()
self.add_algin(ret, data)
return ret
[docs]
def get_angle_amp(self, data, data_p, **kwargs):
a = self.core
b = self.outs[0]
c = self.outs[1]
ang = data[b]["ang"]
D_conj = get_D_matrix_lambda(
ang, a.J, a.spins, *self.list_helicity_inner()
)
H = self.get_angle_helicity_amp(data, data_p, **kwargs)
H = tf.reshape(H, (-1, 1, *self.n_helicity_inner()))
H = tf.cast(H, dtype=D_conj.dtype)
ret = H * tf.stop_gradient(D_conj)
# print(self, H, D_conj)
# exit()
self.add_algin(ret, data)
return ret
[docs]
def get_factor_angle_amp(self, data, data_p, **kwargs):
a = self.core
b = self.outs[0]
c = self.outs[1]
ang = data[b]["ang"]
D_conj = get_D_matrix_lambda(
ang, a.J, a.spins, *self.list_helicity_inner()
)
H = self.get_factor_angle_helicity_amp(data, data_p, **kwargs)
H = tf.cast(H, dtype=D_conj.dtype)
D_conj = tf.reshape(D_conj, (-1, 1, *D_conj.shape[1:]))
ret = H * tf.stop_gradient(D_conj)
# print(self, H, D_conj)
# exit()
if self.aligned:
raise NotImplemented
return ret
[docs]
def get_m_dep(self, data, data_p, **kwargs):
return self.get_ls_amp(data, data_p, **kwargs)
[docs]
def get_total_ls_list(self):
if self.total_ls is None:
self.total_ls = self.get_ls_list()
return self.total_ls
[docs]
def get_factor_m_dep(self, data, data_p, **kwargs):
return self.get_ls_amp(data, data_p, **kwargs)
[docs]
def get_ls_list(self):
"""get possible ls for decay, with l_list filter possible l"""
if self.ls_list is not None:
return self.ls_list
ls_list = super(HelicityDecay, self).get_ls_list()
if self.ls_selector == "weight":
print("using ls_selector", self.ls_selector, "for", self)
from tf_pwa.cov_ten_ir import ls_selector_weight
ls_list = tuple(ls_selector_weight(self, ls_list))
if self.ls_selector == "qr":
print("using ls_selector", self.ls_selector, "for", self)
ls_list = tuple(ls_selector_qr(self, ls_list))
self.ls_list = ls_list
if self.l_list is None:
return self.ls_list
ret = []
for l, s in self.ls_list:
if l in self.l_list:
ret.append((l, s))
self.ls_list = tuple(ret)
return self.ls_list
[docs]
def ls_selector_qr(decay, ls_list):
p0 = decay.core
p1 = decay.outs[0]
p2 = decay.outs[1]
hel_list = []
for l1 in p1.spins:
for l2 in p2.spins:
if abs(l1 - l2) <= p0.J:
if not decay.p_break:
if (-l1, -l2) in hel_list:
continue
hel_list.append((l1, l2))
from sympy import Matrix
from sympy.physics.quantum.cg import CG
cg = []
for l1, l2 in hel_list:
tmp = []
for l, s in ls_list:
delta = l1 - l2
coeff = CG(l, 0, s, delta, p0.J, delta)
coeff = coeff * CG(p1.J, l1, p2.J, -l2, s, delta)
tmp.append(coeff.doit())
cg.append(tmp)
cg = Matrix(cg)
_, r = cg.QRdecomposition()
all_idx = []
for i in range(r.rows):
idx = 0
for j in range(r.cols):
if r[i, j] == 0:
idx += 1
else:
break
all_idx.append(idx)
return [ls_list[i] for i in all_idx]
[docs]
@regist_decay("default", 3)
@regist_decay("AngSam3", 3)
class AngSam3Decay(AmpDecay, AmpBase):
[docs]
def init_params(self):
a = self.core.J
self.gi = self.add_var(
"G_mu", is_complex=True, shape=(_spin_int(2 * a + 1),)
)
try:
self.gi.set_fix_idx(fix_idx=0, fix_vals=(1.0, 0.0))
except Exception as e:
print(e)
[docs]
def get_amp(self, data, data_extra=None, **kwargs):
a = self.core
b = self.outs[0]
c = self.outs[1]
d = self.outs[2]
gi = tf.stack(self.gi())
ang = data["ang"]
D_conj = get_D_matrix_lambda(
ang, a.J, a.spins, tuple(_spin_range(-a.J, a.J))
)
ret = tf.cast(gi, D_conj.dtype) * D_conj
ret = tf.reduce_sum(ret, axis=-1)
ret = tf.reshape(ret, (-1, len(a.spins), 1, 1, 1))
ret = tf.tile(ret, [1, 1, len(b.spins), len(c.spins), len(d.spins)])
return ret
[docs]
class AmpDecayChain(BaseDecayChain, AmpBase):
def __init__(self, *args, is_cp=False, aligned=True, **kwargs):
self.is_cp = is_cp
self.aligned = aligned
super(AmpDecayChain, self).__init__(*args, **kwargs)
self.need_amp_particle = True
self.mask_factor = False
[docs]
def get_params_head(self):
if getattr(self, "params_head", None) is None:
self.params_head = (
"[" + ", ".join([i.get_params_head() for i in self]) + "]"
)
return self.params_head
[docs]
@register_decay_chain("default")
class DecayChain(AmpDecayChain):
"""A list of Decay as a chain decay"""
[docs]
def init_params(self, name=""):
self.total = self.add_var(
name + "total", is_complex=True, is_cp=self.is_cp, shape=[1]
)
# self.total = self.add_var(name + "total", is_complex=True, shape=[1])
[docs]
def get_factor_variable(self):
a = []
for i in self:
tmp = i.get_factor_variable()
if tmp:
a.append(tmp)
for j in self.inner:
tmp = j.get_factor_variable()
if tmp:
a.append(tmp)
return [tuple([self.total] + a)]
[docs]
def get_factor(self): # (total, decay1, decay2, ...)
decay_factor = [i.get_factor() for i in self]
particle_factor = [i.get_factor() for i in self.inner]
all_factor = decay_factor + particle_factor
all_factor = [i for i in all_factor if i is not None]
all_factor = all_factor
ret = self.get_amp_total()
for i in all_factor:
ret = tf.expand_dims(ret, axis=-1) * tf.cast(i, ret.dtype)
return ret
[docs]
def factor_iteration(self, deep=1):
if deep == 0:
yield {}
else:
all_decay = list(self)
for j in all_decay[0].factor_iter_names(
deep=deep, extra=all_decay[1:]
):
with self.total.vm.mask_params(j):
yield j
[docs]
def get_amp_total(self, charge=1):
if self.mask_factor:
return tf.ones_like(tf.stack(self.total(charge)))
return tf.stack(self.total(charge))
[docs]
def product_gls(self):
ret = self.get_all_factor()
return tf.reduce_prod(ret)
[docs]
def get_all_factor(self):
ret = [self.get_amp_total()]
for i in self:
ret.append(i.get_g_ls())
return ret
[docs]
def get_cp_amp_total(self, charge=1):
if not self.is_cp:
return self.get_amp_total()
total_pos = self.get_amp_total(1)
total_neg = self.get_amp_total(-1)
# print("total_pos", total_pos)
# print("total_neg", total_neg)
charge_cond = charge > 0
# print("charge", charge)
total = tf.where(charge_cond, total_pos, total_neg)
return total
[docs]
def get_amp(self, data_c, data_p, all_data=None, base_map=None):
base_map = self.get_base_map(base_map)
iter_idx = ["..."]
amp_d = []
indices = []
final_indices = "".join(iter_idx + self.amp_index(base_map))
for i in self:
indices.append(i.amp_index(base_map))
amp_d.append(i.get_amp(data_c[i], data_p, all_data=all_data))
if self.need_amp_particle:
rs = self.get_amp_particle(data_p, data_c, all_data=all_data)
total = self.get_cp_amp_total(
charge=all_data.get("charge_conjugation", 1)
)
if rs is not None:
total = total * tf.cast(rs, total.dtype)
# print(total)*self.get_amp_total()
amp_d.append(total)
indices.append([])
if self.aligned:
for i in self:
for idxj, j in enumerate(i.outs):
if j.J != 0:
ang = data_c[i][j].get("aligned_angle", None)
if ang is None and not getattr(
i, "helicity_inner_full", False
):
continue
dt = get_D_matrix_lambda(
ang, j.J, i.list_helicity_inner()[idxj], j.spins
)
amp_d.append(tf.stop_gradient(dt))
idx = [base_map[j], base_map[j].upper()]
indices.append(idx)
final_indices = final_indices.replace(*idx)
idxs = []
for i in indices:
tmp = "".join(iter_idx + i)
idxs.append(tmp)
idx = ",".join(idxs)
idx_s = "{}->{}".format(idx, final_indices)
# ret = amp * tf.reshape(rs, [-1] + [1] * len(self.amp_shape()))
# print(idx_s)#, amp_d)
try:
ret = einsum(idx_s, *amp_d)
except:
ret = tf.einsum(idx_s, *amp_d)
# print(self, ret[0])
# exit()
# ret = einsum(idx_s, *amp_d)
return ret
[docs]
def get_angle_amp(self, data_c, data_p, all_data=None, base_map=None):
base_map = self.get_base_map(base_map)
iter_idx = ["..."]
amp_d = []
indices = []
final_indices = "".join(iter_idx + self.amp_index(base_map))
for i in self:
indices.append(i.amp_index(base_map))
amp_d.append(i.get_angle_amp(data_c[i], data_p, all_data=all_data))
if self.aligned:
for i in self:
for idxj, j in enumerate(i.outs):
if j.J != 0:
ang = data_c[i][j].get("aligned_angle", None)
if ang is None and not getattr(
i, "helicity_inner_full", False
):
continue
dt = get_D_matrix_lambda(
ang, j.J, i.list_helicity_inner()[idxj], j.spins
)
amp_d.append(tf.stop_gradient(dt))
idx = [base_map[j], base_map[j].upper()]
indices.append(idx)
final_indices = final_indices.replace(*idx)
idxs = []
for i in indices:
tmp = "".join(iter_idx + i)
idxs.append(tmp)
idx = ",".join(idxs)
idx_s = "{}->{}".format(idx, final_indices)
# ret = amp * tf.reshape(rs, [-1] + [1] * len(self.amp_shape()))
# print(idx_s)#, amp_d)
ret = tf.einsum(idx_s, *amp_d)
# print(self, ret[0])
# exit()
# ret = einsum(idx_s, *amp_d)
return ret
[docs]
def get_factor_angle_amp(
self, data_c, data_p, all_data=None, base_map=None
):
base_map = self.get_base_map(base_map)
iter_idx = ["..."]
amp_d = []
indices = []
next_map = "zyxwvutsr"
used_idx = ""
final_indices = self.amp_index(base_map)
for i in self:
tmp_idx = i.amp_index(base_map)
tmp_idx = [next_map[0], *tmp_idx]
indices.append(tmp_idx)
used_idx += next_map[0]
amp_d.append(
i.get_factor_angle_amp(data_c[i], data_p, all_data=all_data)
)
next_map = next_map[1:]
final_indices = "".join(iter_idx + list(used_idx) + final_indices)
if self.aligned:
for i in self:
for idxj, j in enumerate(i.outs):
if j.J != 0:
ang = data_c[i][j].get("aligned_angle", None)
if ang is None and not getattr(
i, "helicity_inner_full", False
):
continue
dt = get_D_matrix_lambda(
ang, j.J, i.list_helicity_inner()[idxj], j.spins
)
amp_d.append(tf.stop_gradient(dt))
idx = [base_map[j], base_map[j].upper()]
indices.append(idx)
final_indices = final_indices.replace(*idx)
idxs = []
for i in indices:
tmp = "".join(iter_idx + i)
idxs.append(tmp)
idx = ",".join(idxs)
idx_s = "{}->{}".format(idx, final_indices)
# ret = amp * tf.reshape(rs, [-1] + [1] * len(self.amp_shape()))
# print(idx_s) # , amp_d)
ret = tf.einsum(idx_s, *amp_d)
# print(self, ret[0])
# exit()
# ret = einsum(idx_s, *amp_d)
return ret
[docs]
def get_m_dep(self, data_c, data_p, all_data=None, base_map=None):
base_map = self.get_base_map(base_map)
iter_idx = ["..."]
amp_d = []
indices = []
final_indices = "".join(iter_idx + self.amp_index(base_map))
for i in self:
indices.append(i.amp_index(base_map))
amp_d.append(i.get_m_dep(data_c[i], data_p, all_data=all_data))
if self.need_amp_particle:
rs = self.get_amp_particle(data_p, data_c, all_data=all_data)
total = self.get_cp_amp_total(
all_data.get("charge_conjugation", 1)
)
# print("total_pos", total_pos)
# print("total_neg", total_neg)
if rs is not None:
total = total * tf.cast(rs, total.dtype)
# print("charge", charge)
# print(total)
# print(total)*self.get_amp_total()
amp_d.append(total)
return amp_d
[docs]
def get_amp_particle(self, data_p, data_c, all_data=None):
amp_p = []
if not self.inner:
return 1.0
for i in self.inner:
if len(i.decay) >= 1:
decay_i = i.decay[0]
found = False
for j in i.decay:
if j in self:
decay_i = j
found = True
break
if not found:
raise IndexError(
"not found {} decay in {}".format(i, self)
)
data_c_i = data_c[decay_i]
if "|q|" not in data_c_i:
data_c_i["|q|"] = decay_i.get_relative_momentum(
data_p, True
)
if "|q0|" not in data_c_i:
data_c_i["|q0|"] = decay_i.get_relative_momentum(
data_p, False
)
if "|q|2" not in data_c_i:
data_c_i["|q|2"] = decay_i.get_relative_momentum2(
data_p, True
)
if "|q0|2" not in data_c_i:
data_c_i["|q0|2"] = decay_i.get_relative_momentum2(
data_p, False
)
amp_p.append(i.get_amp(data_p[i], data_c_i, all_data=all_data))
else:
amp_p.append(i.get_amp(data_p[i], all_data=all_data))
rs = 1.0
for i in amp_p:
rs = rs * i
# tf.reduce_prod(amp_p, axis=0)
return rs
[docs]
def amp_shape(self):
ret = [len(self.top.spins)]
for i in self.outs:
ret.append(len(i.spins))
return tuple(ret)
# @simple_cache_fun
[docs]
def amp_index(self, base_map=None):
if base_map is None:
base_map = self.get_base_map()
ret = [base_map[self.top]]
for i in self.outs:
ret.append(base_map[i])
return ret
[docs]
def get_base_map(self, base_map=None):
gen = index_generator(base_map)
if base_map is None:
base_map = {}
ret = base_map.copy()
if self.top not in base_map:
ret[self.top] = next(gen)
for i in self.outs:
if i not in base_map:
ret[i] = next(gen)
for i in self.inner:
if i not in ret:
ret[i] = next(gen)
return ret
[docs]
class DecayGroup(BaseDecayGroup, AmpBase):
"""A Group of Decay Chains with the same final particles."""
def __init__(self, chains):
self.chains_idx = list(range(len(chains)))
first_chain = chains[0]
if not isinstance(first_chain, DecayChain):
chains = [DecayChain(i) for i in chains]
super(DecayGroup, self).__init__(chains)
self.not_full = False
self.polarization = getattr(self.top, "polarization", "none")
# self.init_params()
[docs]
def init_params(self, name=""):
for i in self.resonances:
i.init_params()
inited_set = set()
for i in self:
i.init_params(name)
for j in i:
if j not in inited_set:
j.init_params()
inited_set.add(j)
if self.polarization == "vector":
print("add polarization vector")
if self.top.J == 0.5:
self.polarization_vector = [
self.top.add_var("polarization_px"),
self.top.add_var("polarization_py"),
self.top.add_var("polarization_pz"),
]
if self.top.J == 1:
self.polarization_vector = self.top.add_var(
"polarization_p", shape=(8,)
)
[docs]
def get_factor_variable(self):
ret = []
for i in self:
ret += i.get_factor_variable()
return ret
[docs]
def get_factor(self):
ret = []
for i in self:
ret.append(i.get_factor())
return ret
[docs]
def factor_iteration(self, deep=2):
if deep == 0:
yield None
else:
old_chains_idx = self.chains_idx
for i in old_chains_idx:
self.set_used_chains([i])
for j in self.chains[i].factor_iteration(deep=deep - 1):
yield self.chains[i], j
self.chains_idx = old_chains_idx
[docs]
def get_amp(self, data):
"""
calculate the amplitude as complex number
"""
data_particle = data["particle"]
data_decay = data["decay"]
used_chains = tuple([self.chains[i] for i in self.chains_idx])
chain_maps = self.get_chains_map(used_chains)
base_map = self.get_base_map()
ret = []
for chains in chain_maps:
for decay_chain in chains:
chain_topo = decay_chain.standard_topology()
found = False
for i in data_decay.keys():
if i == chain_topo:
data_decay_i = data_decay[i]
found = True
break
if not found:
raise KeyError("not found {}".format(chain_topo))
data_c = rename_data_dict(data_decay_i, chains[decay_chain])
data_p = rename_data_dict(data_particle, chains[decay_chain])
# print("$$$$$",data_c)
# print("$$$$$",data_p)
amp = decay_chain.get_amp(
data_c, data_p, base_map=base_map, all_data=data
)
ret.append(amp)
# print(decay_chain, amp[:10])
ret = tf.reduce_sum(ret, axis=0)
return ret
[docs]
def get_m_dep(self, data):
"""get mass dependent items"""
data_particle = data["particle"]
data_decay = data["decay"]
used_chains = tuple([self.chains[i] for i in self.chains_idx])
chain_maps = self.get_chains_map(used_chains)
base_map = self.get_base_map()
ret = []
for decay_chain in used_chains:
for chains in chain_maps:
if str(decay_chain) in [str(i) for i in chains]:
maps = chains[decay_chain]
break
chain_topo = decay_chain.standard_topology()
found = False
for i in data_decay.keys():
if i == chain_topo:
data_decay_i = data_decay[i]
found = True
break
if not found:
raise KeyError("not found {}".format(chain_topo))
data_c = rename_data_dict(data_decay_i, maps)
data_p = rename_data_dict(data_particle, maps)
# print("$$$$$",data_c)
# print("$$$$$",data_p)
amp = decay_chain.get_m_dep(
data_c, data_p, base_map=base_map, all_data=data
)
ret.append(amp)
# ret = tf.reduce_sum(ret, axis=0)
return ret
[docs]
def get_angle_amp(self, data):
data_particle = data["particle"]
data_decay = data["decay"]
used_chains = tuple([self.chains[i] for i in self.chains_idx])
chain_maps = self.get_chains_map(used_chains)
base_map = self.get_base_map()
ret = []
for decay_chain in used_chains:
for chains in chain_maps:
if str(decay_chain) in [str(i) for i in chains]:
maps = chains[decay_chain]
break
chain_topo = decay_chain.standard_topology()
found = False
for i in data_decay.keys():
if i == chain_topo:
data_decay_i = data_decay[i]
found = True
break
if not found:
raise KeyError("not found {}".format(chain_topo))
data_c = rename_data_dict(data_decay_i, maps)
data_p = rename_data_dict(data_particle, maps)
amp = decay_chain.get_angle_amp(
data_c, data_p, base_map=base_map, all_data=data
)
ret.append(amp)
# ret = tf.reduce_sum(ret, axis=0)
return amp
[docs]
def get_factor_angle_amp(self, data):
data_particle = data["particle"]
data_decay = data["decay"]
used_chains = tuple([self.chains[i] for i in self.chains_idx])
chain_maps = self.get_chains_map(used_chains)
base_map = self.get_base_map()
ret = []
for decay_chain in used_chains:
for chains in chain_maps:
if str(decay_chain) in [str(i) for i in chains]:
maps = chains[decay_chain]
break
chain_topo = decay_chain.standard_topology()
found = False
for i in data_decay.keys():
if i == chain_topo:
data_decay_i = data_decay[i]
found = True
break
if not found:
raise KeyError("not found {}".format(chain_topo))
data_c = rename_data_dict(data_decay_i, maps)
data_p = rename_data_dict(data_particle, maps)
amp = decay_chain.get_factor_angle_amp(
data_c, data_p, base_map=base_map, all_data=data
)
ret.append(amp)
# ret = tf.reduce_sum(ret, axis=0)
return ret
[docs]
@functools.lru_cache()
def get_swap_factor(self, key):
factor = 1.0
used = []
for i, j in zip(self.identical_particles, key[1]):
p = self.get_particle(i[0])
if int(p.J * 2) % 2 == 0:
continue
for m, n in zip(i, j):
if (m, n) in used or (n, m) in used:
continue
used.append((m, n))
if m != n:
factor *= -1.0
return factor
[docs]
@functools.lru_cache()
def get_id_swap_transpose(self, key, n):
_, change = key
# print(key)
old_order = [str(i) for i in self.outs]
trans = []
for i, j in zip(self.identical_particles, change):
for k, l in zip(i, j):
trans.append((k, l))
trans = tuple(trans)
return self.get_swap_transpose(trans, n)
[docs]
@functools.lru_cache()
def get_swap_transpose(self, trans, n):
trans = dict(trans)
# print(trans)
tmp = {v: k for k, v in trans.items()}
tmp.update(trans)
trans = tmp
# print(trans)
old_order = [str(i) for i in self.outs]
new_order = []
for i in old_order:
new_order.append(trans.get(i, i))
index_map = {k: i for i, k in enumerate(new_order)}
trans_order = [index_map[str(i)] for i in self.outs]
diff = n - len(trans_order)
return [i for i in range(diff)] + [i + diff for i in trans_order]
[docs]
def get_amp2(self, data):
amp = self.get_amp(data)
id_swap = data.get("id_swap", {})
for k, v in id_swap.items():
new_data = {**data, **v}
factor = self.get_swap_factor(k)
amp_swap = factor * self.get_amp(new_data)
# print(k, amp, amp_swap)
swap_index = self.get_id_swap_transpose(k, len(amp_swap.shape))
# print(swap_index)
amp_swap = tf.transpose(amp_swap, swap_index)
amp = amp + amp_swap
return amp
[docs]
def get_amp3(self, data):
amp = self.get_amp2(data)
if "cp_swap" in data:
amp_swap = self.get_amp2(data["cp_swap"])
cg = cp_charge_group(
[str(i) for i in self.outs],
self.identical_particles,
self.cp_particles,
)
name_map = {str(i): i for i in self.outs}
frac = 1.0
same_particle = []
change = []
for a, b in cg:
for i, j in zip(a, b):
if i == j:
same_particle.append(i)
frac = frac * getattr(name_map[i], "C", -1)
else:
change.append((i, j))
transpose = self.get_swap_transpose(
tuple(change), len(amp_swap.shape)
)
p_reverse = [Ellipsis] + [
slice(None, None, -1) for i in range(len(amp_swap.shape) - 1)
]
amp = (
amp
+ tf.transpose(amp_swap, transpose).__getitem__(p_reverse)
* frac
)
return amp
[docs]
def sum_amp(self, data, cached=True):
"""
calculat the amplitude modular square
"""
if not cached:
data = simple_deepcopy(data)
if self.polarization != "none":
return self.sum_amp_polarization(data)
amp = self.get_amp3(data)
amp2s = tf.math.real(amp * tf.math.conj(amp))
idx = list(range(1, len(amp2s.shape)))
sum_A = tf.reduce_sum(amp2s, idx)
return sum_A
[docs]
def sum_with_polarization(self, amp):
if self.polarization != "none":
# (i, la, lb lc ld ...)
amp = tf.reshape(amp, (amp.shape[0], len(self.top.spins), -1))
na, nl = amp.shape[1], amp.shape[2]
rho = self.get_density_matrix()
amp = tf.reshape(amp, (-1, na, 1, nl))
# (i, la, lb lc ld ...)
amp_c = tf.reshape(tf.math.conj(amp), (-1, na, nl))
sum_A = (
tf.reduce_sum(amp * tf.reshape(rho, (na, na, 1)), axis=1)
* amp_c
)
return tf.reduce_sum(tf.math.real(sum_A), axis=[1, 2])
else:
amp2s = tf.math.real(amp * tf.math.conj(amp))
idx = list(range(1, len(amp2s.shape)))
sum_A = tf.reduce_sum(amp2s, idx)
return sum_A
[docs]
def sum_amp_polarization(self, data):
"""
sum amplitude suqare with density _get_cg_matrix
.. math::
P = \\sum_{m, m', \\cdots } A_{m, \\cdots} \\rho_{m, m'} A^{*}_{m', \\cdots}
"""
amp = self.get_amp3(data)
return self.sum_with_polarization(amp)
[docs]
def get_density_matrix(self):
if self.polarization == "vector" and self.top.J == 0.5:
px, py, pz = [i() for i in self.polarization_vector]
zeros = tf.zeros_like(px)
ones = tf.ones_like(px)
rho00 = tf.complex(ones + pz, zeros)
rho11 = tf.complex(ones - pz, zeros)
rho01 = tf.complex(px, -py)
rho10 = tf.complex(px, py)
ret = 0.5 * tf.stack([[rho00, rho01], [rho10, rho11]])
# print(ret)
return ret
elif self.polarization == "vector" and self.top.J == 1:
p = tf.stack(self.polarization_vector())
p = tf.complex(p, tf.zeros_like(p))
gi = np.array(
[
[
[0, 0, 1, 0, 0, 0, 0, 1 / np.sqrt(3)],
[1, -1j, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, -1j, 0, 0, 0],
],
[
[1, 1j, 0, 0, 0, 0, 0, 0],
[0, 0, -1, 0, 0, 0, 0, 1 / np.sqrt(3)],
[0, 0, 0, 0, 0, 1, -1j, 0],
],
[
[0, 0, 0, 1, 1j, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1j, 0],
[0, 0, 0, 0, 0, 0, 0, -2 / np.sqrt(3)],
],
]
)
m = gi * p
E = np.eye(3) + 0j
return E + tf.reduce_sum(m, axis=-1)
raise NotImplementedError
# @simple_cache_fun
[docs]
def amp_index(self, gen=None, base_map=None):
if base_map is None:
base_map = self.get_base_map()
ret = [base_map[self.top]]
for i in self.outs:
ret.append(base_map[i])
return ret
[docs]
def get_base_map(self, gen=None, base_map=None):
if gen is None:
gen = index_generator(base_map)
if base_map is None:
base_map = {self.top: next(gen)}
for i in self.outs:
base_map[i] = next(gen)
return base_map
[docs]
def get_res_map(self):
res_map = {}
for i, decay in enumerate(self.chains):
for j in decay.inner:
if j not in res_map:
res_map[j] = []
res_map[j].append(i)
return res_map
[docs]
def set_used_res(self, res, only=False):
if not isinstance(res, (list, tuple)):
res = [res]
res_set = set()
idx_chains = []
for i in res:
if isinstance(i, str):
res_set.add(BaseParticle(i))
elif isinstance(i, BaseParticle):
res_set.add(i)
elif isinstance(i, int):
idx_chains.append(i)
else:
raise TypeError(
"type({}) = {} not a Particle".format(i, type(i))
)
if not only:
used_res = set()
for i in res_set:
for j, c in enumerate(self.chains):
if i in c.inner:
used_res.add(j)
self.set_used_chains(list(used_res))
else:
unused_res = set(self.resonances) - res_set
unused_decay = set()
res_map = self.get_res_map()
for i in unused_res:
for j in res_map[i]:
unused_decay.add(j)
used_decay = []
for i, _ in enumerate(self.chains):
if i not in unused_decay:
used_decay.append(i)
self.set_used_chains(used_decay)
self.add_used_chains(idx_chains)
[docs]
@contextlib.contextmanager
def temp_used_res(self, res):
old_idx = self.chains_idx
self.set_used_res(res)
yield
self.chains_idx = old_idx
[docs]
def add_used_chains(self, used_chains):
for i in used_chains:
assert isinstance(i, int), "not index of chains"
if i in self.chains_idx:
continue
else:
self.chains_idx.append(i)
[docs]
def set_used_chains(self, used_chains):
self.chains_idx = list(used_chains)
if len(self.chains_idx) != len(self.chains):
self.not_full = True
else:
self.not_full = False
[docs]
def partial_weight(self, data, combine=None):
chains = list(self.chains)
if combine is None:
combine = [[i] for i in range(len(chains))]
o_used_chains = self.chains_idx
weights = []
for i in combine:
self.set_used_res(i)
weight = self.sum_amp(data)
weights.append(weight)
self.set_used_chains(o_used_chains)
return weights
[docs]
def chains_particle(self):
ret = []
for i in self:
ret.append(tuple(i.inner))
return ret
[docs]
def partial_weight_interference(self, data):
chains = list(self.chains)
combine = combinations(range(len(chains)), 2)
o_used_chains = self.chains_idx
weights = {}
for i in combine:
self.set_used_chains(i)
weight = self.sum_amp(data)
weights[i] = weight
self.set_used_chains(o_used_chains)
return weights
[docs]
def generate_phasespace(self, num=100000):
def get_mass(i):
mass = i.get_mass()
if mass is None:
raise Exception("mass is required for particle {}".format(i))
return mass
top_mass = get_mass(self.top)
final_mass = [get_mass(i) for i in self.outs]
from tf_pwa.phasespace import PhaseSpaceGenerator
a = PhaseSpaceGenerator(top_mass, final_mass)
data = a.generate(num)
return dict(zip(self.outs, data))
[docs]
def index_generator(base_map=None):
indices = "abcdefghjklmnopqrstuvwxyz"
if base_map is not None:
for i in base_map:
indices = indices.replace(base_map[i], "")
for i in indices:
yield i
[docs]
def rename_data_dict(data, idx_map):
if isinstance(data, dict):
return {
idx_map.get(k, k): rename_data_dict(v, idx_map)
for k, v in data.items()
}
if isinstance(data, tuple):
return tuple([rename_data_dict(i, idx_map) for i in data])
if isinstance(data, list):
return [rename_data_dict(i, idx_map) for i in data]
return data
[docs]
def value_and_grad(f, var):
with tf.GradientTape() as tape:
s = f(var)
g = tape.gradient(s, var)
return s, g
[docs]
def load_decfile_particle(fname):
with open(fname) as f:
dec = load_dec_file(f)
dec = list(dec)
particles = {}
def get_particles(name):
if name not in particles:
a = get_particle(name)
particles[name] = a
return particles[name]
decay = []
for i in dec:
cmd, var = i
if cmd == "Particle":
a = get_particles(var["name"])
setattr(a, "params", var["params"])
if cmd == "Decay":
for j in var["final"]:
outs = [get_particles(k) for k in j["outs"]]
de = Decay(get_particles(var["name"]), outs)
for k in j:
if k != "outs":
setattr(de, k, j[k])
decay.append(de)
if cmd == "RUNNINGWIDTH":
pa = get_particles(var[0])
setattr(pa, "running_width", True)
top, inner, outs = split_particle_type(decay)
return top, inner, outs
regist_config(DEFAULT_DECAY, (HelicityDecay, {}))