from .config import create_config
set_trans, get_trans, register_trans = create_config()
T = "Tensor"
[docs]
def create_trans(item: dict) -> BaseTransform:
model = item.pop("model", "default")
cls = get_trans(model)
obj = cls(**item)
obj._model_name = model
return obj
[docs]
@register_trans("default")
@register_trans("linear")
class LinearTrans(BaseTransform):
def __init__(
self, x: "list | str", k: float = 1.0, b: float = 0.0, **kwargs
):
super().__init__(x)
self.k = k
self.b = b
[docs]
def call(self, x) -> T:
return self.k * x + self.b
[docs]
def inverse(self, x: T) -> T:
return (x - self.b) / self.k