Source code for tf_pwa.experimental.opt_int

import itertools

import tensorflow as tf

from tf_pwa.amp import get_decay, get_particle
from tf_pwa.data import data_split


[docs] def split_gls(dec_chain): gls = [i.get_ls_list() for i in dec_chain] ls_combination = list(itertools.product(*gls)) for i in ls_combination: for gi, j in zip(i, dec_chain): j.set_ls([gi]) yield i, dec_chain for j, g in zip(dec_chain, gls): j.set_ls(g)
[docs] def build_sum_amplitude(dg, dec_chain, data): cached = [] for i, dc in split_gls(dec_chain): amp = dg.get_amp3(data) int_mc = amp gls = dc.product_gls() cached.append(amp / gls) return cached
[docs] def build_int_matrix(dec, data, weight=None): hij = {} used_chains = dec.chains_idx for k, i in enumerate(dec): dec.set_used_chains([k]) for j, amp in enumerate(build_sum_amplitude(dec, i, data)): hij[(i, j)] = amp dec.set_used_chains(used_chains) ret = [] if weight is None: weight = data.get("weight", 1.0) index = list(hij.keys()) weight = tf.cast(weight, hij[index[0]].dtype) n_lambda = len(hij[index[0]].shape) - 1 weight = tf.reshape(weight, [-1] + [1] * n_lambda) for i in index: tmp = [] for j in index: xij = hij[i] * tf.math.conj(hij[j]) xij = tf.reduce_sum(weight * xij) tmp.append(xij) ret.append(tmp) return index, ret
[docs] def build_int_matrix_batch(dec, data, batch=65000): index, ret = None, [] for i in data_split(data, batch): index, a = build_int_matrix(dec, i) ret.append(a) return index, tf.reduce_sum(ret, axis=0)
[docs] def build_params_vector(dec, concat=True): ret = [] for i in dec: factor = i.get_all_factor() a = gls_combine(factor) ret.append(a) if concat: return tf.concat(ret, axis=0) return ret
[docs] def build_params_matrix(dec): pv = build_params_vector(dec) return pv[:, None] * tf.math.conj(pv)[None, :]
[docs] def gls_combine(fs): ret = fs[0] for i in fs[1:]: ret = tf.reshape(ret[:, None] * i[None, :], (-1,)) return ret
[docs] def cached_int_mc(dec, data, batch=65000): a, int_matrix = build_int_matrix_batch(dec, data, batch) @tf.function def int_mc(): pm = build_params_matrix(dec) ret = tf.reduce_sum(pm * int_matrix) return tf.math.real(ret) return int_mc