gwkokab.utils.tools¶
Functions¶
|
Calculated batch and remainder of an array given a batch size. |
Module Contents¶
- gwkokab.utils.tools.batch_and_remainder(x: jaxtyping.Array, batch_size: int) Tuple[jaxtyping.Array, jaxtyping.Array]¶
Calculated batch and remainder of an array given a batch size.
Copied from JAX codebase.
- Parameters:
x (Array) – Array of interest
batch_size (int) – batch size
- Returns:
batched array and remainder
- Return type:
Tuple[Array, Array]