gwkokab.analysis.utils.marginals¶

Classes¶

PlotStyle

A named tuple representing the style for plotting marginal densities, including

Functions¶

calculate_dist_layouts(→ list[tuple[int, Ellipsis]])

Calculate the layout of the distribution's support based on its shaped values.

calculate_marginals_over_axes(→ list[jaxtyping.Array])

Calculate marginal densities for each axis of a joint density array.

compute_batched_marginals(model_meta_cls, ...[, ...])

Compute marginal densities.

generate_marginal_probs(model_meta_cls, ...[, ...])

Generate marginal probability densities.

plot_marginal_with_intervals(ax, filename, parameter, ...)

Plot marginal densities with confidence intervals for a specified parameter.

read_domains(→ dict[str, tuple[float, float, int]])

Read domain specifications from an HDF5 file.

remove_comoving_volume_factor(→ jaxtyping.Array)

Remove the comoving volume factor from a marginal density over redshift.

save_results_to_hdf5(samples, batched_results, ...)

Save the computed marginal densities to an HDF5 file.

write_domains(filepath, domain_cfg)

Write domain specifications to an HDF5 file.

Module Contents¶

class gwkokab.analysis.utils.marginals.PlotStyle¶

Bases: NamedTuple

A named tuple representing the style for plotting marginal densities, including the color, label, and additional keyword arguments for line plots and fill-between plots.

color: str¶
fill_between_kwargs: dict¶
label: str¶
line_plot_kwargs: dict¶
gwkokab.analysis.utils.marginals.calculate_dist_layouts(shaped_values: list[tuple[int, int] | int]) list[tuple[int, Ellipsis]]¶

Calculate the layout of the distribution’s support based on its shaped values.

Parameters:

shaped_values (list[tuple[int, int] | int]) – A list of shaped values for the distribution. Each element can be either a tuple representing a multi-dimensional support (with the second element indicating the number of dimensions), or an integer representing a one-dimensional support.

Returns:

A list of tuples representing the layout of the distribution’s support. Each tuple contains the indices of the dimensions that correspond to a particular component of the distribution.

Return type:

list[tuple[int, …]]

gwkokab.analysis.utils.marginals.calculate_marginals_over_axes(probs: jaxtyping.Array, domains: list[jaxtyping.Array], normalize: list[bool] | None = None) list[jaxtyping.Array]¶

Calculate marginal densities for each axis of a joint density array.

This function iteratively integrates out all dimensions except the specified axis to derive the marginal densities for each axis.

Parameters:
  • probs (Array) – An array representing the joint density over multiple dimensions.

  • domains (list[Array]) – A list of arrays representing the domain values for each dimension of the joint density. The length of this list should match the number of dimensions in probs.

  • normalize (list[bool] | None, optional) – A list of booleans indicating whether to normalize the marginal densities for each dimension. If None, all dimensions will be normalized, by default None

Returns:

A list of arrays representing the marginal densities for each dimension.

Return type:

list[Array]

gwkokab.analysis.utils.marginals.compute_batched_marginals(model_meta_cls: type, samples_batch: jaxtyping.Array, constants: dict, variables_index: dict, domains: list[jaxtyping.Array], normalize: list[bool], batch_size: int | None = None)¶

Compute marginal densities.

Parameters:
  • model_meta_cls (type) – A class representing the meta-information of the model, which includes a method for constructing the model given specific parameters.

  • samples_batch (Array) – A batch of samples, where each row corresponds to a single sample and each column corresponds to a specific parameter of the model.

  • constants (dict) – A dictionary of constant values required for constructing the model. The keys should match the parameter names expected by the model’s constructor.

  • variables_index (dict) – A dictionary mapping parameter names to their corresponding column indices in the samples_batch array. This mapping is used to extract the relevant parameter values from the samples when constructing the model.

  • domains (list[Array]) – A list of arrays representing the domain values for each parameter of the model. The length of this list should match the number of parameters in the model.

  • normalize (list[bool], optional) – A list of booleans indicating whether to normalize the marginal densities for each parameter.

  • batch_size (int | None, optional) – The size of the batch to process at a time, by default None

gwkokab.analysis.utils.marginals.generate_marginal_probs(model_meta_cls: type, inference_data_path: str | pathlib.Path, domain_cfg: dict[str, tuple[float, float, int]], max_samples: int | None = None, batch_size: int | None = None)¶

Generate marginal probability densities.

