Source code for tf_pwa.tensorflow_wrapper

import os
import warnings

import tensorflow as tf

# default configurations
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
# os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # for Mac

# pylint: disable=no-member
try:
    tf_version = int(str(tf.__version__).split(".")[0])
except Exception:
    tf_version = 2
if tf_version < 2:
    tf.compat.v1.enable_eager_execution()
    import tensorflow.compat.v2 as tf  # pragma pylint: disable=import-error


[docs] def set_gpu_mem_growth(): gpus = tf.config.experimental.list_physical_devices("GPU") if gpus: try: # Currently, memory growth needs to be the same across GPUs for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) logical_gpus = tf.config.experimental.list_logical_devices("GPU") # print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") except RuntimeError as e: # Memory growth must be set before GPUs have been initialized print(e) except ValueError as e: print(e)
if "TF_PWA_GPU_FULL_MEM" in os.environ: if os.environ["TF_PWA_GPU_FULL_MEM"] == "0": set_gpu_mem_growth() else: set_gpu_mem_growth()
[docs] class Module(object): pass
tensorflow_wrapper = Module()
[docs] def regist_function(name, var=None, base_mod=tensorflow_wrapper): mod = base_mod names = name.split(".") for i in names[:-1]: if not hasattr(mod, i): setattr(mod, i, Module()) mod = getattr(mod, i) def wrapper(f): if hasattr(mod, names[-1]): warnings.warn("{} already exists.".format(name)) setattr(mod, names[-1], f) return f if var is None: return wrapper else: return wrapper(var)
# @regist_function("cross", base_mod=tf)
[docs] def numpy_cross(a, b): a = tf.convert_to_tensor(a) b = tf.convert_to_tensor(b) a_0 = tf.zeros_like(b) b_0 = tf.zeros_like(a) a = a + a_0 b = b + b_0 # shape = tf.broadcast_static_shape(a.shape, b.shape) # a = tf.broadcast_to(a, shape) # b = tf.broadcast_to(b, shape) ret = tf.linalg.cross(a, b) return ret
# regist_function("sum", tf.reduce_sum, base_mod=tf) regist_function("arctan2", tf.math.atan2, base_mod=tf) # from .jax_wrapper import tf