gwkokab.utils¶
Submodules¶
Functions¶
|
|
Let \(\alpha\) and \(\beta\) be the shape parameters of a beta |
|
Let \(\mu\) and \(\sigma^2\) be the mean and variance of a beta |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Load model and names from HDF5 (backward-compatible). |
|
|
|
If \(x\) is the point at which to evaluate the window, then the Planck taper |
|
|
|
|
|
|
|
|
|
|
|
|
|
Build an MLP with ReLU activations. |
|
|
|
Mean squared error loss. |
|
|
|
Predict outputs for inputs x. |
|
Read dataset (HDF5) into a DataFrame with columns = keys. |
|
|
|
Persist model weights and metadata to HDF5 (backward-compatible format). |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Train an MLP regressor with stable optimization and smooth loss curves. |
Package Contents¶
- gwkokab.utils.Mc_eta_to_m1_m2(Mc: jaxtyping.ArrayLike, eta: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]¶
- \[\begin{split}\begin{align*} m_1(M_c, \eta) &= \frac{M_c}{2} \eta^{-0.6} (1 + \sqrt{1 - 4\eta}) \\ m_2(M_c, \eta) &= \frac{M_c}{2} \eta^{-0.6} (1 - \sqrt{1 - 4\eta}) \end{align*}\end{split}\]
- gwkokab.utils.beta_dist_concentrations_to_mean_variance(alpha: jaxtyping.ArrayLike, beta: jaxtyping.ArrayLike, loc: jaxtyping.ArrayLike = 0.0, scale: jaxtyping.ArrayLike = 1.0) Tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]¶
Let \(\alpha\) and \(\beta\) be the shape parameters of a beta distribution, \(a\) being the location and \(b\) being the scale. This function returns the mean and variance of the distribution. Then concentrations are given by:
\[\mu = a+b\frac{\alpha}{\alpha + \beta}\qquad \sigma^2 = b^2\frac{\alpha \beta}{(\alpha + \beta)^2 (\alpha + \beta + 1)}\]- Parameters:
alpha (ArrayLike) – The shape parameter \(\alpha\).
beta (ArrayLike) – The shape parameter \(\beta\).
loc (ArrayLike) – The location \(a\) of the beta distribution.
scale (ArrayLike) – The scale \(b\) of the beta distribution.
- Returns:
The mean \(\mu\) and variance \(\sigma^2\) of the beta distribution.
- Return type:
Tuple[ArrayLike, ArrayLike]
- gwkokab.utils.beta_dist_mean_variance_to_concentrations(mean: jaxtyping.ArrayLike, variance: jaxtyping.ArrayLike, loc: jaxtyping.ArrayLike = 0.0, scale: jaxtyping.ArrayLike = 1.0) Tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]¶
Let \(\mu\) and \(\sigma^2\) be the mean and variance of a beta distribution, \(a\) being the location and \(b\) being the scale. This function returns the shape parameters \(\alpha\) and \(\beta\) of the distribution. Then concentrations are given by:
\[\alpha = -\frac{\mu-a}{b} \left(\left(\frac{\mu-a}{\sigma}\right)\left(\frac{\mu-a-b}{\sigma}\right)+1\right)\qquad \beta = \alpha\left(\frac{b}{\mu-a}-1\right)\]- Parameters:
mean (ArrayLike) – The mean \(\mu\) of the beta distribution.
variance (ArrayLike) – The variance \(\sigma^2\) of the beta distribution.
loc (ArrayLike) – The location \(a\) of the beta distribution.
scale (ArrayLike) – The scale \(b\) of the beta distribution.
- Returns:
The shape parameters \(\alpha\) and \(\beta\) of the beta distribution.
- Return type:
Tuple[ArrayLike, ArrayLike]
- gwkokab.utils.cart_to_polar(x: jaxtyping.ArrayLike, y: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]¶
- \[\begin{split}\begin{align*} r(x, y) &= \sqrt{x^2 + y^2} \\ \theta(x, y) &= \arctan(y/x) \end{align*}\end{split}\]
- gwkokab.utils.cart_to_spherical(x: jaxtyping.ArrayLike, y: jaxtyping.ArrayLike, z: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]¶
- \[\begin{split}\begin{align*} r(x, y, z) &= \sqrt{x^2 + y^2 + z^2} \\ \theta(x, y, z) &= \arccos\left(\frac{z}{r}\right) \\ \phi(x, y, z) &= \arctan\left(\frac{y}{x}\right) \end{align*}\end{split}\]
- gwkokab.utils.chi_costilt_to_chiz(chi: jaxtyping.ArrayLike, costilt: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\chi_z(\chi, \cos(\theta)) = \chi \cos(\theta)\]
- gwkokab.utils.chi_p_from_components(a_1: jaxtyping.ArrayLike, cos_tilt_1: jaxtyping.ArrayLike, a_2: jaxtyping.ArrayLike, cos_tilt_2: jaxtyping.ArrayLike, mass_ratio: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\chi_p(a_1, \cos(\theta_1), a_2, \cos(\theta_2), q) = \max \left( a_1 \sin(\theta_1), \frac{3 + 4q}{4 + 3q} q a_2 \sin(\theta_2) \right)\]
- gwkokab.utils.chieff(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1z: jaxtyping.ArrayLike, chi2z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\chi_{\text{eff}}(m_1, m_2, \chi_{1z}, \chi_{2z}) = \frac{m_1\chi_{1z} + m_2\chi_{2z}}{m_1 + m_2}\]
- gwkokab.utils.chirp_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[M_c(m_1, m_2) = \frac{(m_1m_2)^{3/5}}{(m_1 + m_2)^{1/5}}\]
- gwkokab.utils.delta_m(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\delta_m(m_1, m_2) = \frac{m_1 - m_2}{m_1 + m_2}\]
- gwkokab.utils.delta_m_to_symmetric_mass_ratio(delta_m: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\eta(\delta_m) = \frac{1 - \delta_m^2}{4}\]
- gwkokab.utils.eta_from_q(q: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\eta(q) = \frac{q}{(1 + q)^2}\]
- gwkokab.utils.load_model(filename: str) Tuple[List[str], equinox.nn.MLP][source]¶
Load model and names from HDF5 (backward-compatible).
- gwkokab.utils.log_chirp_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\log(M_c(m_1, m_2)) = 3/5\times (\log(m_1) + \log(m_2)) - \log(m_1 + m_2)/5\]
- gwkokab.utils.log_planck_taper_window(x: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
If \(x\) is the point at which to evaluate the window, then the Planck taper window is defined as,
\[\begin{split}S(x)=\begin{cases} 0 & \text{if } x < 0, \\ \displaystyle\frac{1}{1+e^{\left(\frac{1}{x}+\frac{1}{x-1}\right)}} & \text{if } 0 \leq x \leq 1, \\ 1 & \text{if } x > 1, \\ \end{cases}\end{split}\]This function evaluates the log of the Planck taper window \(\ln{S(x)}\).
- Parameters:
x (ArrayLike) – point at which to evaluate the window
- Returns:
window value
- Return type:
ArrayLike
- gwkokab.utils.m1_m2_chi1_chi2_costilt1_costilt2_to_chieff(*, m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1: jaxtyping.ArrayLike, chi2: jaxtyping.ArrayLike, costilt1: jaxtyping.ArrayLike, costilt2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\chi_{\text{eff}}(m_1, m_2, \chi_1, \chi_2, \cos(\theta_1), \cos(\theta_2)) = \frac{m_1\chi_1\cos(\theta_1) + m_2\chi_2\cos(\theta_2)}{m_1 + m_2}\]
- gwkokab.utils.m1_m2_chi1_chi2_costilt1_costilt2_to_chiminus(*, m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1: jaxtyping.ArrayLike, chi2: jaxtyping.ArrayLike, costilt1: jaxtyping.ArrayLike, costilt2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\chi_{\text{minus}}(m_1, m_2, \chi_1, \chi_2, \cos(\theta_1), \cos(\theta_2)) = \frac{m_1\chi_1\cos(\theta_1) - m_2\chi_2\cos(\theta_2)}{m_1 + m_2}\]
- gwkokab.utils.m1_m2_chi1z_chi2z_to_chiminus(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chi1z: jaxtyping.ArrayLike, chi2z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\chi_{\text{minus}}(m_1, m_2, \chi_{1z}, \chi_{2z}) = \frac{m_1\chi_{1z} - m_2\chi_{2z}}{m_1 + m_2}\]
- gwkokab.utils.m1_m2_chieff_chiminus_to_chi1z_chi2z(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike, chieff: jaxtyping.ArrayLike, chiminus: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]¶
- \[\begin{split}\begin{align*} \chi_{1z}(m_1, m_2, \chi_{\text{eff}}, \chi_{\text{minus}}) &= \frac{m_1+m_2}{2m_1} \left( \chi_{\text{eff}} + \chi_{\text{minus}} \right)\\ \chi_{2z}(m_1, m_2, \chi_{\text{eff}}, \chi_{\text{minus}}) &= \frac{m_1+m_2}{2m_2} \left( \chi_{\text{eff}} - \chi_{\text{minus}} \right) \end{align*}\end{split}\]
- gwkokab.utils.m1_q_to_m2(m1: jaxtyping.ArrayLike, q: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[m_2(m_1, q) = m_1q\]
- gwkokab.utils.m1_times_m2(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[m_1m_2(m_1, m_2) = m_1 m_2\]
- gwkokab.utils.m2_q_to_m1(m2: jaxtyping.ArrayLike, q: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[m_1(m_2, q) = \frac{m_2}{q}\]
- gwkokab.utils.m_det_z_to_m_source(m_det: jaxtyping.ArrayLike, z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[m_{\text{source}}(m_{\text{det}}, z) = \frac{m_{\text{det}}}{1 + z}\]
- gwkokab.utils.m_source_z_to_m_det(m_source: jaxtyping.ArrayLike, z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[m_{\text{det}}(m_{\text{source}}, z) = m_{\text{source}}(1 + z)\]
- gwkokab.utils.make_model(*, key: jaxtyping.PRNGKeyArray, input_layer: int, output_layer: int, width_size: int, depth: int) equinox.nn.MLP[source]¶
Build an MLP with ReLU activations.
- gwkokab.utils.mass_ratio(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[q(m_1, m_2) = \frac{m_2}{m_1}\]
- gwkokab.utils.mse_loss_fn(model: jaxtyping.PyTree, x: jaxtyping.Array, y: jaxtyping.Array, batch_size: int | None = 256) jaxtyping.Array[source]¶
Mean squared error loss.
- gwkokab.utils.polar_to_cart(r: jaxtyping.ArrayLike, theta: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]¶
- \[\begin{split}\begin{align*} x(r, \theta) &= r \cos(\theta) \\ y(r, \theta) &= r \sin(\theta) \end{align*}\end{split}\]
- gwkokab.utils.predict(model: jaxtyping.PyTree, x: jaxtyping.Array, batch_size: int | None = 256) jaxtyping.Array[source]¶
Predict outputs for inputs x.
- gwkokab.utils.read_data(data_path: str, keys: collections.abc.Sequence[str]) pandas.DataFrame[source]¶
Read dataset (HDF5) into a DataFrame with columns = keys.
- gwkokab.utils.reduced_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[M_r(m_1, m_2) = \frac{m_1m_2}{m_1 + m_2}\]
- gwkokab.utils.save_model(*, filepath: str, datafilepath: str, model: equinox.nn.MLP, names: collections.abc.Sequence[str] | None = None, is_log: bool = False) None[source]¶
Persist model weights and metadata to HDF5 (backward-compatible format).
- gwkokab.utils.sin_tilt(costilt: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\sin(\theta) = \sqrt{1 - \cos^2(\theta)}\]
- gwkokab.utils.spherical_to_cart(r: jaxtyping.ArrayLike, theta: jaxtyping.ArrayLike, phi: jaxtyping.ArrayLike) tuple[jaxtyping.ArrayLike, jaxtyping.ArrayLike, jaxtyping.ArrayLike][source]¶
- \[\begin{split}\begin{align*} x(r, \theta, \phi) &= r \sin(\theta) \cos(\phi) \\ y(r, \theta, \phi) &= r \sin(\theta) \sin(\phi) \\ z(r, \theta, \phi) &= r \cos(\theta) \end{align*}\end{split}\]
- gwkokab.utils.spin_costilt_from_components(chi_x: jaxtyping.ArrayLike, chi_y: jaxtyping.ArrayLike, chi_z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\cos(\theta)(\chi_x, \chi_y, \chi_z) = \frac{\chi_z}{\sqrt{\chi_x^2 + \chi_y^2 + \chi_z^2}}\]
- gwkokab.utils.spin_magnitude_from_components(chi_x: jaxtyping.ArrayLike, chi_y: jaxtyping.ArrayLike, chi_z: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\chi(\chi_x, \chi_y, \chi_z) = \sqrt{\chi_x^2 + \chi_y^2 + \chi_z^2}\]
- gwkokab.utils.symmetric_mass_ratio(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\eta(m_1, m_2) = \frac{m_1m_2}{(m_1 + m_2)^2}\]
- gwkokab.utils.symmetric_mass_ratio_to_delta_m(eta: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[\delta_m(\eta) = \sqrt{1 - 4\eta}\]
- gwkokab.utils.total_mass(m1: jaxtyping.ArrayLike, m2: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
- \[M(m_1, m_2) = m_1 + m_2\]
- gwkokab.utils.train_regressor(*, input_keys: list[str], output_keys: list[str], width_size: int, depth: int, batch_size: int, data_path: str, checkpoint_path: str | None = None, epochs: int = 50, validation_split: float = 0.2, learning_rate: float = 0.001, train_in_log: bool = False, loss_type: str = 'mse', grad_clip_norm: float = 1.0, weight_decay: float = 0.0001, use_cosine_decay: bool = True, min_lr: float = 1e-06, warmup_epochs: int = 3, seed: int | None = 42) None[source]¶
Train an MLP regressor with stable optimization and smooth loss curves.
Notes
For detection probabilities in [0,1], prefer loss_type=”bce_logits” and do NOT set train_in_log=True (BCE expects probability targets, not log-values).
seed fixes the validation split for a less jittery val-loss.