Source code for tf_pwa.experimental.wrap_function

import numpy as np
import tensorflow as tf

from tf_pwa.config import get_config


def _wrap_struct(dic, first_none=True):
    if isinstance(dic, dict):
        return {
            k: _wrap_struct(dic[k], first_none) for k in sorted(dic.keys())
        }
    if isinstance(dic, list):
        return [_wrap_struct(v, first_none) for v in dic]
    if isinstance(dic, tuple):
        return tuple([_wrap_struct(v, first_none) for v in dic])
    if isinstance(dic, (tf.Tensor, np.ndarray)):
        shape = dic.shape
        if first_none:
            shape = (None, *shape[1:])
        return tf.TensorSpec(shape, dtype=dic.dtype)
    return dic


def _flatten(dic):
    if isinstance(dic, dict):
        for k in sorted(dic.keys()):
            yield from _flatten(dic[k])
    if isinstance(dic, (list, tuple)):
        for v in dic:
            yield from _flatten(v)
    if isinstance(dic, (tf.Tensor, np.ndarray, tf.TensorSpec)):
        yield dic


[docs] class Count: def __init__(self, idx=0): self.idx = 0
[docs] def add(self, value=1): self.idx += value
def _nest(dic, value, idx=None): if idx is None: idx = Count(0) if isinstance(dic, dict): return {k: _nest(v, value, idx) for k, v in dic.items()} if isinstance(dic, list): return [_nest(v, value, idx) for v in dic] if isinstance(dic, tuple): return tuple([_nest(v, value, idx) for v in dic]) if isinstance(dic, (tf.Tensor, np.ndarray, tf.TensorSpec)): idx.add() return value[(idx.idx - 1) % len(value)] return dic
[docs] class WrapFun: def __init__(self, f, jit_compile=False): self.f = f self.cached_f = {} self.struct = {} self.jit_compile = jit_compile def __call__(self, *args, **kwargs): new_x = list(_flatten((args, kwargs))) idx = len(new_x) if idx not in self.cached_f: self.struct[idx] = _wrap_struct((args, kwargs)) def _g(*x): new_args, new_kwargs = _nest(self.struct[idx], x) return self.f( *new_args, **new_kwargs ) # *new_args, **new_kwargs) _g2 = tf.function(_g, jit_compile=self.jit_compile) self.cached_f[idx] = _g2.get_concrete_function( *list(_flatten(self.struct[idx])) ) new_x = [ ( tf.convert_to_tensor(i, get_config("dtype")) if not isinstance(i, tf.Tensor) else i ) for i in new_x ] return self.cached_f[idx]( *new_x ) # *args, **kwargs) # _flatten((args, kwargs)))