Source code for tf_pwa.experimental.build_amp

import tensorflow as tf

from tf_pwa.data import data_shape
from tf_pwa.utils import time_print

from .opt_int import split_gls


[docs] def build_sum_amplitude(dg, dec_chain, data): cached = [] for i, dc in split_gls(dec_chain): amp = dg.get_amp(data) int_mc = amp m_dep = dg.get_m_dep(data) m_dep_all = 1.0 for i in m_dep: for j in i: m_dep_all *= tf.reshape(j, (-1,)) m_dep_all = tf.reshape(m_dep_all, [-1] + [1] * (len(amp.shape) - 1)) cached.append(amp / m_dep_all) return cached
[docs] def build_amp_matrix(dec, data, weight=None): hij = [] used_chains = dec.chains_idx index = [] for k, i in enumerate(dec): dec.set_used_chains([k]) tmp = [] for j, amp in enumerate(build_sum_amplitude(dec, i, data)): tmp.append(amp) hij.append(tmp) dec.set_used_chains(used_chains) # print([i.shape for i in hij.values()]) # print([[j.shape for j in i] for i in hij]) return index, hij
def _prod(ls): ret = 1 for i in ls: ret *= i return ret
[docs] def build_params_vector(dg, data): n_data = data_shape(data) m_dep = dg.get_m_dep(data) ret = [] for i in m_dep: tmp = i[0] if tmp.shape[0] == 1: tmp = tf.tile(tmp, [n_data] + [1] * (len(tmp.shape) - 1)) tmp = tf.reshape(tmp, (-1, _prod(tmp.shape[1:]))) for j in i[1:]: tmp2 = tf.reshape(j, (-1, _prod(j.shape[1:]))) tmp = tmp[:, :, None] * tmp2[:, None, :] tmp = tf.reshape(tmp, (-1, _prod(tmp.shape[1:]))) ret.append(tmp) return ret
[docs] def cached_amp2s(dg, data): _amp = cached_amp(dg, data) @time_print @tf.function def _amp2s(): # pragma: no cover # because of tf_funciton amp = _amp() amp2s = tf.math.real(amp * tf.math.conj(amp)) return tf.reduce_sum(amp2s, list(range(1, len(amp2s.shape)))) return _amp2s
[docs] def build_amp2s(dg): @tf.function(experimental_relax_shapes=True) def _amp2s(data, cached_data): n_data = data_shape(data) pv = build_params_vector(dg, data) ret = [] for i, j in zip(pv, cached_data): # print(j) # print(i.shape) a = tf.reshape(i, [-1, i.shape[1]] + [1] * (len(j[0].shape) - 1)) ret.append(tf.reduce_sum(a * tf.stack(j, axis=1), axis=1)) # print(ret) amp = tf.reduce_sum(ret, axis=0) amp2s = tf.math.real(amp * tf.math.conj(amp)) return tf.reduce_sum(amp2s, list(range(1, len(amp2s.shape)))) return _amp2s
[docs] def build_sum_angle_amplitude(dg, dec_chain, data): cached = [] for i, dc in split_gls(dec_chain): amp = dg.get_angle_amp(data) cached.append(amp) return cached
[docs] def build_angle_amp_matrix(dec, data, weight=None): hij = [] used_chains = dec.chains_idx for k, i in enumerate(dec): dec.set_used_chains([k]) tmp = [] for j, amp in enumerate(build_sum_angle_amplitude(dec, i, data)): tmp.append(amp) hij.append(tmp) dec.set_used_chains(used_chains) return list(dec), hij
[docs] def amp_matrix_as_dict(dec, hij): ret = {} for i, d in zip(hij, dec): ret[d] = i return ret
[docs] def cached_amp(dg, data, matrix_method=build_angle_amp_matrix): idx, c_amp = matrix_method(dg, data) n_data = data_shape(data) @tf.function def _amp(): pv = build_params_vector(dg, data) ret = [] for i, j in zip(pv, c_amp): a = tf.reshape(i, [n_data, -1] + [1] * (len(j[0].shape) - 1)) ret.append(tf.reduce_sum(a * tf.stack(j, axis=1), axis=1)) # print(ret) amp = tf.reduce_sum(ret, axis=0) return amp return _amp