Source code for tf_pwa.generator.base_generator

import numpy as np

from tf_pwa.config import create_config
from tf_pwa.generator import BaseGenerator

set_generator, get_generator, register_generator = create_config()


[docs] class Simple1DGenerator(BaseGenerator): def __init__(self, name, func, params): self.name = name self.func = func self.params = params
[docs] def generate(self, N): x = self.func(**self.params, size=(N,)) return {self.name: x}
[docs] class DefaultGenerator(BaseGenerator): def __init__(self, name, value): self.name = name self.value = value
[docs] def generate(self, N): x = np.ones((N,)) * self.value return {self.name: x}
[docs] def create_simple_generator(name, params): params = params.copy() model = params.get("model", "default") gen_params = { k: v for k, v in params.items() if k not in ["model", "default", "dtype"] } if "params" in gen_params: gen_params = gen_params["params"] model_class = get_generator(model, None) if model_class is None: if hasattr(np.random, model): func = getattr(np.random, model) return Simple1DGenerator(name, func, gen_params) if model == "default": default_var = params["default"] return DefaultGenerator(name, value=default_var) else: return model_class(**gen_params) raise ValueError("not support model: {}".format(model))