gwkokab.models.utils

Classes

DoublyTruncatedPowerLaw

Power law distribution with \(\alpha\) index, and lower and upper bounds. We

ExtendedSupportTransformedDistribution

Returns a distribution instance obtained as a result of applying

JointDistribution

Base class for probability distributions in NumPyro. The design largely

LazyJointDistribution

Base class for probability distributions in NumPyro. The design largely

ScaledMixture

A finite mixture of component distributions from different families. This is a

Functions

Package Contents

class gwkokab.models.utils.DoublyTruncatedPowerLaw(alpha: jaxtyping.ArrayLike, low: jaxtyping.ArrayLike, high: jaxtyping.ArrayLike, *, validate_args: bool | None = None)[source]

Bases: numpyro.distributions.Distribution

Power law distribution with \(\alpha\) index, and lower and upper bounds. We can define the power law distribution as,

\[f(x; \alpha, a, b) = \frac{x^{\alpha}}{Z(\alpha, a, b)},\]

where, \(a\) and \(b\) are the lower and upper bounds respectively, and \(Z(\alpha, a, b)\) is the normalization constant. It is defined as,

\[\begin{split}Z(\alpha, a, b) = \begin{cases} \log(b) - \log(a) & \text{if } \alpha = -1, \\ \frac{b^{1 + \alpha} - a^{1 + \alpha}}{1 + \alpha} & \text{otherwise}. \end{cases}\end{split}\]
Parameters:
  • alpha – index of the power law distribution

  • low – lower bound of the distribution

  • high – upper bound of the distribution

cdf(value: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]

Cumulated probability distribution: Z inequal minus one:

\[\frac{x^{\alpha + 1} - a^{\alpha + 1}}{b^{\alpha + 1} - a^{\alpha + 1}}\]

Z equal minus one:

\[\frac{\log(x) - \log(a)}{\log(b) - \log(a)}\]

Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.

icdf(q: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]

Inverse cumulated probability distribution: Z inequal minus one:

\[a \left(\frac{b}{a}\right)^{q}\]

Z equal minus one:

\[\left(a^{1 + \alpha} + q (b^{1 + \alpha} - a^{1 + \alpha})\right)^{\frac{1}{1 + \alpha}}\]

Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.

log_prob(value: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]

Logarithmic probability distribution: Z inequal minus one: .. math:

(x^\alpha) (\alpha + 1)/(b^(\alpha + 1) - a^(\alpha + 1))

Z equal minus one: .. math:

(x^\alpha)/(log(b) - log(a))

Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.

