Skip to content

Analysing variability and compensation mechansims with conditional distributions

A central advantage of sbi over parameter search methods such as genetic algorithms is that the posterior captures all models that can reproduce experimental data. This allows us to analyse whether parameters can be variable or have to be narrowly tuned, and to analyse compensation mechanisms between different parameters. See also Marder and Taylor, 2006 for further motivation to identify all models that capture experimental data.

In this tutorial, we will show how one can use the posterior distribution to identify whether parameters can be variable or have to be finely tuned, and how we can use the posterior to find potential compensation mechanisms between model parameters. To investigate this, we will extract conditional distributions from the posterior inferred with sbi.

Note, you can find the original version of this notebook at https://github.com/mackelab/sbi/blob/main/tutorials/07_conditional_distributions.ipynb in the sbi repository.

Main syntax

from sbi.analysis import conditional_pairplot, conditional_corrcoeff

# Plot slices through posterior, i.e. conditionals.
_ = conditional_pairplot(
    density=posterior,
    condition=posterior.sample((1,)),
    limits=torch.tensor([[-2., 2.], [-2., 2.]]),
)

# Compute the matrix of correlation coefficients of the slices.
cond_coeff_mat = conditional_corrcoeff(
    density=posterior,
    condition=posterior.sample((1,)),
    limits=torch.tensor([[-2., 2.], [-2., 2.]]),
)
plt.imshow(cond_coeff_mat, clim=[-1, 1])

Analysing variability and compensation mechanisms in a toy example

Below, we use a simple toy example to demonstrate the above described features. For an application of these features to a neuroscience problem, see figure 6 in Gonçalves, Lueckmann, Deistler et al., 2019.

from sbi import utils as utils
from sbi.analysis import pairplot, conditional_pairplot, conditional_corrcoeff
import torch
import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import animation, rc
from IPython.display import HTML, Image

_ = torch.manual_seed(0)

Let’s say we have used SNPE to obtain a posterior distribution over three parameters. In this tutorial, we just load the posterior from a file:

from toy_posterior_for_07_cc import ExamplePosterior
posterior = ExamplePosterior()

First, we specify the experimental observation \(x_o\) at which we want to evaluate and sample the posterior \(p(\theta|x_o)\):

x_o = torch.ones(1, 20)  # simulator output was 20-dimensional
posterior.set_default_x(x_o)

As always, we can inspect the posterior marginals with the pairplot() function:

posterior_samples = posterior.sample((5000,))

fig, ax = pairplot(
    samples=posterior_samples,
    limits=torch.tensor([[-2., 2.]]*3),
    upper=['kde'],
    diag=['kde'],
    figsize=(5,5)
)

png

The 1D and 2D marginals of the posterior fill almost the entire parameter space! Also, the Pearson correlation coefficient matrix of the marginal shows rather weak interactions (low correlations):

corr_matrix_marginal = np.corrcoef(posterior_samples.T)
fig, ax = plt.subplots(1,1, figsize=(4, 4))
im = plt.imshow(corr_matrix_marginal, clim=[-1, 1], cmap='PiYG')
_ = fig.colorbar(im)

png

It might be tempting to conclude that the experimental data barely constrains our parameters and that almost all parameter combinations can reproduce the experimental data. As we will show below, this is not the case.

Because our toy posterior has only three parameters, we can plot posterior samples in a 3D plot:

rc('animation', html='html5')

# First set up the figure, the axis, and the plot element we want to animate
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')

ax.set_xlim((-2, 2))
ax.set_ylim((-2, 2))

def init():
    line, = ax.plot([], [], lw=2)
    line.set_data([], [])
    return (line,)

def animate(angle):
    num_samples_vis = 1000
    line = ax.scatter(posterior_samples[:num_samples_vis, 0], posterior_samples[:num_samples_vis, 1], posterior_samples[:num_samples_vis, 2], zdir='z', s=15, c='#2171b5', depthshade=False)
    ax.view_init(20, angle)
    return (line,)

anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=range(0,360,5), interval=150, blit=True)

plt.close()
HTML(anim.to_html5_video())