Source code for tf_pwa.vis

import math

import matplotlib.pyplot as plt

from .particle import split_particle_type


[docs] class DotGenerator: dot_head = """ digraph { rankdir=LR; node [shape=point]; edge [arrowhead=none, labelfloat=true]; """ dot_tail = "}\n" dot_ranksame = " {{ rank=same {} }};\n" dot_default_node = ' "{}" [shape=none];\n' dot_default_edge = ' "{}" -> "{}";\n' dot_label_edge = ' "{}" -> "{}" [label="{}"];\n' def __init__(self, top): self.top = top
[docs] def get_dot_source(self): chains = self.top.chain_decay() ret = [] for i in chains: dot_source = self.dot_chain(i) ret.append(dot_source) return ret
[docs] @staticmethod def dot_chain(chains, has_label=True): ret = DotGenerator.dot_head top, _, outs = split_particle_type(chains) def format_particle(ps): s = ['"{}"'.format(i) for i in ps] return ",".join(s) for i in top: ret += DotGenerator.dot_default_node.format(i) for i in outs: ret += DotGenerator.dot_default_node.format(i) ret += DotGenerator.dot_ranksame.format(format_particle(top)) ret += DotGenerator.dot_ranksame.format(format_particle(outs)) decay_dict = {} edges = [] for i in chains: if i.core in top: edges.append((i.core, i)) else: decay_dict[i.core] = i for j in i.outs: edges.append((i, j)) for i, j in edges: if j in decay_dict: if has_label: ret += DotGenerator.dot_label_edge.format( i, decay_dict[j], j ) else: ret += DotGenerator.dot_default_edge.format( i, decay_dict[j] ) else: ret += DotGenerator.dot_default_edge.format(i, j) ret += DotGenerator.dot_tail return ret
[docs] def draw_decay_struct(decay_chain, show=False, **kwargs): from graphviz import Source a = DotGenerator.dot_chain(decay_chain) g = Source(a, **kwargs) if show: g.view() else: g.render()
[docs] def get_node_layout(decay_chain): stl = decay_chain.sorted_table_layers() decay_map = {} max_branchs = max([len(i.outs) for i in decay_chain]) ys = {} xs = {} for l, p in enumerate(stl[:0:-1]): if p is None: continue for k, v in p: ys[k] = ys.get(k, 0) xs[k] = l + 1 for i in decay_chain: if i.core == k: decay_map[k] = i n = len(i.outs) for j, m in enumerate(i.outs): ys[m] = ys[k] + (-j - 0.5 + n / 2) / ( max_branchs ** (l + 1) + 2 ) break return xs, ys
[docs] def reorder_final_particle(decay_chain, ys): stl = decay_chain.sorted_table_layers() outs = sorted(decay_chain.outs, key=ys.get) ys_new = {k: i for i, k in enumerate(outs)} for p in stl[2:]: if p is None: continue for l, (k, v) in enumerate(p): ys_new[k] = sum(ys_new.get(i) for i in v) / len(v) return ys_new
[docs] def get_layout(decay_chain, xs, ys): stl = decay_chain.sorted_table_layers() points = {"__top": (0, ys[decay_chain.top])} for i in xs: points[i] = xs[i], ys[i] for i in decay_chain.outs: points[i] = len(stl) - 1, ys[i] lines = [("__top", decay_chain.top)] for i in decay_chain: for j in i.outs: lines.append((i.core, j)) return lines, points
[docs] def get_decay_layout(decay_chain): xs, ys = get_node_layout(decay_chain) ys = reorder_final_particle(decay_chain, ys) lines, points = get_layout(decay_chain, xs, ys) return lines, points
[docs] def plot_decay_struct(decay_chain, ax=plt): lines, points = get_decay_layout(decay_chain) for a, b in lines: x, y = points[a] x2, y2 = points[b] ax.arrow( x, y, x2 - x, y2 - y, width=0.01, length_includes_head=True, label=str(b), ) rotation = math.atan2(y2 - y, x2 - x) / math.pi * 180 name = str(b) if (x2 - x) * (y2 - y) >= 0: ax.text( (x + x2) / 2, (y + y2) / 2 + 0.01, name, rotation=rotation, ha="center", va="bottom", ) else: ax.text( (x + x2) / 2, (y + y2) / 2 + 0.01, name, rotation=rotation, ha="center", va="bottom", ) ax.axis("off")