gwkokab.analysis.core.flowMC_base ================================= .. py:module:: gwkokab.analysis.core.flowMC_base Attributes ---------- .. autoapisummary:: gwkokab.analysis.core.flowMC_base.flowMC_arg_parser Classes ------- .. autoapisummary:: gwkokab.analysis.core.flowMC_base.FlowMCBase gwkokab.analysis.core.flowMC_base.Local_Global_Sampler_Bundle gwkokab.analysis.core.flowMC_base.Sampler Module Contents --------------- .. py:class:: FlowMCBase(*, analysis_name: str, check_leaks: bool, debug_nans: bool, model: Union[numpyro.distributions.distribution.Distribution, collections.abc.Callable[Ellipsis, numpyro.distributions.distribution.Distribution]], poisson_mean_filename: str, prior_filename: str, profile_memory: bool, sampler_cfg, variance_cut_threshold: float | None) Bases: :py:obj:`gwkokab.analysis.core.analysis_base.AnalysisBase` AnalysisBase is a class which contains all the common functionality among the different analyses. It is not meant to be used directly, but rather to be subclassed by the specific analyses. .. py:method:: driver(*, logpdf: Callable[[jaxtyping.Array, Dict[str, Any]], jaxtyping.Array], priors: gwkokab.models.utils.JointDistribution, data: Dict[str, Any], labels: List[str]) -> None .. py:class:: Local_Global_Sampler_Bundle(rng_key: jaxtyping.PRNGKeyArray, n_chains: int, n_dims: int, logpdf: Callable[[jaxtyping.Float[jaxtyping.Array, n_dim], dict], jaxtyping.Float], n_local_steps: int, n_global_steps: int, n_training_loops: int, n_production_loops: int, n_epochs: int, local_sampler_name: Literal['mala', 'hmc'] = 'mala', step_size: float = 0.1, mass_matrix: jaxtyping.Array = 1.0, n_leapfrog: int = 10, chain_batch_size: int = 0, rq_spline_hidden_units: list[int] = [32, 32], rq_spline_n_bins: int = 8, rq_spline_n_layers: int = 4, rq_spline_range: tuple[float, float] = (-10.0, 10.0), learning_rate: float = 0.001, batch_size: int = 10000, n_max_examples: int = 10000, history_window: int = 100, local_thinning: int = 1, global_thinning: int = 1, n_NFproposal_batch_size: int = 10000, verbose: bool = False) Bases: :py:obj:`flowMC.resource_strategy_bundle.base.ResourceStrategyBundle` A bundle that uses a Rational Quadratic Spline as a normalizing flow model and the Metropolis Adjusted Langevin Algorithm as a local sampler. This is the base algorithm described in https://www.pnas.org/doi/full/10.1073/pnas.2109420119 .. py:attribute:: resources .. py:attribute:: strategies .. py:attribute:: strategy_order :value: [] .. py:class:: Sampler(n_dim: int, n_chains: int, rng_key: jaxtyping.PRNGKeyArray, resources: None | dict[str, flowMC.resource.base.Resource] = None, strategies: None | dict[str, flowMC.strategy.base.Strategy] = None, strategy_order: None | list[str] = None, resource_strategy_bundles: None | flowMC.resource_strategy_bundle.base.ResourceStrategyBundle = None, **kwargs) Top level API that the users primarily interact with. :param n_dim: Dimension of the parameter space. :type n_dim: int :param n_chains: Number of chains to sample. :type n_chains: int :param rng_key: Jax PRNGKey. :type rng_key: PRNGKeyArray :param logpdf: Log probability function. :type logpdf: Callable[[Float[Array, "n_dim"], dict], Float :param resources: Resources to be used by the sampler. :type resources: dict[str, Resource] :param strategies: Strategies to be used by the sampler. :type strategies: dict[str, Strategy] :param verbose: Whether to print out progress. Defaults to False. :type verbose: bool :param logging: Whether to log the progress. Defaults to True. :type logging: bool :param outdir: Directory to save the logs. Defaults to "./outdir/". :type outdir: str .. py:method:: deserialize() :abstractmethod: Deserialize the sampler object. .. py:method:: sample(initial_position: jaxtyping.Float[jaxtyping.Array, n_chains n_dim], data: dict, n_local_steps_per_loop: int, n_global_steps_per_loop: int, labels: List[str]) Sample from the posterior using the local sampler. :param initial_position: Initial position. :type initial_position: Device Array :param data: Data to be used by the likelihood functions :type data: dict .. py:method:: serialize() :abstractmethod: Serialize the sampler object. .. py:attribute:: logging :type: bool :value: True .. py:attribute:: n_chains :type: int .. py:attribute:: n_dim :type: int .. py:attribute:: outdir :type: str :value: './outdir/' .. py:attribute:: resources :type: dict[str, flowMC.resource.base.Resource] .. py:attribute:: rng_key :type: jaxtyping.PRNGKeyArray .. py:attribute:: strategies :type: dict[str, flowMC.strategy.base.Strategy] .. py:attribute:: strategy_order :type: Optional[list[str]] .. py:attribute:: verbose :type: bool :value: False .. py:data:: flowMC_arg_parser