gwkokab.models.utils ==================== .. py:module:: gwkokab.models.utils Classes ------- .. autoapisummary:: gwkokab.models.utils.DoublyTruncatedPowerLaw gwkokab.models.utils.ExtendedSupportTransformedDistribution gwkokab.models.utils.JointDistribution gwkokab.models.utils.LazyJointDistribution gwkokab.models.utils.ScaledMixture Functions --------- .. autoapisummary:: gwkokab.models.utils.doubly_truncated_power_law_cdf gwkokab.models.utils.doubly_truncated_power_law_icdf gwkokab.models.utils.doubly_truncated_power_law_log_norm_constant gwkokab.models.utils.doubly_truncated_power_law_log_prob Package Contents ---------------- .. py:class:: DoublyTruncatedPowerLaw(alpha: jaxtyping.ArrayLike, low: jaxtyping.ArrayLike, high: jaxtyping.ArrayLike, *, validate_args: Optional[bool] = None) Bases: :py:obj:`numpyro.distributions.Distribution` Power law distribution with :math:`\alpha` index, and lower and upper bounds. We can define the power law distribution as, .. math:: f(x; \alpha, a, b) = \frac{x^{\alpha}}{Z(\alpha, a, b)}, where, :math:`a` and :math:`b` are the lower and upper bounds respectively, and :math:`Z(\alpha, a, b)` is the normalization constant. It is defined as, .. math:: Z(\alpha, a, b) = \begin{cases} \log(b) - \log(a) & \text{if } \alpha = -1, \\ \frac{b^{1 + \alpha} - a^{1 + \alpha}}{1 + \alpha} & \text{otherwise}. \end{cases} :param alpha: index of the power law distribution :param low: lower bound of the distribution :param high: upper bound of the distribution .. py:method:: cdf(value: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike Cumulated probability distribution: Z inequal minus one: .. math:: \frac{x^{\alpha + 1} - a^{\alpha + 1}}{b^{\alpha + 1} - a^{\alpha + 1}} Z equal minus one: .. math:: \frac{\log(x) - \log(a)}{\log(b) - \log(a)} Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. .. py:method:: icdf(q: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike Inverse cumulated probability distribution: Z inequal minus one: .. math:: a \left(\frac{b}{a}\right)^{q} Z equal minus one: .. math:: \left(a^{1 + \alpha} + q (b^{1 + \alpha} - a^{1 + \alpha})\right)^{\frac{1}{1 + \alpha}} Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. .. py:method:: log_prob(value: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike Logarithmic probability distribution: Z inequal minus one: .. math:: (x^\alpha) (\alpha + 1)/(b^(\alpha + 1) - a^(\alpha + 1)) Z equal minus one: .. math:: (x^\alpha)/(log(b) - log(a)) Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. .. py:method:: sample(key: jax.dtypes.prng_key, sample_shape: tuple[int, Ellipsis] = ()) -> jaxtyping.ArrayLike Returns a sample from the distribution having shape given by `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty, leading dimensions (of size `sample_shape`) of the returned sample will be filled with iid draws from the distribution instance. :param jax.random.key key: the rng_key key to be used for the distribution. :param tuple sample_shape: the sample shape for the distribution. :return: an array of shape `sample_shape + batch_shape + event_shape` :rtype: numpy.ndarray .. py:method:: support() -> numpyro.distributions.constraints.Constraint The support of this distribution. Subclasses can override this as a class attribute or as a property. .. py:class:: ExtendedSupportTransformedDistribution(base_distribution: Distribution, transforms: Union[numpyro.distributions.transforms.Transform, list[numpyro.distributions.transforms.Transform]], *, validate_args: Optional[bool] = None) Bases: :py:obj:`numpyro.distributions.TransformedDistribution` Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see :class:`~numpyro.distributions.LogNormal` and :class:`~numpyro.distributions.HalfNormal`. :param base_distribution: the base distribution over which to apply transforms. :param transforms: a single transform or a list of transforms. :param validate_args: Whether to enable validation of distribution parameters and arguments to `.log_prob` method. .. py:property:: support The support of this distribution. Subclasses can override this as a class attribute or as a property. .. py:class:: JointDistribution(*marginal_distributions: numpyro.distributions.Distribution, flatten_method: Optional[Literal['deep', 'shallow']] = None, support: Optional[numpyro.distributions.constraints.Constraint] = None, validate_args: Optional[bool] = None) Bases: :py:obj:`numpyro.distributions.Distribution` Base class for probability distributions in NumPyro. The design largely follows from :mod:`torch.distributions`. :param batch_shape: The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters. :param event_shape: The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using `.log_prob`. :param validate_args: Whether to enable validation of distribution parameters and arguments to `.log_prob` method. As an example: .. doctest:: >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> d = dist.Dirichlet(jnp.ones((2, 3, 4))) >>> d.batch_shape (2, 3) >>> d.event_shape (4,) Construct a joint distribution from one or more marginal distributions. You may pass individual `Distribution` instances or nest them inside :class:`JointDistribution`s. The `flatten_method` argument allows flattening of nested joints into a single flat list of marginals. :param marginal_distributions: One or more `Distribution` objects (or nested :class:`JointDistribution`s) that form the components of the joint distribution. :type marginal_distributions: *Distribution :param flatten_method: If "shallow", one level of nested :class:`JointDistributions` will be flattened. If "deep", all levels of nested :class:`JointDistributions` will be recursively flattened. If None (default), the nesting is preserved as-is. :type flatten_method: Optional[Literal["deep", "shallow"]], optional :param support: The constraint object representing the support of the joint distribution. If not provided, it is computed from the support of the marginals. :type support: Optional[constraints.Constraint], optional :param validate_args: Whether to validate distribution parameters and inputs. Default is None. :type validate_args: Optional[bool], optional :raises ValueError: If no marginal distributions are provided. .. rubric:: Example .. code:: >>> from numpyro.distributions import Normal >>> from gwkokab.models.utils import JointDistribution >>> A = Normal(0, 1) >>> B = Normal(1, 1) >>> C = Normal(2, 1) >>> D = Normal(3, 1) >>> E = Normal(4, 1) >>> jd = JointDistribution( ... A, JointDistribution(B, JointDistribution(C, D)), E ... ) >>> len(jd.marginal_distributions) # No flattening (default) 3 >>> jd = JointDistribution( ... A, ... JointDistribution(B, JointDistribution(C, D)), ... E, ... flatten_method="shallow", ... ) >>> len(jd.marginal_distributions) # Shallow flattening 4 >>> jd = JointDistribution( ... A, ... JointDistribution(B, JointDistribution(C, D)), ... E, ... flatten_method="deep", ... ) >>> len(jd.marginal_distributions) # Deep flattening 5 .. py:method:: log_prob(value: jaxtyping.Array) -> jaxtyping.Array Evaluates the log probability density for a batch of samples given by `value`. :param value: A batch of samples from the distribution. :return: an array with shape `value.shape[:-self.event_shape]` :rtype: ArrayLike .. py:method:: marginal_log_probs(value: jaxtyping.Array) -> jaxtyping.Array .. py:method:: sample(key: jaxtyping.PRNGKeyArray, sample_shape: tuple[int, Ellipsis] = ()) Returns a sample from the distribution having shape given by `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty, leading dimensions (of size `sample_shape`) of the returned sample will be filled with iid draws from the distribution instance. :param jax.random.key key: the rng_key key to be used for the distribution. :param tuple sample_shape: the sample shape for the distribution. :return: an array of shape `sample_shape + batch_shape + event_shape` :rtype: numpy.ndarray .. py:method:: support() -> numpyro.distributions.constraints.Constraint The support of the joint distribution. .. py:attribute:: marginal_distributions :type: collections.abc.Sequence[numpyro.distributions.Distribution] :value: () .. py:attribute:: shaped_values :type: collections.abc.Sequence[int | Tuple[int, int]] :value: () .. py:class:: LazyJointDistribution(*marginal_distributions: Union[numpyro.distributions.Distribution, jax.tree_util.Partial], dependencies: Dict[int, Dict[str, int]], partial_order: List[int], dependencies_event_shape: Optional[List[Tuple[int, Ellipsis]]] = None, flatten_method: Optional[Literal['deep', 'shallow']] = None, support: Optional[numpyro.distributions.constraints.Constraint] = None, validate_args: Optional[bool] = None) Bases: :py:obj:`numpyro.distributions.Distribution` Base class for probability distributions in NumPyro. The design largely follows from :mod:`torch.distributions`. :param batch_shape: The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters. :param event_shape: The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using `.log_prob`. :param validate_args: Whether to enable validation of distribution parameters and arguments to `.log_prob` method. As an example: .. doctest:: >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> d = dist.Dirichlet(jnp.ones((2, 3, 4))) >>> d.batch_shape (2, 3) >>> d.event_shape (4,) Construct a joint distribution from one or more marginal distributions. You may pass individual `Distribution` instances or nest them inside :class:`LazyJointDistribution`s. The `flatten_method` argument allows flattening of nested joints into a single flat list of marginals. :param marginal_distributions: One or more marginal distributions. Each marginal distribution can be an instance of `numpyro.distributions.Distribution` or a `jax.tree_util.Partial` that returns a `numpyro.distributions.Distribution` when called with its arguments. :type marginal_distributions: *Union[Distribution, jax.tree_util.Partial] :param dependencies: A dictionary mapping the index of each marginal distribution that is a `jax.tree_util.Partial` to another dictionary that maps the names of its dependency parameters to the indices of the marginal distributions they depend on. This is used to specify which variables each lazy variable depends on. :type dependencies: Dict[int, Dict[str, int]] :param partial_order: A tuple defining a partial order for the lazy variables. Each entry in the tuple should be of the form (var_name, event_index, marginal_index), such that elements coming earlier in the tuple are not dependent on elements coming later. This is used to ensure that when sampling from the joint distribution, the lazy variables are sampled in an order that respects their dependencies. :type partial_order: Optional[Tuple[str, int, int]], optional :param dependencies_event_shape: A list of event shapes for the dependencies of each marginal distribution. This is used to validate the shapes of the dependency variables when constructing :type dependencies_event_shape: Optional[List[Tuple[int, ...]]], optional :param flatten_method: Currently not used. If "shallow", one level of nested :class:`LazyJointDistributions` will be flattened. If "deep", all levels of nested :class:`LazyJointDistributions` will be recursively flattened. If None (default), the nesting is preserved as-is. :type flatten_method: Optional[Literal["deep", "shallow"]], optional :param support: The constraint object representing the support of the joint distribution. If not provided, it is computed from the support of the marginals. :type support: Optional[constraints.Constraint], optional :param validate_args: Whether to validate distribution parameters and inputs. Default is None. :type validate_args: Optional[bool], optional :raises ValueError: If no marginal distributions are provided. .. py:method:: log_prob(value: jaxtyping.Array) -> jaxtyping.Array Evaluates the log probability density for a batch of samples given by `value`. :param value: A batch of samples from the distribution. :return: an array with shape `value.shape[:-self.event_shape]` :rtype: ArrayLike .. py:method:: sample(key: jaxtyping.PRNGKeyArray, sample_shape: tuple[int, Ellipsis] = ()) Returns a sample from the distribution having shape given by `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty, leading dimensions (of size `sample_shape`) of the returned sample will be filled with iid draws from the distribution instance. :param jax.random.key key: the rng_key key to be used for the distribution. :param tuple sample_shape: the sample shape for the distribution. :return: an array of shape `sample_shape + batch_shape + event_shape` :rtype: numpy.ndarray .. py:method:: support() -> numpyro.distributions.constraints.Constraint The support of the joint distribution. .. py:attribute:: dependencies .. py:attribute:: marginal_distributions :type: collections.abc.Sequence[Union[numpyro.distributions.Distribution, jax.tree_util.Partial]] :value: () .. py:attribute:: partial_order .. py:attribute:: shaped_values :type: collections.abc.Sequence[Union[int, Tuple[int, int]]] :value: () .. py:class:: ScaledMixture(log_scales: jaxtyping.Array, component_distributions: List[numpyro.distributions.Distribution], *, support: Optional[numpyro.distributions.constraints.Constraint] = None, validate_args: Optional[bool] = None) Bases: :py:obj:`numpyro.distributions.Distribution` A finite mixture of component distributions from different families. This is a generalization of :class:`~numpyro.distributions.Mixture` where the component distributions are scaled by a set of rates. **Example** .. code:: >>> import jax >>> import jax.random as jrd >>> import numpyro.distributions as dist >>> from gwkokab.models.utils import ScaledMixture >>> log_scales = jrd.uniform(jrd.key(42), (3,), minval=0, maxval=5) >>> component_dists = [ ... dist.Normal(loc=0.0, scale=1.0), ... dist.Normal(loc=-0.5, scale=0.3), ... dist.Normal(loc=0.6, scale=1.2), ... ] >>> mixture = ScaledMixture(log_scales, component_dists) >>> mixture.sample(jax.random.key(42)).shape () .. py:method:: cdf(samples) The cumulative distribution function. :param value: samples from this distribution. :return: output of the cumulative distribution function evaluated at `value`. :raises: NotImplementedError if the component distribution does not implement the cdf method. .. py:method:: component_cdf(samples) .. py:method:: component_log_probs(value: jaxtyping.ArrayLike) -> jaxtyping.ArrayLike .. py:method:: component_sample(key, sample_shape=()) .. py:method:: log_prob(value, intermediates=None) Evaluates the log probability density for a batch of samples given by `value`. :param value: A batch of samples from the distribution. :return: an array with shape `value.shape[:-self.event_shape]` :rtype: ArrayLike .. py:method:: sample(key, sample_shape=()) Returns a sample from the distribution having shape given by `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty, leading dimensions (of size `sample_shape`) of the returned sample will be filled with iid draws from the distribution instance. :param jax.random.key key: the rng_key key to be used for the distribution. :param tuple sample_shape: the sample shape for the distribution. :return: an array of shape `sample_shape + batch_shape + event_shape` :rtype: numpy.ndarray .. py:method:: sample_with_intermediates(key, sample_shape=()) A version of ``sample`` that also returns the sampled component indices. :param jax.random.PRNGKey key: the rng_key key to be used for the distribution. :param tuple sample_shape: the sample shape for the distribution. :return: A 2-element tuple with the samples from the distribution, and the indices of the sampled components. :rtype: tuple .. py:method:: support() The support of this distribution. Subclasses can override this as a class attribute or as a property. .. py:property:: component_distributions The list of component distributions in the mixture. :return: The list of component distributions :rtype: list[Distribution] .. py:property:: component_mean .. py:property:: component_variance .. py:property:: is_discrete .. py:attribute:: log_scales .. py:property:: mean Mean of the distribution. .. py:property:: mixture_dim .. py:property:: mixture_size The number of components in the mixture. .. py:property:: variance Variance of the distribution. .. py:function:: doubly_truncated_power_law_cdf(x, alpha, low, high) .. py:function:: doubly_truncated_power_law_icdf(q, alpha, low, high) .. py:function:: doubly_truncated_power_law_log_norm_constant(alpha, low, high) .. py:function:: doubly_truncated_power_law_log_prob(x, alpha, low, high)