gwkokab.analysis.core.flowMC_baseΒΆ

AttributesΒΆ

ClassesΒΆ

FlowMCBase

AnalysisBase is a class which contains all the common functionality among the

Local_Global_Sampler_Bundle

A bundle that uses a Rational Quadratic Spline as a normalizing flow model and

Sampler

Top level API that the users primarily interact with.

Module ContentsΒΆ

class gwkokab.analysis.core.flowMC_base.FlowMCBase(*, analysis_name: str, check_leaks: bool, debug_nans: bool, model: numpyro.distributions.distribution.Distribution | collections.abc.Callable[Ellipsis, numpyro.distributions.distribution.Distribution], poisson_mean_filename: str, prior_filename: str, profile_memory: bool, sampler_cfg, variance_cut_threshold: float | None)ΒΆ

Bases: gwkokab.analysis.core.analysis_base.AnalysisBase

AnalysisBase is a class which contains all the common functionality among the different analyses.

It is not meant to be used directly, but rather to be subclassed by the specific analyses.

driver(*, logpdf: Callable[[jaxtyping.Array, Dict[str, Any]], jaxtyping.Array], priors: gwkokab.models.utils.JointDistribution, data: Dict[str, Any], labels: List[str]) NoneΒΆ
class gwkokab.analysis.core.flowMC_base.Local_Global_Sampler_Bundle(rng_key: jaxtyping.PRNGKeyArray, n_chains: int, n_dims: int, logpdf: Callable[[jaxtyping.Float[jaxtyping.Array, n_dim], dict], jaxtyping.Float], n_local_steps: int, n_global_steps: int, n_training_loops: int, n_production_loops: int, n_epochs: int, local_sampler_name: Literal['mala', 'hmc'] = 'mala', step_size: float = 0.1, mass_matrix: jaxtyping.Array = 1.0, n_leapfrog: int = 10, chain_batch_size: int = 0, rq_spline_hidden_units: list[int] = [32, 32], rq_spline_n_bins: int = 8, rq_spline_n_layers: int = 4, rq_spline_range: tuple[float, float] = (-10.0, 10.0), learning_rate: float = 0.001, batch_size: int = 10000, n_max_examples: int = 10000, history_window: int = 100, local_thinning: int = 1, global_thinning: int = 1, n_NFproposal_batch_size: int = 10000, verbose: bool = False)ΒΆ

Bases: flowMC.resource_strategy_bundle.base.ResourceStrategyBundle

A bundle that uses a Rational Quadratic Spline as a normalizing flow model and the Metropolis Adjusted Langevin Algorithm as a local sampler.

This is the base algorithm described in https://www.pnas.org/doi/full/10.1073/pnas.2109420119

resourcesΒΆ
strategiesΒΆ
strategy_order = []ΒΆ
class gwkokab.analysis.core.flowMC_base.Sampler(n_dim: int, n_chains: int, rng_key: jaxtyping.PRNGKeyArray, resources: None | dict[str, flowMC.resource.base.Resource] = None, strategies: None | dict[str, flowMC.strategy.base.Strategy] = None, strategy_order: None | list[str] = None, resource_strategy_bundles: None | flowMC.resource_strategy_bundle.base.ResourceStrategyBundle = None, **kwargs)ΒΆ

Top level API that the users primarily interact with.

Parameters:
  • n_dim (int) – Dimension of the parameter space.

  • n_chains (int) – Number of chains to sample.

  • rng_key (PRNGKeyArray) – Jax PRNGKey.

  • logpdf (Callable[[Float[Array, "n_dim"], dict], Float) – Log probability function.

  • resources (dict[str, Resource]) – Resources to be used by the sampler.

  • strategies (dict[str, Strategy]) – Strategies to be used by the sampler.

  • verbose (bool) – Whether to print out progress. Defaults to False.

  • logging (bool) – Whether to log the progress. Defaults to True.

  • outdir (str) – Directory to save the logs. Defaults to β€œ./outdir/”.

abstractmethod deserialize()ΒΆ

Deserialize the sampler object.

sample(initial_position: jaxtyping.Float[jaxtyping.Array, n_chains n_dim], data: dict, n_local_steps_per_loop: int, n_global_steps_per_loop: int, labels: List[str])ΒΆ

Sample from the posterior using the local sampler.

Parameters:
  • initial_position (Device Array) – Initial position.

  • data (dict) – Data to be used by the likelihood functions

abstractmethod serialize()ΒΆ

Serialize the sampler object.

logging: bool = TrueΒΆ
n_chains: intΒΆ
n_dim: intΒΆ
outdir: str = './outdir/'ΒΆ
resources: dict[str, flowMC.resource.base.Resource]ΒΆ
rng_key: jaxtyping.PRNGKeyArrayΒΆ
strategies: dict[str, flowMC.strategy.base.Strategy]ΒΆ
strategy_order: list[str] | NoneΒΆ
verbose: bool = FalseΒΆ
gwkokab.analysis.core.flowMC_base.flowMC_arg_parserΒΆ