Source code for tf_pwa.config_loader.plotter

import contextlib
import itertools
import logging
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import yaml

from tf_pwa.config_loader import ConfigLoader
from tf_pwa.config_loader.plot import (
    LineStyleSet,
    _get_cfit_bg,
    _get_cfit_eff_phsp,
    create_chain_property,
)
from tf_pwa.data import ReadData, batch_call, data_index, data_shape
from tf_pwa.histogram import Hist1D, WeightedData

logger = logging.getLogger(__file__)


[docs] class PlotData: def __init__( self, dataset, weight=None, partial_weight=None, use_weighted=False ): self.dataset = dataset self.n_data = data_shape(dataset) self.weight = np.ones(self.n_data) if weight is None else weight self.partial_weight = {} if partial_weight is None else partial_weight self.scale = 1.0 self.use_weighted = use_weighted
[docs] def total_size(self): return np.sum(self.weight) * self.scale
def __mul__(self, frac): if isinstance(frac, (float, int)): from copy import copy ret = copy(self) ret.scale = self.scale * frac return ret raise NotImplementedError()
[docs] def get_histogram(self, var, partial=None, **kwargs): value = var(self.dataset) if partial is not None: w = self.partial_weight[partial] else: w = self.weight if self.use_weighted: return WeightedData(value, weights=w * self.scale, **kwargs) else: return Hist1D.histogram(value, weights=w * self.scale, **kwargs)
[docs] class PlotDataGroup: def __init__(self, datasets): self.datasets = datasets self.scale = 1.0 self.n_data = sum(i.n_data for i in self.datasets)
[docs] def total_size(self): return sum([np.sum(i.weight) for i in self.datasets]) * self.scale
[docs] def get_histogram(self, var, partial=None, **kwargs): hist = [i.get_histogram(var, partial, **kwargs) for i in self.datasets] tmp = hist[0] for i in hist[1:]: tmp = tmp + i return tmp
[docs] class Frame: def __init__( self, var, x_range=None, nbins=None, name=None, display=None, trans=None, **extra, ): self.var = var self.x_range = x_range self.nbins = nbins self.trans = trans if callable(var): self.var_trans = var else: self.var_trans = ReadData(var, trans) self.name = str(var) if name is None else name self.display = self.name if display is None else display self.extra = extra self.no_below = True
[docs] def get_histogram(self, data, partial=None, bin_scale=1): if self.nbins is None: n_bins = max(5, min(data.n_data // 25, 200)) hist = ( data.get_histogram( self.var_trans, partial=partial, range=self.x_range, bins=n_bins, ) * bin_scale ) else: hist = ( data.get_histogram( self.var_trans, partial=partial, range=self.x_range, bins=self.nbins * bin_scale, ) * bin_scale ) if self.x_range is None: self.x_range = hist.binning[0], hist.binning[-1] if self.nbins is None: self.nbins = (len(hist.binning) - 1) // bin_scale return hist
[docs] def set_axis(self, axis, **config): if axis is plt: axis = plt.gca() config = {**self.extra, **config} delta = (self.x_range[1] - self.x_range[0]) / self.nbins if config.get("units", "") == "": axis.set_xlabel(self.display) axis.set_ylabel("Events/{:.3g}".format(delta)) else: axis.set_xlabel(self.display + config.get("units", "")) axis.set_ylabel( "Events/({:.3g} {})".format(delta, config.get("units", "")) ) lower = None if self.no_below: lower = 0 axis.set_ylim((lower, config.get("upper_ylim", None))) if "yscale" in config: axis.set_yscale(config["yscale"]) if "has_legend" in config: if config["has_legend"]: if "legend.fontsize" in config: axis.legend(fontsize=config["legend.fontsize"]) else: axis.legend()
[docs] @ConfigLoader.register_function() def get_all_frame(self): ret = {} for conf in self.plot_params.get_params(): name = conf.get("name") display = conf.get("display", name) upper_ylim = conf.get("upper_ylim", None) readdata = conf.get("readdata") has_legend = conf.get("legend", False) xrange = conf.get("range", None) bins = conf.get("bins", None) units = conf.get("units", "") yscale = conf.get("yscale", "linear") ret[name] = Frame( readdata, name=name, display=display, x_range=xrange, nbins=bins, units=units, yscale=yscale, has_legend=has_legend, upper_ylim=upper_ylim, ) return ret
[docs] class PlotAllData: def __init__(self, amp, data, phsp, bg=None, res=None, use_weighted=False): self.amp = amp self.data = data self.phsp = phsp self.bg = bg self.res = res self.datasets = { "data": PlotData( data, weight=data.get_weight(), use_weighted=use_weighted ), "fitted": PlotData( phsp, weight=phsp.get_weight() * batch_call(amp, phsp), use_weighted=use_weighted, ), } if self.bg is not None: self.datasets["bg"] = PlotData( bg, weight=-bg.get_weight(), use_weighted=use_weighted ) if res is None: weights = amp.partial_weight(phsp) new_res = [str(i) for i in amp.decay_group] else: weights = [] used_res = amp.used_res new_res = [] for i in res: if not isinstance(i, list): i = [i] new_res.append(tuple(i)) amp.set_used_res(i) weights.append(amp(phsp)) # print(weights, amp.decay_group.chains_idx) amp.set_used_res(used_res) self.datasets["fitted"].partial_weight = dict(zip(new_res, weights)) if self.bg is None: n_sig = self.datasets["data"].total_size() else: n_sig = ( self.datasets["data"].total_size() - self.datasets["bg"].total_size() ) self.datasets["fitted"].scale = ( n_sig / self.datasets["fitted"].total_size() )
[docs] def get_all_histogram(self, var, bin_scale=3): ret = {} for k, v in self.datasets.items(): ret[(k,)] = var.get_histogram(v) for k2 in v.partial_weight: ret[(k, k2)] = var.get_histogram( v, partial=k2, bin_scale=bin_scale ) return ret
[docs] def merge_hist(hists): hists = list(hists) ret = {k: v for k, v in hists[0].items()} for i in hists[1:]: for k in ret: ret[k] = ret[k] + i[k] return ret
[docs] class StyleSet: def __init__(self, file_name): self.file_name = file_name if file_name is not None: with open(file_name) as f: all_style = yaml.full_load(f) all_style = {} if all_style is None else all_style self.all_style = {} for k in all_style: id_ = k.pop("id") self.all_style[id_] = k else: self.all_style = {} self.prop_cycler = itertools.cycle(plt.rcParams["axes.prop_cycle"])
[docs] def get(self, key, value=None): if key in self.all_style: return self.all_style[key] if value is not None: self.all_style[key] = value return self.all_style[key] self.all_style[key] = self.generate_new_style() return self.all_style[key]
[docs] def set(self, key, value): self.all_style[key] = {**self.all_style.get(key, {}), **value}
[docs] def generate_new_style(self): return next(self.prop_cycler)
[docs] def save(self): with open(self.file_name, "w") as f: all_style = [{"id": k, **v} for k, v in self.all_style.items()] yaml.dump(all_style, f, indent=2)
[docs] class Plotter: def __init__( self, config, legend_file=None, res=None, datasets=None, use_weighted=False, ): self.config = config datasets = {} if datasets is None else datasets self.alldatas = config.get_all_plotdatas( res=res, use_weighted=use_weighted, **datasets ) self.res = res self.chain_property = create_chain_property(config, res=res) self.all_frame = config.get_all_frame() self.style = StyleSet(legend_file) self.plot_item = None self.extra_plot_item = [] self.hidden_plot_item = [] self.small_bg = True self.smooth = False self.draw_item = {} self.default_plot_style = { (("data",),): { "label": "data", "type": "error", "color": "black", "linestyle": "none", "marker": ".", "zorder": 100, }, (("bg",),): { "label": "background", "type": "bar", "color": "grey", "alpha": 0.5, }, (("fitted",), ("bg",)): { "label": "total fit", "type": "hist", "color": "black", "linestyle": "-", "marker": "", "zorder": 99, }, (("fitted",),): { "label": "total fit", "type": "hist", "color": "black", "linestyle": "-", "marker": "", "zorder": 99, }, }
[docs] def get_label(self, key): if len(key) == 1: head = key[0] if len(head) == 2: if head[0] == "fitted": dic = head[1]
[docs] def get_plot_style(self, example_hist): for k, v in example_hist.items(): if k == ("data",): self.style.get((k,), self.default_plot_style[(k,)]) elif k == ("fitted",): if ("bg",) in example_hist: key = (("fitted",), ("bg",)) self.style.get(key, self.default_plot_style[key]) self.style.get((("fitted",),), {}) else: self.style.get( (("fitted",),), self.default_plot_style[(k,)] ) elif k == ("bg",): self.style.get((k,), self.default_plot_style[(k,)]) elif k[0] == "fitted": style = {} if self.smooth: style["type"] = "kde" extra_style = self.get_res_style(k[1]) self.style.get((k,), {**style, **extra_style}) else: pass # if self.smooth: # self.style.get((k,), {"type": "kde"}) # else: # self.style.get((k,)) return self.style
[docs] def get_res_style(self, key): if isinstance(key, str): for i, j in enumerate(self.config.get_decay()): if str(j) == key: style = self.chain_property[i] return {"name": style[1], "label": style[2]} elif isinstance(key, tuple): for i, name in enumerate(self.res): if not isinstance(name, list): name = [name] if tuple(name) == key: style = self.chain_property[i] return {"name": style[1], "label": style[2]} return {}
[docs] def set_plot_item(self, example_hist): self.plot_item = [] for k, v in example_hist.items(): self.plot_item.append((k,)) if ("bg",) in example_hist: if (("fitted",), ("bg",)) not in self.plot_item: self.plot_item.append((("fitted",), ("bg",))) if self.small_bg: self.plot_item.remove((("fitted",),))
[docs] def plot_frame(self, name, idx=None, ax=plt, bin_scale=3): """plot frame for all partial wave :param name: data variable frame name :type name: str :param idx: data index, None for all data, defaults to None :type idx: int, optional :param bin_scale: smooth bin scale, defaults to 3 :type bin_scale: float, optional :param ax: plot on axis ax :type ax: matplotlib.Axes, optional :return: matplotlib.Axes """ frame = self.all_frame.get(name, None) if frame is None: raise IndexError("no such frame") return self.plot_var(frame, idx, ax=ax, bin_scale=bin_scale)
[docs] def get_all_hist(self, frame, idx=None, bin_scale=3): """create all partial wave histogram for observation frame. :param name: Function for get observation in datasets :type name: Frame, or callable :param idx: data index, None for all data, defaults to None :type idx: int, optional :param bin_scale: smooth bin scale, defaults to 3 :type bin_scale: float, optional :return: collection of histogram :rtype: dict """ if idx is None: hists = merge_hist( i.get_all_histogram(frame, bin_scale=bin_scale) for i in self.alldatas ) else: hists = self.alldatas[idx].get_all_histogram( frame, bin_scale=bin_scale ) return hists
[docs] def plot_var(self, frame, idx=None, ax=plt, bin_scale=3): """plot data observation for all partial wave :param name: Function for get observation in datasets :type name: Frame, or callable :param idx: data index, None for all data, defaults to None :type idx: int, optional :param bin_scale: smooth bin scale, defaults to 3 :type bin_scale: float, optional :param ax: plot axis :type ax: matplotlib.Axes, optional :return: """ if not isinstance(frame, Frame): frame = Frame(frame) hists = self.get_all_hist(frame, idx=idx, bin_scale=bin_scale) plot_style = self.get_plot_style(hists) if self.plot_item is None: self.set_plot_item(hists) no_below = True for i in self.extra_plot_item + self.plot_item: if i in self.hidden_plot_item: continue hist = hists[i[0]] for j in i[1:]: hist = hist + hists[j] style = self.style.get(i) style = {k: v for k, v in style.items() if k not in ["name", "id"]} a = hist.draw(ax, **style) no_below = no_below and np.all(hist.count >= 0) self.draw_item[i] = a frame.no_below = no_below frame.set_axis(ax) return ax
[docs] def forzen_style(self): assert len(self.draw_item) > 0 import matplotlib as mpl for key in self.draw_item: style = self.style.get(key) plot_type = style.get("type", "hist") if plot_type == "hist" in ["line", "error", "kde", "hist"]: art = self.draw_item[key][0] tmp = { "color": art.get_color(), "linestyle": art.get_linestyle(), "marker": art.get_marker(), } self.style.set(key, tmp) elif plot_type == "bar": art = self.draw_item[key][0] tmp = {"facecolor": mpl.colors.to_hex(art.get_facecolor())} self.style.set(key, tmp) else: pass
[docs] @contextlib.contextmanager def old_style(self, extra_config=None, color_first=True): """context for base style, see matplotlib.rcParams for more configuration :param extra_config: new configs, defaults to None :type extra_config: ditc, optional :param color_first: order of color and linestyle, defaults to True :type color_first: bool, optional """ import matplotlib as mpl from cycler import cycler extra_config = {} if extra_config is None else extra_config colors = [ "red", "orange", "purple", "springgreen", "y", "green", "blue", "c", ] linestyles = ["-", "--", "-.", ":"] marker = [",", ".", "^"] if color_first: style = ( cycler(marker=marker) * cycler(linestyle=linestyles) * cycler(color=colors) ) else: style = ( cycler(marker=marker) * cycler(color=colors) * cycler(linestyle=linestyles) ) config = { "axes.prop_cycle": style, "legend.frameon": False, "legend.borderpad": 0.0, "legend.labelspacing": 0.1, } total_config = {**config, **extra_config} with mpl.rc_context(total_config): yield total_config
[docs] def plot_frame_with_pull( self, name, idx=None, bin_scale=3, pull_config=None ): """plot frame with pull for all partial wave :param name: data variable frame name :type name: str :param idx: data index, None for all data, defaults to None :type idx: int, optional :param bin_scale: smooth bin scale, defaults to 3 :type bin_scale: float, optional :param pull_config: pull plot style, defaults to None :type pull_config: dict, optional :return: matplotlib.Axes for plot and pull """ frame = self.all_frame.get(name, None) if frame is None: raise IndexError("no such frame") hists = self.get_all_hist(frame, idx=idx, bin_scale=bin_scale) # import matplotlib.gridspec as gridspec # gs = gridspec.GridSpec(4,1) ax = plt.subplot2grid( (4, 1), (0, 0), rowspan=3 ) # plt.subplot(gs[:3,0]) ax2 = plt.subplot2grid( (4, 1), (3, 0), rowspan=1 ) # plt.subplot(gs[3:,0]) self.plot_var(frame, idx, ax=ax, bin_scale=bin_scale) if ("bg",) in hists: total_fit = hists[("bg",)] + hists[("fitted",)] else: total_fit = hists[("fitted",)] pull = hists[("data",)] - total_fit pull_config = {} if pull_config is None else pull_config pull.draw_pull(ax2, **pull_config) frame.set_axis(ax2, has_legend=False) ax2.axhline(y=0, color="r", linewidth=0.5) ax2.axhline( y=3, color="r", linestyle="--", linewidth=0.5, ) ax2.axhline( y=-3, color="r", linestyle="--", linewidth=0.5, ) ax2.set_ylim((-5, 5)) ax2.set_ylabel("pull") ax2.set_yscale("linear") ax.set_xticks([]) ax.set_xlabel("") return ax, ax2
[docs] def save_all_frame( self, prefix="figure/", format="png", idx=None, plot_pull=False, pull_config=None, ): """Save all frame in with prefix. like ConfigLoader.plot_partial_waves :param prefix: prefix for file name, defaults to "figure/" :type prefix: str, optional :param format: figure format, defaults to "png" :type format: str, optional :param idx: dataset index, defaults to None :type idx: int, optional :param plot_pull: if plot pulls, defaults to False :type plot_pull: bool, optional :param pull_config: configuration for plot pulls, defaults to None :type pull_config: dict, optional """ path = os.path.dirname(prefix) os.makedirs(path, exist_ok=True) for name in self.all_frame: plt.clf() if plot_pull: ax, _ = self.plot_frame_with_pull( name, idx=idx, pull_config=pull_config ) else: ax = self.plot_frame(name) logger.info("save figure: ", prefix + name + "." + format) plt.savefig(prefix + name + "." + format)
[docs] def add_ref_amp(self, ref_amp, name="reference fit"): for i in self.alldatas: base_data = i.datasets["fitted"] phsp = base_data.dataset i.datasets[name] = PlotData( phsp, ref_amp(phsp) * phsp.get_weight() ) i.datasets[name].scale = ( base_data.total_size() / i.datasets[name].total_size() ) if "bg" in self.alldatas[0].datasets: key = ((name,), ("bg",)) self.hidden_plot_item.append(((name,),)) else: key = ((name,),) self.extra_plot_item.append(key) self.style.get(key, {"label": name, "linewidth": 2, "zorder": -1})
[docs] @ConfigLoader.register_function() def get_plotter( self, legend_file=None, res=None, datasets=None, use_weighted=False ): datasets = None if datasets is None else datasets return Plotter( self, legend_file, res=res, datasets=datasets, use_weighted=use_weighted, )
[docs] @ConfigLoader.register_function() def get_all_plotdatas( self, data=None, phsp=None, bg=None, res=None, use_weighted=False ): amp = self.get_amplitude() data = self.get_data("data") if data is None else data phsp = self.get_data("phsp") if phsp is None else data bg = self.get_data("bg") if bg is None else bg if bg is None: if self.config["data"].get("model", "auto") == "cfit": bg = _get_cfit_bg(self, data, phsp) else: bg = [None] * len(data) if self.config["data"].get("model", "auto") == "cfit": phsp = _get_cfit_eff_phsp(self, phsp) return [ PlotAllData( amp, data=a, phsp=b, bg=c, res=res, use_weighted=use_weighted ) for a, b, c in zip(data, phsp, bg) ]