gwkokab.utils ============= .. py:module:: gwkokab.utils Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/gwkokab/utils/exceptions/index /autoapi/gwkokab/utils/kernel/index /autoapi/gwkokab/utils/math/index /autoapi/gwkokab/utils/path/index /autoapi/gwkokab/utils/tools/index /autoapi/gwkokab/utils/train/index /autoapi/gwkokab/utils/transformations/index Functions --------- .. autoapisummary:: gwkokab.utils.Mc_eta_to_m1_m2 gwkokab.utils.beta_dist_concentrations_to_mean_variance gwkokab.utils.beta_dist_mean_variance_to_concentrations gwkokab.utils.cart_to_polar gwkokab.utils.cart_to_spherical gwkokab.utils.chi_costilt_to_chiz gwkokab.utils.chi_p_from_components gwkokab.utils.chieff gwkokab.utils.chirp_mass gwkokab.utils.delta_m gwkokab.utils.delta_m_to_symmetric_mass_ratio gwkokab.utils.eta_from_q gwkokab.utils.load_model gwkokab.utils.log_chirp_mass gwkokab.utils.log_planck_taper_window gwkokab.utils.m1_m2_chi1_chi2_costilt1_costilt2_to_chieff gwkokab.utils.m1_m2_chi1_chi2_costilt1_costilt2_to_chiminus gwkokab.utils.m1_m2_chi1z_chi2z_to_chiminus gwkokab.utils.m1_m2_chieff_chiminus_to_chi1z_chi2z gwkokab.utils.m1_q_to_m2 gwkokab.utils.m1_times_m2 gwkokab.utils.m2_q_to_m1 gwkokab.utils.m_det_z_to_m_source gwkokab.utils.m_source_z_to_m_det gwkokab.utils.make_model gwkokab.utils.mass_ratio gwkokab.utils.mse_loss_fn gwkokab.utils.polar_to_cart gwkokab.utils.predict gwkokab.utils.read_data gwkokab.utils.reduced_mass gwkokab.utils.save_model gwkokab.utils.sin_tilt gwkokab.utils.spherical_to_cart gwkokab.utils.spin_costilt_from_components gwkokab.utils.spin_magnitude_from_components gwkokab.utils.symmetric_mass_ratio gwkokab.utils.symmetric_mass_ratio_to_delta_m gwkokab.utils.total_mass gwkokab.utils.train_regressor Package Contents ---------------- .. py:function:: Mc_eta_to_m1_m2(Mc: jaxtyping.ArrayLike, eta: jaxtyping.ArrayLike) -> tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike] .. math:: \begin{align*} m_1(M_c, \eta) &= \frac{M_c}{2} \eta^{-0.6} (1 + \sqrt{1 - 4\eta}) \\ m_2(M_c, \eta) &= \frac{M_c}{2} \eta^{-0.6} (1 - \sqrt{1 - 4\eta}) \end{align*} .. py:function:: beta_dist_concentrations_to_mean_variance(alpha: jaxtyping.ArrayLike, beta: jaxtyping.ArrayLike, loc: jaxtyping.ArrayLike = 0.0, scale: jaxtyping.ArrayLike = 1.0) -> Tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike] Let :math:`\alpha` and :math:`\beta` be the shape parameters of a beta distribution, :math:`a` being the location and :math:`b` being the scale. This function returns the mean and variance of the distribution. Then concentrations are given by: .. math:: \mu = a+b\frac{\alpha}{\alpha + \beta}\qquad \sigma^2 = b^2\frac{\alpha \beta}{(\alpha + \beta)^2 (\alpha + \beta + 1)} :param alpha: The shape parameter :math:`\alpha`. :type alpha: ArrayLike :param beta: The shape parameter :math:`\beta`. :type beta: ArrayLike :param loc: The location :math:`a` of the beta distribution. :type loc: ArrayLike :param scale: The scale :math:`b` of the beta distribution. :type scale: ArrayLike :returns: The mean :math:`\mu` and variance :math:`\sigma^2` of the beta distribution. :rtype: Tuple[ArrayLike, ArrayLike] .. py:function:: beta_dist_mean_variance_to_concentrations(mean: jaxtyping.ArrayLike, variance: jaxtyping.ArrayLike, loc: jaxtyping.ArrayLike = 0.0, scale: jaxtyping.ArrayLike = 1.0) -> Tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike] Let :math:`\mu` and :math:`\sigma^2` be the mean and variance of a beta distribution, :math:`a` being the location and :math:`b` being the scale. This function returns the shape parameters :math:`\alpha` and :math:`\beta` of the distribution. Then concentrations are given by: .. math:: \alpha = -\frac{\mu-a}{b} \left(\left(\frac{\mu-a}{\sigma}\right)\left(\frac{\mu-a-b}{\sigma}\right)+1\right)\qquad \beta = \alpha\left(\frac{b}{\mu-a}-1\right) :param mean: The mean :math:`\mu` of the beta distribution. :type mean: ArrayLike :param variance: The variance :math:`\sigma^2` of the beta distribution. :type variance: ArrayLike :param loc: The location :math:`a` of the beta distribution. :type loc: ArrayLike :param scale: The scale :math:`b` of the beta distribution. :type scale: ArrayLike :returns: The shape parameters :math:`\alpha` and :math:`\beta` of the beta distribution. :rtype: Tuple[ArrayLike, ArrayLike] .. py:function:: cart_to_polar(x: jaxtyping.ArrayLike, y: jaxtyping.ArrayLike) -> tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike] .. math:: \begin{align*} r(x, y) &= \sqrt{x^2 + y^2} \\ \theta(x, y) &= \arctan(y/x) \end{align*} .. py:function:: cart_to_spherical(x: jaxtyping.ArrayLike, y: jaxtyping.ArrayLike, z: jaxtyping.ArrayLike) -> tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike, jaxtyping.ArrayLike] .. math:: \begin{align*} r(x, y, z) &= \sqrt{x^2 + y^2 + z^2} \\ \theta(x, y, z) &= \arccos\left(\frac{z}{r}\right) \\ \phi(x, y, z) &= \arctan\left(\frac{y}{x}\right) \end{align*} .. py:function:: chi_costilt_to_chiz(chi: jaxtyping.ArrayLike, costilt: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \chi_z(\chi, \cos(\theta)) = \chi \cos(\theta) .. py:function:: chi_p_from_components(a_1: jaxtyping.ArrayLike, cos_tilt_1: jaxtyping.ArrayLike, a_2: jaxtyping.ArrayLike, cos_tilt_2: jaxtyping.ArrayLike, mass_ratio: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \chi_p(a_1, \cos(\theta_1), a_2, \cos(\theta_2), q) = \max \left( a_1 \sin(\theta_1), \frac{3 + 4q}{4 + 3q} q a_2 \sin(\theta_2) \right) .. py:function:: chieff(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1z: jaxtyping.ArrayLike, chi2z: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \chi_{\text{eff}}(m_1, m_2, \chi_{1z}, \chi_{2z}) = \frac{m_1\chi_{1z} + m_2\chi_{2z}}{m_1 + m_2} .. py:function:: chirp_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: M_c(m_1, m_2) = \frac{(m_1m_2)^{3/5}}{(m_1 + m_2)^{1/5}} .. py:function:: delta_m(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \delta_m(m_1, m_2) = \frac{m_1 - m_2}{m_1 + m_2} .. py:function:: delta_m_to_symmetric_mass_ratio(delta_m: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \eta(\delta_m) = \frac{1 - \delta_m^2}{4} .. py:function:: eta_from_q(q: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \eta(q) = \frac{q}{(1 + q)^2} .. py:function:: load_model(filename: str) -> Tuple[List[str], equinox.nn.MLP] Load model and names from HDF5 (backward-compatible). .. py:function:: log_chirp_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \log(M_c(m_1, m_2)) = 3/5\times (\log(m_1) + \log(m_2)) - \log(m_1 + m_2)/5 .. py:function:: log_planck_taper_window(x: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike If :math:`x` is the point at which to evaluate the window, then the Planck taper window is defined as, .. math:: S(x)=\begin{cases} 0 & \text{if } x < 0, \\ \displaystyle\frac{1}{1+e^{\left(\frac{1}{x}+\frac{1}{x-1}\right)}} & \text{if } 0 \leq x \leq 1, \\ 1 & \text{if } x > 1, \\ \end{cases} This function evaluates the log of the Planck taper window :math:`\ln{S(x)}`. :param x: point at which to evaluate the window :type x: ArrayLike :returns: window value :rtype: ArrayLike .. py:function:: m1_m2_chi1_chi2_costilt1_costilt2_to_chieff(*, m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1: jaxtyping.ArrayLike, chi2: jaxtyping.ArrayLike, costilt1: jaxtyping.ArrayLike, costilt2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \chi_{\text{eff}}(m_1, m_2, \chi_1, \chi_2, \cos(\theta_1), \cos(\theta_2)) = \frac{m_1\chi_1\cos(\theta_1) + m_2\chi_2\cos(\theta_2)}{m_1 + m_2} .. py:function:: m1_m2_chi1_chi2_costilt1_costilt2_to_chiminus(*, m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1: jaxtyping.ArrayLike, chi2: jaxtyping.ArrayLike, costilt1: jaxtyping.ArrayLike, costilt2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \chi_{\text{minus}}(m_1, m_2, \chi_1, \chi_2, \cos(\theta_1), \cos(\theta_2)) = \frac{m_1\chi_1\cos(\theta_1) - m_2\chi_2\cos(\theta_2)}{m_1 + m_2} .. py:function:: m1_m2_chi1z_chi2z_to_chiminus(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1z: jaxtyping.ArrayLike, chi2z: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \chi_{\text{minus}}(m_1, m_2, \chi_{1z}, \chi_{2z}) = \frac{m_1\chi_{1z} - m_2\chi_{2z}}{m_1 + m_2} .. py:function:: m1_m2_chieff_chiminus_to_chi1z_chi2z(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chieff: jaxtyping.ArrayLike, chiminus: jaxtyping.ArrayLike) -> tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike] .. math:: \begin{align*} \chi_{1z}(m_1, m_2, \chi_{\text{eff}}, \chi_{\text{minus}}) &= \frac{m_1+m_2}{2m_1} \left( \chi_{\text{eff}} + \chi_{\text{minus}} \right)\\ \chi_{2z}(m_1, m_2, \chi_{\text{eff}}, \chi_{\text{minus}}) &= \frac{m_1+m_2}{2m_2} \left( \chi_{\text{eff}} - \chi_{\text{minus}} \right) \end{align*} .. py:function:: m1_q_to_m2(m1: jaxtyping.ArrayLike, q: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: m_2(m_1, q) = m_1q .. py:function:: m1_times_m2(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: m_1m_2(m_1, m_2) = m_1 m_2 .. py:function:: m2_q_to_m1(m2: jaxtyping.ArrayLike, q: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: m_1(m_2, q) = \frac{m_2}{q} .. py:function:: m_det_z_to_m_source(m_det: jaxtyping.ArrayLike, z: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: m_{\text{source}}(m_{\text{det}}, z) = \frac{m_{\text{det}}}{1 + z} .. py:function:: m_source_z_to_m_det(m_source: jaxtyping.ArrayLike, z: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: m_{\text{det}}(m_{\text{source}}, z) = m_{\text{source}}(1 + z) .. py:function:: make_model(*, key: jaxtyping.PRNGKeyArray, input_layer: int, output_layer: int, width_size: int, depth: int) -> equinox.nn.MLP Build an MLP with ReLU activations. .. py:function:: mass_ratio(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: q(m_1, m_2) = \frac{m_2}{m_1} .. py:function:: mse_loss_fn(model: jaxtyping.PyTree, x: jaxtyping.Array, y: jaxtyping.Array, batch_size: Optional[int] = 256) -> jaxtyping.Array Mean squared error loss. .. py:function:: polar_to_cart(r: jaxtyping.ArrayLike, theta: jaxtyping.ArrayLike) -> tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike] .. math:: \begin{align*} x(r, \theta) &= r \cos(\theta) \\ y(r, \theta) &= r \sin(\theta) \end{align*} .. py:function:: predict(model: jaxtyping.PyTree, x: jaxtyping.Array, batch_size: Optional[int] = 256) -> jaxtyping.Array Predict outputs for inputs x. .. py:function:: read_data(data_path: str, keys: collections.abc.Sequence[str]) -> pandas.DataFrame Read dataset (HDF5) into a DataFrame with columns = keys. .. py:function:: reduced_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: M_r(m_1, m_2) = \frac{m_1m_2}{m_1 + m_2} .. py:function:: save_model(*, filepath: str, datafilepath: str, model: equinox.nn.MLP, names: Optional[collections.abc.Sequence[str]] = None, is_log: bool = False) -> None Persist model weights and metadata to HDF5 (backward-compatible format). .. py:function:: sin_tilt(costilt: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \sin(\theta) = \sqrt{1 - \cos^2(\theta)} .. py:function:: spherical_to_cart(r: jaxtyping.ArrayLike, theta: jaxtyping.ArrayLike, phi: jaxtyping.ArrayLike) -> tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike, jaxtyping.ArrayLike] .. math:: \begin{align*} x(r, \theta, \phi) &= r \sin(\theta) \cos(\phi) \\ y(r, \theta, \phi) &= r \sin(\theta) \sin(\phi) \\ z(r, \theta, \phi) &= r \cos(\theta) \end{align*} .. py:function:: spin_costilt_from_components(chi_x: jaxtyping.ArrayLike, chi_y: jaxtyping.ArrayLike, chi_z: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \cos(\theta)(\chi_x, \chi_y, \chi_z) = \frac{\chi_z}{\sqrt{\chi_x^2 + \chi_y^2 + \chi_z^2}} .. py:function:: spin_magnitude_from_components(chi_x: jaxtyping.ArrayLike, chi_y: jaxtyping.ArrayLike, chi_z: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \chi(\chi_x, \chi_y, \chi_z) = \sqrt{\chi_x^2 + \chi_y^2 + \chi_z^2} .. py:function:: symmetric_mass_ratio(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \eta(m_1, m_2) = \frac{m_1m_2}{(m_1 + m_2)^2} .. py:function:: symmetric_mass_ratio_to_delta_m(eta: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: \delta_m(\eta) = \sqrt{1 - 4\eta} .. py:function:: total_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. math:: M(m_1, m_2) = m_1 + m_2 .. py:function:: train_regressor(*, input_keys: list[str], output_keys: list[str], width_size: int, depth: int, batch_size: int, data_path: str, checkpoint_path: Optional[str] = None, epochs: int = 50, validation_split: float = 0.2, learning_rate: float = 0.001, train_in_log: bool = False, loss_type: str = 'mse', grad_clip_norm: float = 1.0, weight_decay: float = 0.0001, use_cosine_decay: bool = True, min_lr: float = 1e-06, warmup_epochs: int = 3, seed: Optional[int] = 42) -> None Train an MLP regressor with stable optimization and smooth loss curves. .. rubric:: Notes - For detection probabilities in [0,1], prefer `loss_type="bce_logits"` and do NOT set `train_in_log=True` (BCE expects probability targets, not log-values). - `seed` fixes the validation split for a less jittery val-loss.