gwkokab.utils.tools¶

Functions¶

batch_and_remainder(→ Tuple[jaxtyping.Array, ...)

Calculated batch and remainder of an array given a batch size.

Module Contents¶

gwkokab.utils.tools.batch_and_remainder(x: jaxtyping.Array, batch_size: int) Tuple[jaxtyping.Array, jaxtyping.Array]¶

Calculated batch and remainder of an array given a batch size.

Copied from JAX codebase.

Parameters:
  • x (Array) – Array of interest

  • batch_size (int) – batch size

Returns:

batched array and remainder

Return type:

Tuple[Array, Array]