from .config import create_config
set_trans, get_trans, register_trans = create_config()
T = "Tensor"
[docs]
def create_trans(item: dict) -> BaseTransform:
model = item.pop("model", "default")
cls = get_trans(model)
obj = cls(**item)
obj._model_name = model
return obj
[docs]
@register_trans("default")
@register_trans("linear")
class LinearTrans(BaseTransform):
def __init__(
self, x: "list | str", k: float = 1.0, b: float = 0.0, **kwargs
):
super().__init__(x)
self.k = k
self.b = b
[docs]
def call(self, x) -> T:
return self.k * x + self.b
[docs]
def inverse(self, x: T) -> T:
return (x - self.b) / self.k
[docs]
@register_trans("blind")
class BlindTrans(BaseTransform):
def __init__(
self,
x: "list | str",
range: "list[float]",
key: "str | int | float",
**kwargs,
):
super().__init__(x)
self.key = key
self.range = range
self.range_size = abs(self.range[1] - self.range[0])
self.start_point = min(self.range[0], self.range[1])
import numpy as np
rng = np.random.RandomState(key)
self.bias = rng.random()
[docs]
def call(self, x) -> T:
return (
x - self.start_point + self.bias * self.range_size
) % self.range_size + self.start_point
[docs]
def inverse(self, x: T) -> T:
return (
(x - self.start_point) - self.bias * self.range_size
) % self.range_size + self.start_point