Source code for tf_pwa.transform

from .config import create_config

set_trans, get_trans, register_trans = create_config()

T = "Tensor"


[docs] class BaseTransform: def __init__(self, x: "list | str", **kwargs): self.x = x def __call__(self, dic: dict) -> T: x = self.read(dic) return self.call(x)
[docs] def read(self, x: dict) -> T: if isinstance(self.x, (list, tuple)): return [x[i] for i in self.x] elif isinstance(self.x, str): return x[self.x] else: raise TypeError("only str of list of str is supported for x")
[docs] def call(self, x: T) -> T: raise NotImplementedError()
[docs] def inverse(self, y: T) -> T: return None
[docs] def create_trans(item: dict) -> BaseTransform: model = item.pop("model", "default") cls = get_trans(model) obj = cls(**item) obj._model_name = model return obj
[docs] @register_trans("default") @register_trans("linear") class LinearTrans(BaseTransform): def __init__( self, x: "list | str", k: float = 1.0, b: float = 0.0, **kwargs ): super().__init__(x) self.k = k self.b = b
[docs] def call(self, x) -> T: return self.k * x + self.b
[docs] def inverse(self, x: T) -> T: return (x - self.b) / self.k
[docs] @register_trans("blind") class BlindTrans(BaseTransform): def __init__( self, x: "list | str", range: "list[float]", key: "str | int | float", **kwargs, ): super().__init__(x) self.key = key self.range = range self.range_size = abs(self.range[1] - self.range[0]) self.start_point = min(self.range[0], self.range[1]) import numpy as np rng = np.random.RandomState(key) self.bias = rng.random()
[docs] def call(self, x) -> T: return ( x - self.start_point + self.bias * self.range_size ) % self.range_size + self.start_point
[docs] def inverse(self, x: T) -> T: return ( (x - self.start_point) - self.bias * self.range_size ) % self.range_size + self.start_point