gwkokab.utils

Submodules

Functions

Mc_eta_to_m1_m2(→ tuple[jaxtyping.ArrayLike, ...)

beta_dist_concentrations_to_mean_variance(...)

Let \(\alpha\) and \(\beta\) be the shape parameters of a beta

beta_dist_mean_variance_to_concentrations(...)

Let \(\mu\) and \(\sigma^2\) be the mean and variance of a beta

cart_to_polar(→ tuple[jaxtyping.ArrayLike, ...)

cart_to_spherical(→ tuple[jaxtyping.ArrayLike, ...)

chi_costilt_to_chiz(→ jaxtyping.ArrayLike)

chi_p_from_components(→ jaxtyping.ArrayLike)

chieff(→ jaxtyping.ArrayLike)

chirp_mass(→ jaxtyping.ArrayLike)

delta_m(→ jaxtyping.ArrayLike)

delta_m_to_symmetric_mass_ratio(→ jaxtyping.ArrayLike)

eta_from_q(→ jaxtyping.ArrayLike)

load_model(→ Tuple[List[str], equinox.nn.MLP])

Load model and names from HDF5 (backward-compatible).

log_chirp_mass(→ jaxtyping.ArrayLike)

log_planck_taper_window(→ jaxtyping.ArrayLike)

If \(x\) is the point at which to evaluate the window, then the Planck taper

m1_m2_chi1_chi2_costilt1_costilt2_to_chieff(...)

m1_m2_chi1_chi2_costilt1_costilt2_to_chiminus(...)

m1_m2_chi1z_chi2z_to_chiminus(→ jaxtyping.ArrayLike)

m1_m2_chieff_chiminus_to_chi1z_chi2z(...)

m1_q_to_m2(→ jaxtyping.ArrayLike)

m1_times_m2(→ jaxtyping.ArrayLike)

m2_q_to_m1(→ jaxtyping.ArrayLike)

m_det_z_to_m_source(→ jaxtyping.ArrayLike)

m_source_z_to_m_det(→ jaxtyping.ArrayLike)

make_model(→ equinox.nn.MLP)

Build an MLP with ReLU activations.

mass_ratio(→ jaxtyping.ArrayLike)

mse_loss_fn(→ jaxtyping.Array)

Mean squared error loss.

polar_to_cart(→ tuple[jaxtyping.ArrayLike, ...)

predict(→ jaxtyping.Array)

Predict outputs for inputs x.

read_data(→ pandas.DataFrame)

Read dataset (HDF5) into a DataFrame with columns = keys.

reduced_mass(→ jaxtyping.ArrayLike)

save_model(→ None)

Persist model weights and metadata to HDF5 (backward-compatible format).

sin_tilt(→ jaxtyping.ArrayLike)

spherical_to_cart(→ tuple[jaxtyping.ArrayLike, ...)

spin_costilt_from_components(→ jaxtyping.ArrayLike)

spin_magnitude_from_components(→ jaxtyping.ArrayLike)

symmetric_mass_ratio(→ jaxtyping.ArrayLike)

symmetric_mass_ratio_to_delta_m(→ jaxtyping.ArrayLike)

total_mass(→ jaxtyping.ArrayLike)

train_regressor(→ None)

Train an MLP regressor with stable optimization and smooth loss curves.

Package Contents

gwkokab.utils.Mc_eta_to_m1_m2(Mc: jaxtyping.ArrayLike, eta: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]
\[\begin{split}\begin{align*} m_1(M_c, \eta) &= \frac{M_c}{2} \eta^{-0.6} (1 + \sqrt{1 - 4\eta}) \\ m_2(M_c, \eta) &= \frac{M_c}{2} \eta^{-0.6} (1 - \sqrt{1 - 4\eta}) \end{align*}\end{split}\]
gwkokab.utils.beta_dist_concentrations_to_mean_variance(alpha: jaxtyping.ArrayLike, beta: jaxtyping.ArrayLike, loc: jaxtyping.ArrayLike = 0.0, scale: jaxtyping.ArrayLike = 1.0) Tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]

Let \(\alpha\) and \(\beta\) be the shape parameters of a beta distribution, \(a\) being the location and \(b\) being the scale. This function returns the mean and variance of the distribution. Then concentrations are given by:

\[\mu = a+b\frac{\alpha}{\alpha + \beta}\qquad \sigma^2 = b^2\frac{\alpha \beta}{(\alpha + \beta)^2 (\alpha + \beta + 1)}\]
Parameters:
  • alpha (ArrayLike) – The shape parameter \(\alpha\).

  • beta (ArrayLike) – The shape parameter \(\beta\).

  • loc (ArrayLike) – The location \(a\) of the beta distribution.

  • scale (ArrayLike) – The scale \(b\) of the beta distribution.

Returns:

The mean \(\mu\) and variance \(\sigma^2\) of the beta distribution.

Return type:

Tuple[ArrayLike, ArrayLike]

gwkokab.utils.beta_dist_mean_variance_to_concentrations(mean: jaxtyping.ArrayLike, variance: jaxtyping.ArrayLike, loc: jaxtyping.ArrayLike = 0.0, scale: jaxtyping.ArrayLike = 1.0) Tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]

Let \(\mu\) and \(\sigma^2\) be the mean and variance of a beta distribution, \(a\) being the location and \(b\) being the scale. This function returns the shape parameters \(\alpha\) and \(\beta\) of the distribution. Then concentrations are given by:

\[\alpha = -\frac{\mu-a}{b} \left(\left(\frac{\mu-a}{\sigma}\right)\left(\frac{\mu-a-b}{\sigma}\right)+1\right)\qquad \beta = \alpha\left(\frac{b}{\mu-a}-1\right)\]
Parameters:
  • mean (ArrayLike) – The mean \(\mu\) of the beta distribution.

  • variance (ArrayLike) – The variance \(\sigma^2\) of the beta distribution.

  • loc (ArrayLike) – The location \(a\) of the beta distribution.

  • scale (ArrayLike) – The scale \(b\) of the beta distribution.

Returns:

The shape parameters \(\alpha\) and \(\beta\) of the beta distribution.

Return type:

Tuple[ArrayLike, ArrayLike]

gwkokab.utils.cart_to_polar(x: jaxtyping.ArrayLike, y: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]
\[\begin{split}\begin{align*} r(x, y) &= \sqrt{x^2 + y^2} \\ \theta(x, y) &= \arctan(y/x) \end{align*}\end{split}\]
gwkokab.utils.cart_to_spherical(x: jaxtyping.ArrayLike, y: jaxtyping.ArrayLike, z: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]
\[\begin{split}\begin{align*} r(x, y, z) &= \sqrt{x^2 + y^2 + z^2} \\ \theta(x, y, z) &= \arccos\left(\frac{z}{r}\right) \\ \phi(x, y, z) &= \arctan\left(\frac{y}{x}\right) \end{align*}\end{split}\]
gwkokab.utils.chi_costilt_to_chiz(chi: jaxtyping.ArrayLike, costilt: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\chi_z(\chi, \cos(\theta)) = \chi \cos(\theta)\]
gwkokab.utils.chi_p_from_components(a_1: jaxtyping.ArrayLike, cos_tilt_1: jaxtyping.ArrayLike, a_2: jaxtyping.ArrayLike, cos_tilt_2: jaxtyping.ArrayLike, mass_ratio: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\chi_p(a_1, \cos(\theta_1), a_2, \cos(\theta_2), q) = \max \left( a_1 \sin(\theta_1), \frac{3 + 4q}{4 + 3q} q a_2 \sin(\theta_2) \right)\]
gwkokab.utils.chieff(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1z: jaxtyping.ArrayLike, chi2z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\chi_{\text{eff}}(m_1, m_2, \chi_{1z}, \chi_{2z}) = \frac{m_1\chi_{1z} + m_2\chi_{2z}}{m_1 + m_2}\]
gwkokab.utils.chirp_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[M_c(m_1, m_2) = \frac{(m_1m_2)^{3/5}}{(m_1 + m_2)^{1/5}}\]
gwkokab.utils.delta_m(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\delta_m(m_1, m_2) = \frac{m_1 - m_2}{m_1 + m_2}\]
gwkokab.utils.delta_m_to_symmetric_mass_ratio(delta_m: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\eta(\delta_m) = \frac{1 - \delta_m^2}{4}\]
gwkokab.utils.eta_from_q(q: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\eta(q) = \frac{q}{(1 + q)^2}\]
gwkokab.utils.load_model(filename: str) Tuple[List[str], equinox.nn.MLP][source]

Load model and names from HDF5 (backward-compatible).

gwkokab.utils.log_chirp_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\log(M_c(m_1, m_2)) = 3/5\times (\log(m_1) + \log(m_2)) - \log(m_1 + m_2)/5\]
gwkokab.utils.log_planck_taper_window(x: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]

If \(x\) is the point at which to evaluate the window, then the Planck taper window is defined as,

\[\begin{split}S(x)=\begin{cases} 0 & \text{if } x < 0, \\ \displaystyle\frac{1}{1+e^{\left(\frac{1}{x}+\frac{1}{x-1}\right)}} & \text{if } 0 \leq x \leq 1, \\ 1 & \text{if } x > 1, \\ \end{cases}\end{split}\]

This function evaluates the log of the Planck taper window \(\ln{S(x)}\).

Parameters:

x (ArrayLike) – point at which to evaluate the window

Returns:

window value

Return type:

ArrayLike

gwkokab.utils.m1_m2_chi1_chi2_costilt1_costilt2_to_chieff(*, m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1: jaxtyping.ArrayLike, chi2: jaxtyping.ArrayLike, costilt1: jaxtyping.ArrayLike, costilt2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\chi_{\text{eff}}(m_1, m_2, \chi_1, \chi_2, \cos(\theta_1), \cos(\theta_2)) = \frac{m_1\chi_1\cos(\theta_1) + m_2\chi_2\cos(\theta_2)}{m_1 + m_2}\]
gwkokab.utils.m1_m2_chi1_chi2_costilt1_costilt2_to_chiminus(*, m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1: jaxtyping.ArrayLike, chi2: jaxtyping.ArrayLike, costilt1: jaxtyping.ArrayLike, costilt2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\chi_{\text{minus}}(m_1, m_2, \chi_1, \chi_2, \cos(\theta_1), \cos(\theta_2)) = \frac{m_1\chi_1\cos(\theta_1) - m_2\chi_2\cos(\theta_2)}{m_1 + m_2}\]
gwkokab.utils.m1_m2_chi1z_chi2z_to_chiminus(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1z: jaxtyping.ArrayLike, chi2z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\chi_{\text{minus}}(m_1, m_2, \chi_{1z}, \chi_{2z}) = \frac{m_1\chi_{1z} - m_2\chi_{2z}}{m_1 + m_2}\]
gwkokab.utils.m1_m2_chieff_chiminus_to_chi1z_chi2z(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chieff: jaxtyping.ArrayLike, chiminus: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]
\[\begin{split}\begin{align*} \chi_{1z}(m_1, m_2, \chi_{\text{eff}}, \chi_{\text{minus}}) &= \frac{m_1+m_2}{2m_1} \left( \chi_{\text{eff}} + \chi_{\text{minus}} \right)\\ \chi_{2z}(m_1, m_2, \chi_{\text{eff}}, \chi_{\text{minus}}) &= \frac{m_1+m_2}{2m_2} \left( \chi_{\text{eff}} - \chi_{\text{minus}} \right) \end{align*}\end{split}\]
gwkokab.utils.m1_q_to_m2(m1: jaxtyping.ArrayLike, q: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[m_2(m_1, q) = m_1q\]
gwkokab.utils.m1_times_m2(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[m_1m_2(m_1, m_2) = m_1 m_2\]
gwkokab.utils.m2_q_to_m1(m2: jaxtyping.ArrayLike, q: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[m_1(m_2, q) = \frac{m_2}{q}\]
gwkokab.utils.m_det_z_to_m_source(m_det: jaxtyping.ArrayLike, z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[m_{\text{source}}(m_{\text{det}}, z) = \frac{m_{\text{det}}}{1 + z}\]
gwkokab.utils.m_source_z_to_m_det(m_source: jaxtyping.ArrayLike, z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[m_{\text{det}}(m_{\text{source}}, z) = m_{\text{source}}(1 + z)\]
gwkokab.utils.make_model(*, key: jaxtyping.PRNGKeyArray, input_layer: int, output_layer: int, width_size: int, depth: int) equinox.nn.MLP[source]

Build an MLP with ReLU activations.

gwkokab.utils.mass_ratio(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[q(m_1, m_2) = \frac{m_2}{m_1}\]
gwkokab.utils.mse_loss_fn(model: jaxtyping.PyTree, x: jaxtyping.Array, y: jaxtyping.Array, batch_size: int | None = 256) jaxtyping.Array[source]

Mean squared error loss.

gwkokab.utils.polar_to_cart(r: jaxtyping.ArrayLike, theta: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]
\[\begin{split}\begin{align*} x(r, \theta) &= r \cos(\theta) \\ y(r, \theta) &= r \sin(\theta) \end{align*}\end{split}\]
gwkokab.utils.predict(model: jaxtyping.PyTree, x: jaxtyping.Array, batch_size: int | None = 256) jaxtyping.Array[source]

Predict outputs for inputs x.

gwkokab.utils.read_data(data_path: str, keys: collections.abc.Sequence[str]) pandas.DataFrame[source]

Read dataset (HDF5) into a DataFrame with columns = keys.

gwkokab.utils.reduced_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[M_r(m_1, m_2) = \frac{m_1m_2}{m_1 + m_2}\]
gwkokab.utils.save_model(*, filepath: str, datafilepath: str, model: equinox.nn.MLP, names: collections.abc.Sequence[str] | None = None, is_log: bool = False) None[source]

Persist model weights and metadata to HDF5 (backward-compatible format).

gwkokab.utils.sin_tilt(costilt: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\sin(\theta) = \sqrt{1 - \cos^2(\theta)}\]
gwkokab.utils.spherical_to_cart(r: jaxtyping.ArrayLike, theta: jaxtyping.ArrayLike, phi: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]
\[\begin{split}\begin{align*} x(r, \theta, \phi) &= r \sin(\theta) \cos(\phi) \\ y(r, \theta, \phi) &= r \sin(\theta) \sin(\phi) \\ z(r, \theta, \phi) &= r \cos(\theta) \end{align*}\end{split}\]
gwkokab.utils.spin_costilt_from_components(chi_x: jaxtyping.ArrayLike, chi_y: jaxtyping.ArrayLike, chi_z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\cos(\theta)(\chi_x, \chi_y, \chi_z) = \frac{\chi_z}{\sqrt{\chi_x^2 + \chi_y^2 + \chi_z^2}}\]
gwkokab.utils.spin_magnitude_from_components(chi_x: jaxtyping.ArrayLike, chi_y: jaxtyping.ArrayLike, chi_z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\chi(\chi_x, \chi_y, \chi_z) = \sqrt{\chi_x^2 + \chi_y^2 + \chi_z^2}\]
gwkokab.utils.symmetric_mass_ratio(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\eta(m_1, m_2) = \frac{m_1m_2}{(m_1 + m_2)^2}\]
gwkokab.utils.symmetric_mass_ratio_to_delta_m(eta: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[\delta_m(\eta) = \sqrt{1 - 4\eta}\]
gwkokab.utils.total_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
\[M(m_1, m_2) = m_1 + m_2\]
gwkokab.utils.train_regressor(*, input_keys: list[str], output_keys: list[str], width_size: int, depth: int, batch_size: int, data_path: str, checkpoint_path: str | None = None, epochs: int = 50, validation_split: float = 0.2, learning_rate: float = 0.001, train_in_log: bool = False, loss_type: str = 'mse', grad_clip_norm: float = 1.0, weight_decay: float = 0.0001, use_cosine_decay: bool = True, min_lr: float = 1e-06, warmup_epochs: int = 3, seed: int | None = 42) None[source]

Train an MLP regressor with stable optimization and smooth loss curves.

Notes

  • For detection probabilities in [0,1], prefer loss_type=”bce_logits” and do NOT set train_in_log=True (BCE expects probability targets, not log-values).

  • seed fixes the validation split for a less jittery val-loss.