"""Base classes for Bayesian analysis.
"""
from __future__ import annotations
import warnings
from typing import Any, Sequence
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import multivariate_normal, norm, rv_continuous, wasserstein_distance
from ..base import ColumnType, ModelBase, Numeric1DArray, ResultsBase, ColumnsType
from ..stats import joint_distribution
from ..utils import weighted_quantile
[docs]class BayesResults(ResultsBase):
"""Results of Bayesian analysis.
Inherits from :class:`multiple_inference.base.ResultsBase`.
Args:
*args (Any): Passed to :class:`multiple_inference.base.ResultsBase`.
n_samples (int): Number of samples used for approximations (ranking, likelihood
and Wasserstein distance). Defaults to 10000.
**kwargs (Any): Passed to :class:`multiple_inference.base.ResultsBase`.
Attributes:
distributions (List[scipy.stats.norm]): Marginal posterior distributions.
multivariate_distribution (scipy.stats.multivariate_normal): Joint posterior
distribution.
rank_df (pd.DataFrame): (n, n) dataframe of probabilities that column i has
rank j.
"""
_default_title = "Bayesian estimates"
def __init__(self, *args: Any, n_samples: int = 10000, **kwargs: Any):
super().__init__(*args, **kwargs)
# get the marginal (posterior) distributions, parameters, and pvalues
self.marginal_distributions, params, pvalues = [], [], []
for i in range(self.model.n_params):
dist = self.model.get_marginal_distribution(i)
self.marginal_distributions.append(dist)
params.append(dist.mean())
pvalues.append(dist.cdf(0))
self.params = np.array(params).squeeze()
self.pvalues = np.array(pvalues).squeeze()
self._n_samples = n_samples
self._sample_weight = np.full(n_samples, 1 / n_samples)
@property
def _posterior_rvs(self):
if hasattr(self, "_cached_posterior_rvs"):
return self._cached_posterior_rvs
# estimate the parameter rankings by drawing from the posterior
try:
self.joint_distribution = self.model.get_joint_distribution()
self._cached_posterior_rvs = self.joint_distribution.rvs(
size=self._n_samples
)
except NotImplementedError:
warnings.warn(
"Model does not provide a joint posterior distribution."
" I'll assume the marginal posterior distributions are independent."
" Rank estimates and likelihood and Wasserstein approximations may be"
" unreliable."
)
self._cached_posterior_rvs = joint_distribution(
self.marginal_distributions
).rvs(size=self._n_samples)
return self._cached_posterior_rvs
@property
def _reconstructed_rvs(self):
if hasattr(self, "_cached_reconstructed_rvs"):
return self._cached_reconstructed_rvs
self._cached_reconstructed_rvs = np.apply_along_axis(
lambda mean: multivariate_normal.rvs(mean, self.model.cov),
1,
self._posterior_rvs,
)
return self._cached_reconstructed_rvs
@property
def rank_df(self):
if hasattr(self, "_cached_rank_df"):
return self._cached_rank_df
argsort = np.argsort(-self._posterior_rvs, axis=1)
rank_matrix = np.array(
[
((argsort == k).T * self._sample_weight).sum(axis=1)
for k in range(self.model.n_params)
]
).T
self._cached_rank_df = pd.DataFrame(
rank_matrix,
columns=self.model.exog_names,
index=np.arange(1, self.model.n_params + 1),
)
self._cached_rank_df.index.name = "Rank"
return self._cached_rank_df
[docs] def expected_wasserstein_distance(
self, mean: Numeric1DArray = None, cov: np.ndarray = None, **kwargs: Any
) -> float:
"""Compute the Wasserstein distance metric.
This method estimates the Wasserstein distance between the observed
distribution (a joint normal characterized by ``mean`` and ``cov``) and the
distribution you would expect to observe according to this model.
Args:
mean (Numeric1DArray, optional): (# params,) array of sample conventionally
estimated means. If None, use the model's estimated means. Defaults to
None.
cov (np.ndarray, optional): (# params, # params) covaraince matrix for
conventionally estimated means. If None, use the model's estimated
covariance matrix. Defaults to None.
**kwargs (Any): Keyword arguments for ``scipy.stats.wasserstein_distance``.
Returns:
float: Expected Wasserstein distance.
"""
if mean is None and cov is None:
mean = self.params
reconstructed_rvs = self._reconstructed_rvs
else:
if cov is None:
cov = self.model.cov
reconstructed_rvs = np.apply_along_axis(
lambda mean: multivariate_normal.rvs(mean, cov), 1, self._posterior_rvs
)
distances = np.apply_along_axis(
lambda rv: wasserstein_distance(rv, mean, **kwargs), 1, reconstructed_rvs
)
return (self._sample_weight * distances).sum()
[docs] def likelihood(self, mean: Numeric1DArray = None, cov: np.ndarray = None) -> float:
"""
Args:
mean (Numeric1DArray, optional): (# params,) array of sample conventionally
estimated means. If None, use the model's estimated means. Defaults to
None.
cov (np.ndarray, optional): (# params, # params) covaraince matrix for
conventionally estimated means. If None, use the model's estimated
covariance matrix. Defaults to None.
Returns:
float: Likelihood.
"""
if mean is None:
mean = self.model.mean
if cov is None:
cov = self.model.cov
return (
self._sample_weight
* multivariate_normal.pdf(self._posterior_rvs, mean, cov)
).sum()
[docs] def line_plot(
self,
column: ColumnType = None,
alpha: float = 0.05,
title: str = None,
yname: str = None,
ax = None
):
"""Create a line plot of the prior, conventional, and posterior estimates.
Args:
column (ColumnType, optional): Selected parameter. Defaults to None.
alpha (float, optional): Sets the plot width. 0 is as wide as possible, 1 is
as narrow as possible. Defaults to .05.
title (str, optional): Plot title. Defaults to None.
yname (str, optional): Name of the dependent variable. Defaults to None.
ax (AxesSubplot, optional): Axis to write on.
Returns:
AxesSubplot: Plot.
"""
index = self.model.get_index(column)
prior = self.model.get_marginal_prior(index)
posterior = self.marginal_distributions[index]
conventional = norm(
self.model.mean[index], np.sqrt(self.model.cov[index, index])
)
xlim = np.array(
[
dist.ppf([alpha / 2, 1 - alpha / 2])
for dist in (prior, conventional, posterior)
]
).T
x = np.linspace(xlim[0].min(), xlim[1].max())
palette = sns.color_palette()
if ax is None:
_, ax = plt.subplots()
sns.lineplot(x=x, y=prior.pdf(x), label="prior", ax=ax)
ax.axvline(prior.mean(), linestyle="--", color=palette[0])
sns.lineplot(x=x, y=conventional.pdf(x), label="conventional")
ax.axvline(conventional.mean(), linestyle="--", color=palette[1])
sns.lineplot(x=x, y=posterior.pdf(x), label="posterior")
ax.axvline(posterior.mean(), linestyle="--", color=palette[2])
ax.set_title(title or self.model.exog_names[index])
ax.set_xlabel(yname or self.model.endog_names)
return ax
[docs] def rank_matrix_plot(self, title: str = None, **kwargs: Any):
"""Plot a heatmap of the rank matrix.
Args:
title (str, optional): Plot title. Defaults to None.
**kwargs (Any): Passed to ``sns.heatmap``.
Returns:
AxesSubplot: Heatmap.
"""
ax = sns.heatmap(self.rank_df, center=1 / self.model.n_params, **kwargs)
ax.set_title(title or f"{self.title} rank matrix")
return ax
[docs] def reconstruction_point_plot(
self,
yname: str = None,
xname: Sequence[str] = None,
title: str = None,
alpha: float = 0.05,
ax=None,
):
"""Create point plot of the reconstructed sample means.
Plots the distribution of sample means you would expect to see if this model
were correct.
Args:
yname (str, optional): Name of the endogenous variable. Defaults to None.
xname (Sequence[str], optional): Names of x-ticks. Defaults to None.
title (str, optional): Plot title. Defaults to None.
alpha: (float, optional): Plot the 1-alpha CI. Defaults to 0.05.
ax: (AxesSubplot, optional): Axis to write on.
Returns:
plt.axes._subplots.AxesSubplot: Plot.
"""
reconstructed_means = -np.sort(-self._reconstructed_rvs)
params = np.average(reconstructed_means, axis=0, weights=self._sample_weight)
conf_int = np.apply_along_axis(
weighted_quantile,
0,
reconstructed_means,
quantiles=[alpha / 2, 1 - alpha / 2],
sample_weight=self._sample_weight,
).T
xname = xname or np.arange(self.model.n_params)
yticks = np.arange(len(xname), 0, -1)
if ax is None:
_, ax = plt.subplots()
ax.errorbar(
x=params,
y=yticks,
xerr=[params - conf_int[:, 0], conf_int[:, 1] - params],
fmt="o",
)
ax.set_title(title or f"{self.title} reconstruction plot")
ax.set_xlabel(yname or self.model.endog_names)
ax.set_ylabel("rank")
ax.set_yticks(yticks)
ax.set_yticklabels(xname)
ax.errorbar(x=-np.sort(-self.model.mean), y=yticks, fmt="x")
return ax
def _make_summary_header(self, alpha: float) -> list[str]:
return ["coef", "pvalue (1-sided)", f"[{alpha/2}", f"{1-alpha/2}]"]
[docs]class BayesBase(ModelBase):
"""Mixin for Bayesian models.
Subclasses :class:`multiple_inference.base.ModelBase`.
"""
_results_cls = BayesResults
[docs] def get_marginal_prior(self, column: ColumnType) -> rv_continuous:
"""Get the marginal prior distribution of ``column``.
Args:
column (ColumnType): Name or index of the parameter of interest.
Returns:
rv_continuous: Prior distribution
"""
return self._get_marginal_prior(self.get_index(column))
def _get_marginal_prior(self, index: int) -> rv_continuous:
"""Private version of :meth:`self.get_marginal_prior`."""
raise NotImplementedError()
[docs] def get_marginal_distribution(self, column: ColumnType) -> rv_continuous:
"""Get the marginal posterior distribution of ``column``.
Args:
column (ColumnType): Name or index of the parameter of interest.
Returns:
rv_continuous: Posterior distribution.
"""
return self._get_marginal_distribution(self.get_index(column))
def _get_marginal_distribution(self, index: int) -> rv_continuous:
"""Private version of :meth:`self.get_marginal_distribution`."""
raise NotImplementedError()
[docs] def get_joint_prior(self, columns: ColumnsType = None):
"""Get the joint prior distribution.
Args:
columns (ColumnsType, optional): Selected columns. Defaults to None.
Returns:
rv_like: Joint distribution.
"""
return self._get_joint_prior(self.get_indices(columns))
def _get_joint_prior(self, indices: np.ndarray):
"""Private version of :meth:`self.get_joint_prior`."""
return joint_distribution([self.get_marginal_prior(i) for i in indices])
[docs] def get_joint_distribution(self, columns: ColumnsType = None):
"""Get the joint posterior distribution.
Args:
columns (ColumnsType, optional): Selected columns. Defaults to None.
Returns:
rv_like: Joint distribution.
"""
return self._get_joint_distribution(self.get_indices(columns))
def _get_joint_distribution(self, indices: np.ndarray):
"""Private version of :meth:`self.get_joint_distribution`."""
raise NotImplementedError()