gwkokab.models.utils¶
Classes¶
Power law distribution with \(\alpha\) index, and lower and upper bounds. We |
|
Returns a distribution instance obtained as a result of applying |
|
Base class for probability distributions in NumPyro. The design largely |
|
Base class for probability distributions in NumPyro. The design largely |
|
A finite mixture of component distributions from different families. This is a |
Functions¶
|
|
|
|
|
|
|
Package Contents¶
- class gwkokab.models.utils.DoublyTruncatedPowerLaw(alpha: jaxtyping.ArrayLike, low: jaxtyping.ArrayLike, high: jaxtyping.ArrayLike, *, validate_args: bool | None = None)[source]¶
Bases:
numpyro.distributions.DistributionPower law distribution with \(\alpha\) index, and lower and upper bounds. We can define the power law distribution as,
\[f(x; \alpha, a, b) = \frac{x^{\alpha}}{Z(\alpha, a, b)},\]where, \(a\) and \(b\) are the lower and upper bounds respectively, and \(Z(\alpha, a, b)\) is the normalization constant. It is defined as,
\[\begin{split}Z(\alpha, a, b) = \begin{cases} \log(b) - \log(a) & \text{if } \alpha = -1, \\ \frac{b^{1 + \alpha} - a^{1 + \alpha}}{1 + \alpha} & \text{otherwise}. \end{cases}\end{split}\]- Parameters:
alpha – index of the power law distribution
low – lower bound of the distribution
high – upper bound of the distribution
- cdf(value: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
Cumulated probability distribution: Z inequal minus one:
\[\frac{x^{\alpha + 1} - a^{\alpha + 1}}{b^{\alpha + 1} - a^{\alpha + 1}}\]Z equal minus one:
\[\frac{\log(x) - \log(a)}{\log(b) - \log(a)}\]Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.
- icdf(q: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
Inverse cumulated probability distribution: Z inequal minus one:
\[a \left(\frac{b}{a}\right)^{q}\]Z equal minus one:
\[\left(a^{1 + \alpha} + q (b^{1 + \alpha} - a^{1 + \alpha})\right)^{\frac{1}{1 + \alpha}}\]Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.
- log_prob(value: jaxtyping.ArrayLike) jaxtyping.ArrayLike[source]¶
Logarithmic probability distribution: Z inequal minus one: .. math:
(x^\alpha) (\alpha + 1)/(b^(\alpha + 1) - a^(\alpha + 1))
Z equal minus one: .. math:
(x^\alpha)/(log(b) - log(a))
Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.
- sample(key: jax.dtypes.prng_key, sample_shape: tuple[int, Ellipsis] = ()) jaxtyping.ArrayLike[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters:
key (jax.random.key) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns:
an array of shape sample_shape + batch_shape + event_shape
- Return type:
- support() numpyro.distributions.constraints.Constraint¶
The support of this distribution. Subclasses can override this as a class attribute or as a property.
- class gwkokab.models.utils.ExtendedSupportTransformedDistribution(base_distribution: Distribution, transforms: numpyro.distributions.transforms.Transform | list[numpyro.distributions.transforms.Transform], *, validate_args: bool | None = None)[source]¶
Bases:
numpyro.distributions.TransformedDistributionReturns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see
LogNormalandHalfNormal.- Parameters:
base_distribution – the base distribution over which to apply transforms.
transforms – a single transform or a list of transforms.
validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
- property support¶
The support of this distribution. Subclasses can override this as a class attribute or as a property.
- class gwkokab.models.utils.JointDistribution(*marginal_distributions: numpyro.distributions.Distribution, flatten_method: Literal['deep', 'shallow'] | None = None, support: numpyro.distributions.constraints.Constraint | None = None, validate_args: bool | None = None)[source]¶
Bases:
numpyro.distributions.DistributionBase class for probability distributions in NumPyro. The design largely follows from
torch.distributions.- Parameters:
batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
As an example:
Construct a joint distribution from one or more marginal distributions.
You may pass individual Distribution instances or nest them inside
JointDistribution`s. The `flatten_methodargument allows flattening of nested joints into a single flat list of marginals.- Parameters:
marginal_distributions (*Distribution) – One or more Distribution objects (or nested :class:`JointDistribution`s) that form the components of the joint distribution.
flatten_method (Optional[Literal["deep", "shallow"]], optional) – If “shallow”, one level of nested
JointDistributionswill be flattened. If “deep”, all levels of nestedJointDistributionswill be recursively flattened. If None (default), the nesting is preserved as-is.support (Optional[constraints.Constraint], optional) – The constraint object representing the support of the joint distribution. If not provided, it is computed from the support of the marginals.
validate_args (Optional[bool], optional) – Whether to validate distribution parameters and inputs. Default is None.
- Raises:
ValueError – If no marginal distributions are provided.
Example
>>> from numpyro.distributions import Normal >>> from gwkokab.models.utils import JointDistribution >>> A = Normal(0, 1) >>> B = Normal(1, 1) >>> C = Normal(2, 1) >>> D = Normal(3, 1) >>> E = Normal(4, 1) >>> jd = JointDistribution( ... A, JointDistribution(B, JointDistribution(C, D)), E ... ) >>> len(jd.marginal_distributions) # No flattening (default) 3 >>> jd = JointDistribution( ... A, ... JointDistribution(B, JointDistribution(C, D)), ... E, ... flatten_method="shallow", ... ) >>> len(jd.marginal_distributions) # Shallow flattening 4 >>> jd = JointDistribution( ... A, ... JointDistribution(B, JointDistribution(C, D)), ... E, ... flatten_method="deep", ... ) >>> len(jd.marginal_distributions) # Deep flattening 5
- log_prob(value: jaxtyping.Array) jaxtyping.Array[source]¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters:
value – A batch of samples from the distribution.
- Returns:
an array with shape value.shape[:-self.event_shape]
- Return type:
ArrayLike
- sample(key: jaxtyping.PRNGKeyArray, sample_shape: tuple[int, Ellipsis] = ())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters:
key (jax.random.key) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns:
an array of shape sample_shape + batch_shape + event_shape
- Return type:
- support() numpyro.distributions.constraints.Constraint¶
The support of the joint distribution.
- marginal_distributions: collections.abc.Sequence[numpyro.distributions.Distribution] = ()¶
- shaped_values: collections.abc.Sequence[int | Tuple[int, int]] = ()¶
- class gwkokab.models.utils.LazyJointDistribution(*marginal_distributions: numpyro.distributions.Distribution | jax.tree_util.Partial, dependencies: Dict[int, Dict[str, int]], partial_order: List[int], dependencies_event_shape: List[Tuple[int, Ellipsis]] | None = None, flatten_method: Literal['deep', 'shallow'] | None = None, support: numpyro.distributions.constraints.Constraint | None = None, validate_args: bool | None = None)[source]¶
Bases:
numpyro.distributions.DistributionBase class for probability distributions in NumPyro. The design largely follows from
torch.distributions.- Parameters:
batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
As an example:
Construct a joint distribution from one or more marginal distributions.
You may pass individual Distribution instances or nest them inside
LazyJointDistribution`s. The `flatten_methodargument allows flattening of nested joints into a single flat list of marginals.- Parameters:
marginal_distributions (*Union[Distribution, jax.tree_util.Partial]) – One or more marginal distributions. Each marginal distribution can be an instance of numpyro.distributions.Distribution or a jax.tree_util.Partial that returns a numpyro.distributions.Distribution when called with its arguments.
dependencies (Dict[int, Dict[str, int]]) – A dictionary mapping the index of each marginal distribution that is a jax.tree_util.Partial to another dictionary that maps the names of its dependency parameters to the indices of the marginal distributions they depend on. This is used to specify which variables each lazy variable depends on.
partial_order (Optional[Tuple[str, int, int]], optional) – A tuple defining a partial order for the lazy variables. Each entry in the tuple should be of the form (var_name, event_index, marginal_index), such that elements coming earlier in the tuple are not dependent on elements coming later. This is used to ensure that when sampling from the joint distribution, the lazy variables are sampled in an order that respects their dependencies.
dependencies_event_shape (Optional[List[Tuple[int, ...]]], optional) – A list of event shapes for the dependencies of each marginal distribution. This is used to validate the shapes of the dependency variables when constructing
flatten_method (Optional[Literal["deep", "shallow"]], optional) – Currently not used. If “shallow”, one level of nested
LazyJointDistributionswill be flattened. If “deep”, all levels of nestedLazyJointDistributionswill be recursively flattened. If None (default), the nesting is preserved as-is.support (Optional[constraints.Constraint], optional) – The constraint object representing the support of the joint distribution. If not provided, it is computed from the support of the marginals.
validate_args (Optional[bool], optional) – Whether to validate distribution parameters and inputs. Default is None.
- Raises:
ValueError – If no marginal distributions are provided.
- log_prob(value: jaxtyping.Array) jaxtyping.Array[source]¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters:
value – A batch of samples from the distribution.
- Returns:
an array with shape value.shape[:-self.event_shape]
- Return type:
ArrayLike
- sample(key: jaxtyping.PRNGKeyArray, sample_shape: tuple[int, Ellipsis] = ())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters:
key (jax.random.key) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns:
an array of shape sample_shape + batch_shape + event_shape
- Return type:
- support() numpyro.distributions.constraints.Constraint¶
The support of the joint distribution.
- dependencies¶
- marginal_distributions: collections.abc.Sequence[numpyro.distributions.Distribution | jax.tree_util.Partial] = ()¶
- partial_order¶
- shaped_values: collections.abc.Sequence[int | Tuple[int, int]] = ()¶
- class gwkokab.models.utils.ScaledMixture(log_scales: jaxtyping.Array, component_distributions: List[numpyro.distributions.Distribution], *, support: numpyro.distributions.constraints.Constraint | None = None, validate_args: bool | None = None)[source]¶
Bases:
numpyro.distributions.DistributionA finite mixture of component distributions from different families. This is a generalization of
Mixturewhere the component distributions are scaled by a set of rates.Example
>>> import jax >>> import jax.random as jrd >>> import numpyro.distributions as dist >>> from gwkokab.models.utils import ScaledMixture >>> log_scales = jrd.uniform(jrd.key(42), (3,), minval=0, maxval=5) >>> component_dists = [ ... dist.Normal(loc=0.0, scale=1.0), ... dist.Normal(loc=-0.5, scale=0.3), ... dist.Normal(loc=0.6, scale=1.2), ... ] >>> mixture = ScaledMixture(log_scales, component_dists) >>> mixture.sample(jax.random.key(42)).shape ()
- cdf(samples)[source]¶
The cumulative distribution function.
- Parameters:
value – samples from this distribution.
- Returns:
output of the cumulative distribution function evaluated at value.
- Raises:
NotImplementedError if the component distribution does not implement the cdf method.
- log_prob(value, intermediates=None)[source]¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters:
value – A batch of samples from the distribution.
- Returns:
an array with shape value.shape[:-self.event_shape]
- Return type:
ArrayLike
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters:
key (jax.random.key) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns:
an array of shape sample_shape + batch_shape + event_shape
- Return type:
- sample_with_intermediates(key, sample_shape=())[source]¶
A version of
samplethat also returns the sampled component indices.
- support()¶
The support of this distribution. Subclasses can override this as a class attribute or as a property.
- property component_distributions¶
The list of component distributions in the mixture.
- Returns:
The list of component distributions
- Return type:
list[Distribution]
- property component_mean¶
- property component_variance¶
- property is_discrete¶
- log_scales¶
- property mean¶
Mean of the distribution.
- property mixture_dim¶
- property mixture_size¶
The number of components in the mixture.
- property variance¶
Variance of the distribution.