gwkokab.utils.tools =================== .. py:module:: gwkokab.utils.tools Functions --------- .. autoapisummary:: gwkokab.utils.tools.batch_and_remainder Module Contents --------------- .. py:function:: batch_and_remainder(x: jaxtyping.Array, batch_size: int) -> Tuple[jaxtyping.Array, jaxtyping.Array] Calculated batch and remainder of an array given a batch size. Copied from JAX codebase. :param x: Array of interest :type x: Array :param batch_size: batch size :type batch_size: int :returns: batched array and remainder :rtype: Tuple[Array, Array]