Parameters:
  • model_meta_cls (type) – A class representing the meta-information of the model, which includes a method for constructing the model given specific parameters.

  • inference_data_path (str | Path) – Path to hdf5 file containing the inference data saved by the analysis. Generated probs will be saved in this file too.

  • domain_cfg (dict[str, tuple[float, float, int]]) – A dictionary mapping parameter names to their corresponding domain specifications. Each value in the dictionary should be a tuple containing the start, stop, and number of points for the domain of the parameter.

  • filename (str) – The path to the HDF5 file where the results will be saved.

  • max_samples (int | None, optional) – The maximum number of samples to use for computing marginal densities, by default None

  • batch_size (int | None, optional) – The batch size for computing marginal densities, by default None

Raises:

FileNotFoundError – If the required samples file cannot be found in the specified base directory.

gwkokab.analysis.utils.marginals.plot_marginal_with_intervals(ax: matplotlib.pyplot.Axes, filename: str, parameter: str, style: PlotStyle, component_idxs: list[int], scale: float | Callable = 1.0, weights: list[float | Callable] | None = None, normalize: bool = False)¶

Plot marginal densities with confidence intervals for a specified parameter.

Parameters:
  • ax (plt.Axes) – The Matplotlib Axes object on which to plot the marginal densities and confidence intervals.

  • filename (str) – The path to the HDF5 file containing the marginal density data.

  • parameter (str) – The name of the parameter for which to plot the marginal densities. This should correspond to a dataset in the HDF5 file under the “probs/component_{i}” groups.

  • style (PlotStyle) – A list of PlotStyle objects specifying the plotting style for each component’s marginal density. If an element is None, the corresponding component will be skipped in the plot.

  • component_idxs (list[int]) – A list of indices specifying which components to plot.

  • scale (float | Callable, optional) – A scaling factor for the marginal densities. If a callable is provided, it will be evaluated with the parameters from the HDF5 file, by default 1.0

  • weights (list[float | Callable] | None, optional) – The weights for each component’s marginal density. If None, equal weights are assumed, by default None

  • normalize (bool, optional) – Whether to normalize the marginal densities, by default False

gwkokab.analysis.utils.marginals.read_domains(filepath: str | pathlib.Path) dict[str, tuple[float, float, int]]¶

Read domain specifications from an HDF5 file.

Parameters:

filepath (str | Path) – The path to the HDF5 file containing the domain specifications.

Returns:

A dictionary mapping parameter names to their corresponding domain specifications. Each value in the dictionary is a tuple containing the start, stop, and number of points for the domain of the parameter.

Return type:

dict[str, tuple[float, float, int]]

gwkokab.analysis.utils.marginals.remove_comoving_volume_factor(marginal_density: jaxtyping.Array, redshift_domain: jaxtyping.Array) jaxtyping.Array¶

Remove the comoving volume factor from a marginal density over redshift.

This function takes a marginal density that includes the comoving volume factor and time dilation factor and divides it by the comoving volume element to obtain the underlying density without the volume factor.

Parameters:
  • marginal_density (Array) – The marginal density over redshift that includes the comoving volume factor.

  • redshift_domain (Array) – The array of redshift values corresponding to the marginal density. This is used to compute the comoving volume element.

Returns:

The marginal density with the comoving volume and time dilation factors removed, representing the underlying density over redshift.

Return type:

Array

gwkokab.analysis.utils.marginals.save_results_to_hdf5(samples: jaxtyping.Array, batched_results: list[list[list[jaxtyping.Array]]], parameters: list[str], domain_cfg: dict[str, tuple[float, float, int]], filepath: str | pathlib.Path)¶

Save the computed marginal densities to an HDF5 file.

Parameters:
  • samples (Array) – An array of samples used for computing the marginal densities. Each row corresponds to a single sample, and each column corresponds to a specific parameter of the model.

  • batched_results (list[list[Array]]) – A nested list containing the computed marginal densities for each component of the model. The outer list corresponds to the components of the model, the inner list corresponds to the parameters of the model, and each leaf is a 2D array of shape (num_samples, domain_size).

  • parameters (list[str]) – A list of parameter names for the model.

  • domain_cfg (dict[str, tuple[float, float, int]]) – A dictionary mapping parameter names to their corresponding domain specifications.

  • filepath (str | Path) – The path to the HDF5 file where the results will be saved.

gwkokab.analysis.utils.marginals.write_domains(filepath: str | pathlib.Path, domain_cfg: dict[str, tuple[float, float, int]])¶

Write domain specifications to an HDF5 file.

Parameters:
  • filepath (str | Path) – The path to the HDF5 file where the domain specifications will be saved.

  • domain_cfg (dict[str, tuple[float, float, int]]) – A dictionary mapping parameter names to their corresponding domain specifications. Each value in the dictionary is a tuple containing the start, stop, and number of points for the domain of the parameter.