Source code for arviz_stats.loo.loo
"""Pareto-smoothed importance sampling LOO (PSIS-LOO-CV)."""
from arviz_base import rcParams
from arviz_stats.loo.helper_loo import (
_compute_loo_results,
_get_r_eff,
_prepare_loo_inputs,
)
[docs]
def loo(data, pointwise=None, var_name=None, reff=None):
"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
importance sampling leave-one-out cross-validation (PSIS-LOO-CV). Also calculates LOO's
standard error and the effective number of parameters. The method is described in [1]_
and [2]_.
Parameters
----------
data : DataTree or InferenceData
Input data. It should contain the posterior and the log_likelihood groups.
pointwise: bool, optional
If True the pointwise predictive accuracy will be returned. Defaults to
``rcParams["stats.ic_pointwise"]``.
var_name : str, optional
The name of the variable in log_likelihood groups storing the pointwise log
likelihood data to use for loo computation.
reff: float, optional
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
of actual samples. Computed from trace by default.
Returns
-------
ELPDData
Object with the following attributes:
- **elpd**: expected log pointwise predictive density
- **se**: standard error of the elpd
- **p**: effective number of parameters
- **n_samples**: number of samples
- **n_data_points**: number of data points
- **warning**: True if the estimated shape parameter of Pareto distribution is greater
than ``good_k``.
- **elp_i**: :class:`~xarray.DataArray` with the pointwise predictive accuracy, only if
``pointwise=True``
- **pareto_k**: array of Pareto shape values, only if ``pointwise=True``
- **good_k**: For a sample size S, the threshold is computed as
``min(1 - 1/log10(S), 0.7)``
- **approx_posterior**: True if approximate posterior was used.
Examples
--------
Calculate LOO of a model:
.. ipython::
In [1]: from arviz_stats import loo
...: from arviz_base import load_arviz_data
...: data = load_arviz_data("centered_eight")
...: loo_data = loo(data)
...: loo_data
Return the pointwise values:
.. ipython::
In [2]: loo_data.elpd_i
See Also
--------
:func:`compare` : Compare models based on their ELPD.
:func:`arviz_plots.plot_compare`: Summary plot for model comparison.
References
----------
.. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4
arXiv preprint https://arxiv.org/abs/1507.04544.
.. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*.
Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
arXiv preprint https://arxiv.org/abs/1507.02646
"""
loo_inputs = _prepare_loo_inputs(data, var_name)
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
if reff is None:
reff = _get_r_eff(data, loo_inputs.n_samples)
log_weights, pareto_k = loo_inputs.log_likelihood.azstats.psislw(
r_eff=reff, dim=loo_inputs.sample_dims
)
return _compute_loo_results(
log_likelihood=loo_inputs.log_likelihood,
var_name=loo_inputs.var_name,
pointwise=pointwise,
sample_dims=loo_inputs.sample_dims,
n_samples=loo_inputs.n_samples,
n_data_points=loo_inputs.n_data_points,
log_weights=log_weights,
pareto_k=pareto_k,
approx_posterior=False,
)