template_synthesis.jax.utilities.jax_utils module¶
Utility modules required for JAX implementation.
# Copyright(C) 2022 Gordian Edenhofer # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause # Authors: Gordian Edenhofer, Philipp Frank
- class template_synthesis.jax.utilities.jax_utils.ModelMeta(name, bases, dict_, /, **kwargs)¶
Bases:
ABCMeta
Register all derived classes as PyTrees in JAX using metaprogramming.
For any dataclasses.Field property with a metadata-entry named “static”, we will either hide or expose the property to JAX depending on the value.
- class template_synthesis.jax.utilities.jax_utils.PyTreeString(str)¶
Bases:
object
- lower()¶
- tree_flatten()¶
- upper()¶
- template_synthesis.jax.utilities.jax_utils.jax_zip(x: Array | ndarray | bool | number | bool | int | float | complex, y: Array | ndarray | bool | number | bool | int | float | complex) list[tuple] ¶
equivalent of zip function in Python
- template_synthesis.jax.utilities.jax_utils.jit_repeat(a: Array | ndarray | bool | number | bool | int | float | complex, repeats: Array | ndarray | bool | number | bool | int | float | complex, axis: int | None, total_repeat_length: int | None)¶
Jit-compiled version of jax.numpy.repeat.