Skip to content

Inference on Hodgkin-Huxley model: tutorial

In this tutorial, we use sbi to do inference on a Hodgkin-Huxley model from neuroscience (Hodgkin and Huxley, 1952). We will learn two parameters (\(\bar g_{Na}\),\(\bar g_K\)) based on a current-clamp recording, that we generate synthetically (in practice, this would be an experimental observation).

Note, you find the original version of this notebook at https://github.com/mackelab/sbi/blob/main/examples/00_HH_simulator.ipynb in the sbi repository.

First we are going to import basic packages.

import numpy as np
import torch

# visualization
import matplotlib as mpl
import matplotlib.pyplot as plt

# sbi
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
# remove top and right axis from plots
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

Different required components

Before running inference, let us define the different required components:

  1. observed data
  2. prior over model parameters
  3. simulator

1. Observed data

Let us assume we current-clamped a neuron and recorded the following voltage trace:


In fact, this voltage trace was not measured experimentally but synthetically generated by simulating a Hodgkin-Huxley model with particular parameters (\(\bar g_{Na}\),\(\bar g_K\)). We will come back to this point later in the tutorial.

2. Simulator

We would like to infer the posterior over the two parameters (\(\bar g_{Na}\),\(\bar g_K\)) of a Hodgkin-Huxley model, given the observed electrophysiological recording above. The model has channel kinetics as in Pospischil et al. 2008, and is defined by the following set of differential equations (parameters of interest highlighted in orange):

\[ \scriptsize \begin{align} C_m\frac{dV}{dt}&=g_1\left(E_1-V\right)+ \color{orange}{\bar{g}_{Na}}m^3h\left(E_{Na}-V\right)+ \color{orange}{\bar{g}_{K}}n^4\left(E_K-V\right)+ \bar{g}_Mp\left(E_K-V\right)+ I_{inj}+ \sigma\eta\left(t\right)\\ \frac{dq}{dt}&=\frac{q_\infty\left(V\right)-q}{\tau_q\left(V\right)},\;q\in\{m,h,n,p\} \end{align} \]

Above, \(V\) represents the membrane potential, \(C_m\) is the membrane capacitance, \(g_{\text{l}}\) is the leak conductance, \(E_{\text{l}}\) is the membrane reversal potential, \(\bar{g}_c\) is the density of channels of type \(c\) (\(\text{Na}^+\), \(\text{K}^+\), M), \(E_c\) is the reversal potential of \(c\), (\(m\), \(h\), \(n\), \(p\)) are the respective channel gating kinetic variables, and \(\sigma \eta(t)\) is the intrinsic neural noise. The right hand side of the voltage dynamics is composed of a leak current, a voltage-dependent \(\text{Na}^+\) current, a delayed-rectifier \(\text{K}^+\) current, a slow voltage-dependent \(\text{K}^+\) current responsible for spike-frequency adaptation, and an injected current \(I_{\text{inj}}\). Channel gating variables \(q\) have dynamics fully characterized by the neuron membrane potential \(V\), given the respective steady-state \(q_{\infty}(V)\) and time constant \(\tau_{q}(V)\) (details in Pospischil et al. 2008).

The input current \(I_{\text{inj}}\) is defined as

from HH_helper_functions import syn_current

I, t_on, t_off, dt, t, A_soma = syn_current()

The Hodgkin-Huxley simulator is given by:

from HH_helper_functions import HHsimulator

Putting the input current and the simulator together:

def run_HH_model(params):

    params = np.asarray(params)

    # input current, time step
    I, t_on, t_off, dt, t, A_soma = syn_current()

    t = np.arange(0, len(I), 1)*dt

    # initial voltage
    V0 = -70

    states = HHsimulator(V0, params.reshape(1, -1), dt, t, I)

    return dict(data=states.reshape(-1), time=t, dt=dt, I=I.reshape(-1))

To get an idea of the output of the Hodgkin-Huxley model, let us generate some voltage traces for different parameters (\(\bar g_{Na}\),\(\bar g_K\)), given the input current \(I_{\text{inj}}\):

# three sets of (g_Na, g_K)
params = np.array([[50., 1.],[4., 1.5],[20., 15.]])

num_samples = len(params[:,0])
sim_samples = np.zeros((num_samples, len(I)))
for i in range(num_samples):
    sim_samples[i,:] = run_HH_model(params=params[i,:])['data']
# colors for traces
col_min = 2
num_colors = num_samples+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

fig = plt.figure(figsize=(7,5))
gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
ax = plt.subplot(gs[0])
for i in range(num_samples):
    plt.plot(t,sim_samples[i,:],color=col1[i],lw=2)
plt.ylabel('voltage (mV)')
ax.set_xticks([])
ax.set_yticks([-80, -20, 40])

ax = plt.subplot(gs[1])
plt.plot(t,I*A_soma*1e3,'k', lw=2)
plt.xlabel('time (ms)')
plt.ylabel('input (nA)')

ax.set_xticks([0, max(t)/2, max(t)])
ax.set_yticks([0, 1.1*np.max(I*A_soma*1e3)])
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.2f'))
plt.show()

png

As can be seen, the voltage traces can be quite diverse for different parameter values.

Often, we are not interested in matching the exact trace, but only in matching certain features thereof. In this example of the Hodgkin-Huxley model, the summary features are the number of spikes, the mean resting potential, the standard deviation of the resting potential, and the first four voltage moments: mean, standard deviation, skewness and kurtosis. Using the function calculate_summary_statistics() imported below, we obtain these statistics from the output of the Hodgkin Huxley simulator.

