gwkokab.analysis.utils.marginals ================================ .. py:module:: gwkokab.analysis.utils.marginals Classes ------- .. autoapisummary:: gwkokab.analysis.utils.marginals.PlotStyle Functions --------- .. autoapisummary:: gwkokab.analysis.utils.marginals.calculate_dist_layouts gwkokab.analysis.utils.marginals.calculate_marginals_over_axes gwkokab.analysis.utils.marginals.compute_batched_marginals gwkokab.analysis.utils.marginals.generate_marginal_probs gwkokab.analysis.utils.marginals.plot_marginal_with_intervals gwkokab.analysis.utils.marginals.read_domains gwkokab.analysis.utils.marginals.remove_comoving_volume_factor gwkokab.analysis.utils.marginals.save_results_to_hdf5 gwkokab.analysis.utils.marginals.write_domains Module Contents --------------- .. py:class:: PlotStyle Bases: :py:obj:`NamedTuple` A named tuple representing the style for plotting marginal densities, including the color, label, and additional keyword arguments for line plots and fill-between plots. .. py:attribute:: color :type: str .. py:attribute:: fill_between_kwargs :type: dict .. py:attribute:: label :type: str .. py:attribute:: line_plot_kwargs :type: dict .. py:function:: calculate_dist_layouts(shaped_values: list[tuple[int, int] | int]) -> list[tuple[int, Ellipsis]] Calculate the layout of the distribution's support based on its shaped values. :param shaped_values: A list of shaped values for the distribution. Each element can be either a tuple representing a multi-dimensional support (with the second element indicating the number of dimensions), or an integer representing a one-dimensional support. :type shaped_values: list[tuple[int, int] | int] :returns: A list of tuples representing the layout of the distribution's support. Each tuple contains the indices of the dimensions that correspond to a particular component of the distribution. :rtype: list[tuple[int, ...]] .. py:function:: calculate_marginals_over_axes(probs: jaxtyping.Array, domains: list[jaxtyping.Array], normalize: list[bool] | None = None) -> list[jaxtyping.Array] Calculate marginal densities for each axis of a joint density array. This function iteratively integrates out all dimensions except the specified axis to derive the marginal densities for each axis. :param probs: An array representing the joint density over multiple dimensions. :type probs: Array :param domains: A list of arrays representing the domain values for each dimension of the joint density. The length of this list should match the number of dimensions in :code:`probs`. :type domains: list[Array] :param normalize: A list of booleans indicating whether to normalize the marginal densities for each dimension. If None, all dimensions will be normalized, by default None :type normalize: list[bool] | None, optional :returns: A list of arrays representing the marginal densities for each dimension. :rtype: list[Array] .. py:function:: compute_batched_marginals(model_meta_cls: type, samples_batch: jaxtyping.Array, constants: dict, variables_index: dict, domains: list[jaxtyping.Array], normalize: list[bool], batch_size: int | None = None) Compute marginal densities. :param model_meta_cls: A class representing the meta-information of the model, which includes a method for constructing the model given specific parameters. :type model_meta_cls: type :param samples_batch: A batch of samples, where each row corresponds to a single sample and each column corresponds to a specific parameter of the model. :type samples_batch: Array :param constants: A dictionary of constant values required for constructing the model. The keys should match the parameter names expected by the model's constructor. :type constants: dict :param variables_index: A dictionary mapping parameter names to their corresponding column indices in the :code:`samples_batch` array. This mapping is used to extract the relevant parameter values from the samples when constructing the model. :type variables_index: dict :param domains: A list of arrays representing the domain values for each parameter of the model. The length of this list should match the number of parameters in the model. :type domains: list[Array] :param normalize: A list of booleans indicating whether to normalize the marginal densities for each parameter. :type normalize: list[bool], optional :param batch_size: The size of the batch to process at a time, by default None :type batch_size: int | None, optional .. py:function:: generate_marginal_probs(model_meta_cls: type, inference_data_path: str | pathlib.Path, domain_cfg: dict[str, tuple[float, float, int]], max_samples: int | None = None, batch_size: int | None = None) Generate marginal probability densities. :param model_meta_cls: A class representing the meta-information of the model, which includes a method for constructing the model given specific parameters. :type model_meta_cls: type :param inference_data_path: Path to hdf5 file containing the inference data saved by the analysis. Generated probs will be saved in this file too. :type inference_data_path: str | Path :param domain_cfg: A dictionary mapping parameter names to their corresponding domain specifications. Each value in the dictionary should be a tuple containing the start, stop, and number of points for the domain of the parameter. :type domain_cfg: dict[str, tuple[float, float, int]] :param filename: The path to the HDF5 file where the results will be saved. :type filename: str :param max_samples: The maximum number of samples to use for computing marginal densities, by default None :type max_samples: int | None, optional :param batch_size: The batch size for computing marginal densities, by default None :type batch_size: int | None, optional :raises FileNotFoundError: If the required samples file cannot be found in the specified base directory. .. py:function:: plot_marginal_with_intervals(ax: matplotlib.pyplot.Axes, filename: str, parameter: str, style: PlotStyle, component_idxs: list[int], scale: float | Callable = 1.0, weights: list[float | Callable] | None = None, normalize: bool = False) Plot marginal densities with confidence intervals for a specified parameter. :param ax: The Matplotlib Axes object on which to plot the marginal densities and confidence intervals. :type ax: plt.Axes :param filename: The path to the HDF5 file containing the marginal density data. :type filename: str :param parameter: The name of the parameter for which to plot the marginal densities. This should correspond to a dataset in the HDF5 file under the "probs/component_{i}" groups. :type parameter: str :param style: A list of PlotStyle objects specifying the plotting style for each component's marginal density. If an element is None, the corresponding component will be skipped in the plot. :type style: PlotStyle :param component_idxs: A list of indices specifying which components to plot. :type component_idxs: list[int] :param scale: A scaling factor for the marginal densities. If a callable is provided, it will be evaluated with the parameters from the HDF5 file, by default 1.0 :type scale: float | Callable, optional :param weights: The weights for each component's marginal density. If None, equal weights are assumed, by default None :type weights: list[float | Callable] | None, optional :param normalize: Whether to normalize the marginal densities, by default False :type normalize: bool, optional .. py:function:: read_domains(filepath: str | pathlib.Path) -> dict[str, tuple[float, float, int]] Read domain specifications from an HDF5 file. :param filepath: The path to the HDF5 file containing the domain specifications. :type filepath: str | Path :returns: A dictionary mapping parameter names to their corresponding domain specifications. Each value in the dictionary is a tuple containing the start, stop, and number of points for the domain of the parameter. :rtype: dict[str, tuple[float, float, int]] .. py:function:: remove_comoving_volume_factor(marginal_density: jaxtyping.Array, redshift_domain: jaxtyping.Array) -> jaxtyping.Array Remove the comoving volume factor from a marginal density over redshift. This function takes a marginal density that includes the comoving volume factor and time dilation factor and divides it by the comoving volume element to obtain the underlying density without the volume factor. :param marginal_density: The marginal density over redshift that includes the comoving volume factor. :type marginal_density: Array :param redshift_domain: The array of redshift values corresponding to the marginal density. This is used to compute the comoving volume element. :type redshift_domain: Array :returns: The marginal density with the comoving volume and time dilation factors removed, representing the underlying density over redshift. :rtype: Array .. py:function:: save_results_to_hdf5(samples: jaxtyping.Array, batched_results: list[list[list[jaxtyping.Array]]], parameters: list[str], domain_cfg: dict[str, tuple[float, float, int]], filepath: str | pathlib.Path) Save the computed marginal densities to an HDF5 file. :param samples: An array of samples used for computing the marginal densities. Each row corresponds to a single sample, and each column corresponds to a specific parameter of the model. :type samples: Array :param batched_results: A nested list containing the computed marginal densities for each component of the model. The outer list corresponds to the components of the model, the inner list corresponds to the parameters of the model, and each leaf is a 2D array of shape (num_samples, domain_size). :type batched_results: list[list[Array]] :param parameters: A list of parameter names for the model. :type parameters: list[str] :param domain_cfg: A dictionary mapping parameter names to their corresponding domain specifications. :type domain_cfg: dict[str, tuple[float, float, int]] :param filepath: The path to the HDF5 file where the results will be saved. :type filepath: str | Path .. py:function:: write_domains(filepath: str | pathlib.Path, domain_cfg: dict[str, tuple[float, float, int]]) Write domain specifications to an HDF5 file. :param filepath: The path to the HDF5 file where the domain specifications will be saved. :type filepath: str | Path :param domain_cfg: A dictionary mapping parameter names to their corresponding domain specifications. Each value in the dictionary is a tuple containing the start, stop, and number of points for the domain of the parameter. :type domain_cfg: dict[str, tuple[float, float, int]]