sample(key: jax.dtypes.prng_key, sample_shape: tuple[int, Ellipsis] = ()) jaxtyping.ArrayLike[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

support() numpyro.distributions.constraints.Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

class gwkokab.models.utils.ExtendedSupportTransformedDistribution(base_distribution: Distribution, transforms: numpyro.distributions.transforms.Transform | list[numpyro.distributions.transforms.Transform], *, validate_args: bool | None = None)[source]

Bases: numpyro.distributions.TransformedDistribution

Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see LogNormal and HalfNormal.

Parameters:
  • base_distribution – the base distribution over which to apply transforms.

  • transforms – a single transform or a list of transforms.

  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

property support

The support of this distribution. Subclasses can override this as a class attribute or as a property.

class gwkokab.models.utils.JointDistribution(*marginal_distributions: numpyro.distributions.Distribution, flatten_method: Literal['deep', 'shallow'] | None = None, support: numpyro.distributions.constraints.Constraint | None = None, validate_args: bool | None = None)[source]

Bases: numpyro.distributions.Distribution

Base class for probability distributions in NumPyro. The design largely follows from torch.distributions.

Parameters:
  • batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.

  • event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.

  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

As an example:

Construct a joint distribution from one or more marginal distributions.

You may pass individual Distribution instances or nest them inside JointDistribution`s. The `flatten_method argument allows flattening of nested joints into a single flat list of marginals.

Parameters:
  • marginal_distributions (*Distribution) – One or more Distribution objects (or nested :class:`JointDistribution`s) that form the components of the joint distribution.

  • flatten_method (Optional[Literal["deep", "shallow"]], optional) – If “shallow”, one level of nested JointDistributions will be flattened. If “deep”, all levels of nested JointDistributions will be recursively flattened. If None (default), the nesting is preserved as-is.

  • support (Optional[constraints.Constraint], optional) – The constraint object representing the support of the joint distribution. If not provided, it is computed from the support of the marginals.

  • validate_args (Optional[bool], optional) – Whether to validate distribution parameters and inputs. Default is None.

Raises:

ValueError – If no marginal distributions are provided.

Example

>>> from numpyro.distributions import Normal
>>> from gwkokab.models.utils import JointDistribution

>>> A = Normal(0, 1)
>>> B = Normal(1, 1)
>>> C = Normal(2, 1)
>>> D = Normal(3, 1)
>>> E = Normal(4, 1)

>>> jd = JointDistribution(
...     A, JointDistribution(B, JointDistribution(C, D)), E
... )

>>> len(jd.marginal_distributions)  # No flattening (default)
3

>>> jd = JointDistribution(
...     A,
...     JointDistribution(B, JointDistribution(C, D)),
...     E,
...     flatten_method="shallow",
... )
>>> len(jd.marginal_distributions)  # Shallow flattening
4

>>> jd = JointDistribution(
...     A,
...     JointDistribution(B, JointDistribution(C, D)),
...     E,
...     flatten_method="deep",
... )
>>> len(jd.marginal_distributions)  # Deep flattening
5
log_prob(value: jaxtyping.Array) jaxtyping.Array[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

marginal_log_probs(value: jaxtyping.Array) jaxtyping.Array[source]
sample(key: jaxtyping.PRNGKeyArray, sample_shape: tuple[int, Ellipsis] = ())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

support() numpyro.distributions.constraints.Constraint

The support of the joint distribution.

marginal_distributions: collections.abc.Sequence[numpyro.distributions.Distribution] = ()
shaped_values: collections.abc.Sequence[int | Tuple[int, int]] = ()
class gwkokab.models.utils.LazyJointDistribution(*marginal_distributions: numpyro.distributions.Distribution | jax.tree_util.Partial, dependencies: Dict[int, Dict[str, int]], partial_order: List[int], dependencies_event_shape: List[Tuple[int, Ellipsis]] | None = None, flatten_method: Literal['deep', 'shallow'] | None = None, support: numpyro.distributions.constraints.Constraint | None = None, validate_args: bool | None = None)[source]

Bases: numpyro.distributions.Distribution

Base class for probability distributions in NumPyro. The design largely follows from torch.distributions.

Parameters:
  • batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.

  • event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.

  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

As an example:

Construct a joint distribution from one or more marginal distributions.

You may pass individual Distribution instances or nest them inside LazyJointDistribution`s. The `flatten_method argument allows flattening of nested joints into a single flat list of marginals.

Parameters:
  • marginal_distributions (*Union[Distribution, jax.tree_util.Partial]) – One or more marginal distributions. Each marginal distribution can be an instance of numpyro.distributions.Distribution or a jax.tree_util.Partial that returns a numpyro.distributions.Distribution when called with its arguments.

  • dependencies (Dict[int, Dict[str, int]]) – A dictionary mapping the index of each marginal distribution that is a jax.tree_util.Partial to another dictionary that maps the names of its dependency parameters to the indices of the marginal distributions they depend on. This is used to specify which variables each lazy variable depends on.

  • partial_order (Optional[Tuple[str, int, int]], optional) – A tuple defining a partial order for the lazy variables. Each entry in the tuple should be of the form (var_name, event_index, marginal_index), such that elements coming earlier in the tuple are not dependent on elements coming later. This is used to ensure that when sampling from the joint distribution, the lazy variables are sampled in an order that respects their dependencies.

  • dependencies_event_shape (Optional[List[Tuple[int, ...]]], optional) – A list of event shapes for the dependencies of each marginal distribution. This is used to validate the shapes of the dependency variables when constructing

  • flatten_method (Optional[Literal["deep", "shallow"]], optional) – Currently not used. If “shallow”, one level of nested LazyJointDistributions will be flattened. If “deep”, all levels of nested LazyJointDistributions will be recursively flattened. If None (default), the nesting is preserved as-is.

  • support (Optional[constraints.Constraint], optional) – The constraint object representing the support of the joint distribution. If not provided, it is computed from the support of the marginals.

  • validate_args (Optional[bool], optional) – Whether to validate distribution parameters and inputs. Default is None.

Raises:

ValueError – If no marginal distributions are provided.

log_prob(value: jaxtyping.Array) jaxtyping.Array[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

sample(key: jaxtyping.PRNGKeyArray, sample_shape: tuple[int, Ellipsis] = ())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

support() numpyro.distributions.constraints.Constraint

The support of the joint distribution.

dependencies
marginal_distributions: collections.abc.Sequence[numpyro.distributions.Distribution | jax.tree_util.Partial] = ()
partial_order
shaped_values: collections.abc.Sequence[int | Tuple[int, int]] = ()
class gwkokab.models.utils.ScaledMixture(log_scales: jaxtyping.Array, component_distributions: List[numpyro.distributions.Distribution], *, support: numpyro.distributions.constraints.Constraint | None = None, validate_args: bool | None = None)[source]

Bases: numpyro.distributions.Distribution

A finite mixture of component distributions from different families. This is a generalization of Mixture where the component distributions are scaled by a set of rates.

Example

>>> import jax
>>> import jax.random as jrd
>>> import numpyro.distributions as dist
>>> from gwkokab.models.utils import ScaledMixture
>>> log_scales = jrd.uniform(jrd.key(42), (3,), minval=0, maxval=5)
>>> component_dists = [
...     dist.Normal(loc=0.0, scale=1.0),
...     dist.Normal(loc=-0.5, scale=0.3),
...     dist.Normal(loc=0.6, scale=1.2),
... ]
>>> mixture = ScaledMixture(log_scales, component_dists)
>>> mixture.sample(jax.random.key(42)).shape
()
cdf(samples)[source]

The cumulative distribution function.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

Raises:

NotImplementedError if the component distribution does not implement the cdf method.

component_cdf(samples)[source]
component_log_probs(value: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]
component_sample(key, sample_shape=())[source]
log_prob(value, intermediates=None)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample_with_intermediates(key, sample_shape=())[source]

A version of sample that also returns the sampled component indices.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

A 2-element tuple with the samples from the distribution, and the indices of the sampled components.

Return type:

tuple

support()

The support of this distribution. Subclasses can override this as a class attribute or as a property.

property component_distributions

The list of component distributions in the mixture.

Returns:

The list of component distributions

Return type:

list[Distribution]

property component_mean
property component_variance
property is_discrete
log_scales
property mean

Mean of the distribution.

property mixture_dim
property mixture_size

The number of components in the mixture.

property variance

Variance of the distribution.

gwkokab.models.utils.doubly_truncated_power_law_cdf(x, alpha, low, high)[source]
gwkokab.models.utils.doubly_truncated_power_law_icdf(q, alpha, low, high)[source]
gwkokab.models.utils.doubly_truncated_power_law_log_norm_constant(alpha, low, high)[source]
gwkokab.models.utils.doubly_truncated_power_law_log_prob(x, alpha, low, high)[source]