gwkokab.analysis.core.flowMC_baseΒΆ
AttributesΒΆ
ClassesΒΆ
AnalysisBase is a class which contains all the common functionality among the |
|
A bundle that uses a Rational Quadratic Spline as a normalizing flow model and |
|
Top level API that the users primarily interact with. |
Module ContentsΒΆ
- class gwkokab.analysis.core.flowMC_base.FlowMCBase(*, analysis_name: str, check_leaks: bool, debug_nans: bool, model: numpyro.distributions.distribution.Distribution | collections.abc.Callable[Ellipsis, numpyro.distributions.distribution.Distribution], poisson_mean_filename: str, prior_filename: str, profile_memory: bool, sampler_cfg, variance_cut_threshold: float | None)ΒΆ
Bases:
gwkokab.analysis.core.analysis_base.AnalysisBaseAnalysisBase is a class which contains all the common functionality among the different analyses.
It is not meant to be used directly, but rather to be subclassed by the specific analyses.
- class gwkokab.analysis.core.flowMC_base.Local_Global_Sampler_Bundle(rng_key: jaxtyping.PRNGKeyArray, n_chains: int, n_dims: int, logpdf: Callable[[jaxtyping.Float[jaxtyping.Array, n_dim], dict], jaxtyping.Float], n_local_steps: int, n_global_steps: int, n_training_loops: int, n_production_loops: int, n_epochs: int, local_sampler_name: Literal['mala', 'hmc'] = 'mala', step_size: float = 0.1, mass_matrix: jaxtyping.Array = 1.0, n_leapfrog: int = 10, chain_batch_size: int = 0, rq_spline_hidden_units: list[int] = [32, 32], rq_spline_n_bins: int = 8, rq_spline_n_layers: int = 4, rq_spline_range: tuple[float, float] = (-10.0, 10.0), learning_rate: float = 0.001, batch_size: int = 10000, n_max_examples: int = 10000, history_window: int = 100, local_thinning: int = 1, global_thinning: int = 1, n_NFproposal_batch_size: int = 10000, verbose: bool = False)ΒΆ
Bases:
flowMC.resource_strategy_bundle.base.ResourceStrategyBundleA bundle that uses a Rational Quadratic Spline as a normalizing flow model and the Metropolis Adjusted Langevin Algorithm as a local sampler.
This is the base algorithm described in https://www.pnas.org/doi/full/10.1073/pnas.2109420119
- resourcesΒΆ
- strategiesΒΆ
- strategy_order = []ΒΆ
- class gwkokab.analysis.core.flowMC_base.Sampler(n_dim: int, n_chains: int, rng_key: jaxtyping.PRNGKeyArray, resources: None | dict[str, flowMC.resource.base.Resource] = None, strategies: None | dict[str, flowMC.strategy.base.Strategy] = None, strategy_order: None | list[str] = None, resource_strategy_bundles: None | flowMC.resource_strategy_bundle.base.ResourceStrategyBundle = None, **kwargs)ΒΆ
Top level API that the users primarily interact with.
- Parameters:
n_dim (int) β Dimension of the parameter space.
n_chains (int) β Number of chains to sample.
rng_key (PRNGKeyArray) β Jax PRNGKey.
logpdf (Callable[[Float[Array, "n_dim"], dict], Float) β Log probability function.
resources (dict[str, Resource]) β Resources to be used by the sampler.
strategies (dict[str, Strategy]) β Strategies to be used by the sampler.
verbose (bool) β Whether to print out progress. Defaults to False.
logging (bool) β Whether to log the progress. Defaults to True.
outdir (str) β Directory to save the logs. Defaults to β./outdir/β.
- abstractmethod deserialize()ΒΆ
Deserialize the sampler object.
- sample(initial_position: jaxtyping.Float[jaxtyping.Array, n_chains n_dim], data: dict, n_local_steps_per_loop: int, n_global_steps_per_loop: int, labels: List[str])ΒΆ
Sample from the posterior using the local sampler.
- Parameters:
initial_position (Device Array) β Initial position.
data (dict) β Data to be used by the likelihood functions
- abstractmethod serialize()ΒΆ
Serialize the sampler object.
- rng_key: jaxtyping.PRNGKeyArrayΒΆ
- gwkokab.analysis.core.flowMC_base.flowMC_arg_parserΒΆ