from HH_helper_functions import calculate_summary_statistics

Lastly, we define a function that performs all of the above steps at once. The function simulation_wrapper takes in conductance values, runs the Hodgkin Huxley model and then returns the summary statistics.

def simulation_wrapper(params):
    """
    Returns summary statistics from conductance values in `params`.

    Summarizes the output of the HH simulator and converts it to `torch.Tensor`.
    """
    obs = run_HH_model(params)
    summstats = torch.as_tensor(calculate_summary_statistics(obs))
    return summstats

sbi takes any function as simulator. Thus, sbi also has the flexibility to use simulators that utilize external packages, e.g., Brian (http://briansimulator.org/), nest (https://www.nest-simulator.org/), or NEURON (https://neuron.yale.edu/neuron/). External simulators do not even need to be Python-based as long as they store simulation outputs in a format that can be read from Python. All that is necessary is to wrap your external simulator of choice into a Python callable that takes a parameter set and outputs a set of summary statistics we want to fit the parameters to.

3. Prior over model parameters

Now that we have the simulator, we need to define a function with the prior over the model parameters (\(\bar g_{Na}\),\(\bar g_K\)), which in this case is chosen to be a Uniform distribution:

prior_min = [.5,1e-4]
prior_max = [80.,15.]
prior = utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min), 
                                    high=torch.as_tensor(prior_max))

Inference

Now that we have all the required components, we can run inference with SNPE to identify parameters whose activity matches this trace.

posterior = infer(simulation_wrapper, prior, method='SNPE', 
                  num_simulations=300, num_workers=4)
HBox(children=(FloatProgress(value=0.0, description='Running 300 simulations in 300 batches.', max=300.0, styl…



Neural network successfully converged after 233 epochs.

Note sbi can parallelize your simulator. If you experience problems with parallelization, try setting num_workers=1 and please give us an error report as a GitHub issue.

Coming back to the observed data

As mentioned at the beginning of the tutorial, the observed data are generated by the Hodgkin-Huxley model with a set of known parameters (\(\bar g_{Na}\),\(\bar g_K\)). To illustrate how to compute the summary statistics of the observed data, let us regenerate the observed data:

# true parameters and respective labels
true_params = np.array([50., 5.])
labels_params = [r'$g_{Na}$', r'$g_{K}$']
observation_trace = run_HH_model(true_params)
observation_summary_statistics = calculate_summary_statistics(observation_trace)

As we already shown above, the observed voltage traces look as follows:

fig = plt.figure(figsize=(7,5))
gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
ax = plt.subplot(gs[0])
plt.plot(observation_trace['time'],observation_trace['data'])
plt.ylabel('voltage (mV)')
plt.title('observed data')
plt.setp(ax, xticks=[], yticks=[-80, -20, 40])

ax = plt.subplot(gs[1])
plt.plot(observation_trace['time'],I*A_soma*1e3,'k', lw=2)
plt.xlabel('time (ms)')
plt.ylabel('input (nA)')

ax.set_xticks([0, max(observation_trace['time'])/2, max(observation_trace['time'])])
ax.set_yticks([0, 1.1*np.max(I*A_soma*1e3)])
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.2f'))

png

Analysis of the posterior given the observed data

After running the inference algorithm, let us inspect the inferred posterior distribution over the parameters (\(\bar g_{Na}\),\(\bar g_K\)), given the observed trace. To do so, we first draw samples (i.e. consistent parameter sets) from the posterior:

samples = posterior.sample((10000,), 
                           x=observation_summary_statistics)
HBox(children=(FloatProgress(value=0.0, description='Drawing 10000 posterior samples', max=10000.0, style=Prog…
fig, axes = analysis.pairplot(samples,
                           limits=[[.5,80], [1e-4,15.]],
                           ticks=[[.5,80], [1e-4,15.]],
                           figsize=(5,5),
                           points=true_params,
                           points_offdiag={'markersize': 6},
                           points_colors='r');

png

As can be seen, the inferred posterior contains the ground-truth parameters (red) in a high-probability region. Now, let us sample parameters from the posterior distribution, simulate the Hodgkin-Huxley model for this parameter set and compare the simulations with the observed data:

# Draw a sample from the posterior and convert to numpy for plotting.
posterior_sample = posterior.sample((1,), 
                                    x=observation_summary_statistics).numpy()
HBox(children=(FloatProgress(value=0.0, description='Drawing 1 posterior samples', max=1.0, style=ProgressStyl…
fig = plt.figure(figsize=(7,5))

# plot observation
t = observation_trace['time']
y_obs = observation_trace['data']
plt.plot(t, y_obs, lw=2, label='observation')

# simulate and plot samples
x = run_HH_model(posterior_sample)
plt.plot(t, x['data'], '--', lw=2, label='posterior sample')

plt.xlabel('time (ms)')
plt.ylabel('voltage (mV)')

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.3, 1), 
          loc='upper right')

ax.set_xticks([0, 60, 120])
ax.set_yticks([-80, -20, 40]);

png

As can be seen, the sample from the inferred posterior leads to simulations that closely resemble the observed data, confirming that SNPE did a good job at capturing the observed data in this simple case.

References

A. L. Hodgkin and A. F. Huxley. A quantitative description of membrane current and its application to conduction and excitation in nerve. The Journal of Physiology, 117(4):500–544, 1952.

M. Pospischil, M. Toledo-Rodriguez, C. Monier, Z. Piwkowska, T. Bal, Y. Frégnac, H. Markram, and A. Destexhe. Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biological Cybernetics, 99(4-5), 2008.