template_synthesis.jax.utilities.jax_utils module

Utility modules required for JAX implementation.

# Copyright(C) 2022 Gordian Edenhofer # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause # Authors: Gordian Edenhofer, Philipp Frank

class template_synthesis.jax.utilities.jax_utils.ModelMeta(name, bases, dict_, /, **kwargs)

Bases: ABCMeta

Register all derived classes as PyTrees in JAX using metaprogramming.

For any dataclasses.Field property with a metadata-entry named “static”, we will either hide or expose the property to JAX depending on the value.

class template_synthesis.jax.utilities.jax_utils.PyTreeString(str)

Bases: object

lower()
startswith(rhs)
tree_flatten()
classmethod tree_unflatten(aux, _)
upper()
template_synthesis.jax.utilities.jax_utils.hide_strings(a)
template_synthesis.jax.utilities.jax_utils.jax_zip(x: Array | ndarray | bool | number | bool | int | float | complex, y: Array | ndarray | bool | number | bool | int | float | complex) list[tuple]

equivalent of zip function in Python

template_synthesis.jax.utilities.jax_utils.jit_repeat(a: Array | ndarray | bool | number | bool | int | float | complex, repeats: Array | ndarray | bool | number | bool | int | float | complex, axis: int | None, total_repeat_length: int | None)

Jit-compiled version of jax.numpy.repeat.