Fitting an AdEx-type neuron to spike times of biophysical model

Description of problem

Hi everyone,

I am trying to use the brian2modelfitting library to fit an neuron model that has an AdEx soma and a coupled passive dendritic segment to the spike times of a biophysical model and for some reason the error I get is always quite high. I am using the SpikeFitter module with the GammaFactor metric. From what I understand (and please let me know if I’m wrong), the error should be much smaller than 1 and ideally -1 with the rate_correction=True or 0 with rate_correction=False to indicate a perfect match. In my case, I am always getting error values above 1 and the optimization process seems to get stuck in local minimas, without converging to a solution that would give me a good spike coincidence between the AdEx model and the spike times recorded from the biophysical model. I have tried various optimizers (NevergradOptimizer, SkoptOptimizer(‘GBRT’, ‘RF’) etc.), however, it doesn’t seem to improve much.

I am thinking that perhaps trying to fit an AdEx model to this kind of data (long step current and spike times of a biophysical model) is simply not going to work very well, but I have seen in multiple papers that AdEx models are quite good at reproducing even subtle sub-threshold fluctuations in voltage for a given input.

Perhaps I have to also consider the TraceFitter, although I tried it a couple of times and the errors are much higher, and ultimately what I am interested in is having a similar f-I curve for my model when compared to the biophysical one. The parameters that the SpikeFitter finds are not bad and sometimes the f-I curve looks similar for some input values, but the fact that the error is still above 1 tells me that there could be a way to obtain much better performance.

This is my code to show how I am approaching the problem. In reality I have the files saved from my biophysical model simulation with the input current and spike times, but here I just create them manually so the code can be used as it is. The current is a step function that takes a positive value between times 0.1s and 2.1s and 0 otherwise (total time 2.2s). For each step value (90pA, 100pA etc.), there is a spike list that contains the spike times that correspond to that input when the biophysical model is stimulated with it (expressed in ms). The ranges I express for the parameter values of the AdEx model are based on data I found in the literature that used AdEx to fit electrophysiological data.

Minimal code to reproduce problem

import numpy as np
from brian2 import *
from brian2modelfitting import *
import os

# Lists of spike times corresponding to each step input (in ms)
spikes_90pA = [285.5, 429.4, 582.3, 761.7, 981.5, 1232.9, 1496.0, 1761.8, 2027.9]
spikes_100pA = [201.2, 279.1, 362.3, 458.2, 579.4, 736.5, 917.5, 1104.6, 1291.9, 1478.3, 1664.0, 1849.2, 2034.1]
spikes_110pA = [174.9, 233.3, 297.1, 372.0, 468.6, 598.3, 751.7, 910.4, 1068.6, 1225.7, 1381.9, 1537.3, 1692.2, 1846.7, 2000.8]
spikes_120pA = [160.5, 1795.9, 1659.3, 1522.4, 1385.1, 1247.1, 1108.2, 1932.2, 968.3, 685.1, 545.3, 423.4, 333.0, 265.3, 209.4, 827.2, 2068.4]
spikes_130pA = [151.0, 1793.9, 1669.8, 1545.4, 1420.6, 1295.3, 1169.4, 1042.7, 914.9, 785.8, 655.3, 525.2, 405.7, 314.0, 247.3, 194.6, 1917.8, 2041.6]
spikes_140pA = [144.3, 1812.1, 1697.5, 1582.6, 1467.4, 1351.9, 1235.9, 1119.3, 1926.5, 1001.9, 763.7, 642.5, 520.5, 403.2, 306.9, 237.3, 184.7, 883.4, 2040.8]
spikes_150pA = [139.1, 1833.4, 1726.5, 1619.4, 1512.0, 1404.4, 1296.3, 1187.8, 1078.6, 968.6, 857.7, 745.6, 632.0, 517.2, 404.1, 305.0, 231.4, 177.5, 1940.1, 2046.7]
spikes_160pA = [135.0, 1855.2, 1754.8, 1654.3, 1553.6, 1452.6, 1351.3, 1249.6, 1147.4, 1955.4, 1044.5, 836.2, 730.5, 623.4, 515.0, 406.7, 306.7, 228.6, 172.2, 940.8, 2055.5]
spikes_170pA = [131.7, 1873.6, 1778.9, 1684.1, 1589.1, 1493.8, 1398.3, 1302.5, 1206.3, 1109.5, 1012.2, 914.1, 815.1, 715.0, 613.6, 510.8, 407.4, 308.8, 227.4, 168.1, 1968.1, 2062.5]
spikes_180pA = [129.0, 1889.2, 1799.6, 1709.9, 1620.0, 1530.0, 1439.6, 1349.0, 1258.0, 1166.6, 1978.7, 1074.7, 889.1, 795.0, 699.9, 603.6, 505.9, 407.3, 311.2, 227.6, 164.9, 982.3, 2068.1]
spikes_190pA = [126.6, 1896.7, 1811.7, 1726.6, 1641.4, 1556.0, 1470.4, 1384.5, 1298.3, 1211.8, 1124.9, 1037.4, 949.4, 860.6, 771.0, 680.4, 588.6, 495.5, 401.5, 309.1, 226.1, 161.9, 1981.6, 2066.4]

spike_times = [spikes_90pA, spikes_100pA, spikes_110pA, spikes_120pA, spikes_130pA, spikes_140pA, spikes_150pA, spikes_160pA, spikes_170pA, spikes_180pA, spikes_190pA]

# Code to generate list of stimulation arrays I used
stimulation = []
simulation_time = 2.2 # in seconds
time_resolution = 0.0001 # in seconds (0.1 ms)
start_time = 0.1 # in seconds (start of current injection)
end_time = 2.1 # in seconds (end of current injection)

for amp_value in [round(x, 3) for x in list(range(90, 200, 10))]: # starts from 90pA until 190pA
    # Define parameters to generate current trace
    current_amplitude = amp_value*(10**-12) # pA in Amperes
    # Create the time array
    time_points = np.arange(0, simulation_time + time_resolution, time_resolution)
    # Create the current array, setting step current between 100ms and 2.1s, and 0 otherwise
    current_clamp = np.where((time_points >= start_time) & (time_points <= end_time), current_amplitude, 0)
    stimulation.append(current_clamp)

# Actual fitting procedure
# Membrane time constant taken from experimental data
C_soma_basal = 62.15571803 * pfarad # Capacitance of soma compartment
gL_soma_basal = 2.5684181 * nsiemens # Conductance of soma compartment
gL_rest_neuron = 6.25078264 * nsiemens # Parameter for passive dendrite
C_rest_neuron = 151.26893999 * pfarad # Parameter for passive dendrite
g_rest_neuron_soma_basal = 9. * nsiemens # Coupling conductance between compartments
g_soma_basal_rest_neuron = 9. * nsiemens # Symmetric coupling conductance
DeltaT_soma_basal = 2*mV # Value most commonly found in literature

model = '''
dV_soma_basal/dt = (gL_soma_basal * (EL_soma_basal-V_soma_basal) + gL_soma_basal*DeltaT_soma_basal*exp((V_soma_basal-Vth_soma_basal)/DeltaT_soma_basal) + I_soma_basal - w_soma_basal) / C_soma_basal  :volt
dw_soma_basal/dt = (a_soma_basal * (V_soma_basal-EL_soma_basal) -w_soma_basal) / tauw_soma_basal :amp
I_soma_basal = I_ext_soma_basal + I_rest_neuron_soma_basal :amp
I_rest_neuron_soma_basal = (V_rest_neuron-V_soma_basal) * g_rest_neuron_soma_basal :amp

dV_rest_neuron/dt = (gL_rest_neuron * (EL_rest_neuron-V_rest_neuron) + I_rest_neuron) / C_rest_neuron :volt
I_rest_neuron = I_soma_basal_rest_neuron :amp
I_soma_basal_rest_neuron = (V_soma_basal-V_rest_neuron) * g_soma_basal_rest_neuron :amp

EL_rest_neuron = EL_soma_basal : volt
EL_soma_basal : volt (constant)
Vth_soma_basal : volt (constant)
a_soma_basal : siemens (constant)
tauw_soma_basal : second (constant)
b : ampere (constant)
V_reset : volt (constant)
'''

inp_trace = stimulation
out_trace = spike_times

n_opt = NevergradOptimizer()
metric = GammaFactor(delta=5*ms, time=2.2*second, rate_correction=True)
fitter = SpikeFitter(model=model,
                     input_var='I_ext_soma_basal',
                     input=inp_trace * amp,
                     output=out_trace,
                     dt=0.1*ms,
                     n_samples=1000,
                     method='exponential_euler',
                     param_init={'V_soma_basal': -70*mV, 'V_rest_neuron': -70*mV, 'V_reset': -70*mV},
                     reset='V_soma_basal=V_reset; w_soma_basal += b',
                     threshold='V_soma_basal > Vth_soma_basal + 5*DeltaT_soma_basal')

res, error = fitter.fit(n_rounds=20,
                        optimizer=n_opt,
                        metric=metric,
                        EL_soma_basal=[-80*mV,-60*mV],
                        Vth_soma_basal=[-60*mV,-40*mV],
                        a_soma_basal=[-5*nS,4.0*nS],
                        tauw_soma_basal=[20.0*ms,200.0*ms],
                        V_reset=[-80*mV,-60*mV],
                        b=[50*pA,200*pA])

It’s my first time working with such optimization processes, so if anyone has an idea as to what approach might work or why my approach doesn’t work, or any tips on good practices when fitting such models to data, I would really appreciate it :blush:

Thank you:)

Best wishes,
Rares

Your post reminds me that brian2modelfitting needs a bit of love, there are plenty of minor bugs, and it has been ages since its last release :crying_cat_face: That said, I don’t see anything wrong with your general approach. I’d maybe change the param_init slightly, e.g. to

param_init={'V_soma_basal': 'EL_soma_basal', 'V_rest_neuron': 'V_reset'},

This way, it takes the current parameters into account, which should avoid large initial transients (also, it doesn’t make sense to initialize V_reset if it is one of your fitted parameters). But this shouldn’t change anything fundamentally. Your cell responses seem to be quite regular with some weak adaption, so I’d agree with your intuition that the AdEx model should be able to fit it. But I can also confirm your observation that the fitting is not doing a very good job. Actually, if you only fit the number of spikes (so basically the f/I curve), then you get a really good fit, e.g. something like this:
image

[Doing this kind of fitting unfortunately currently needs you to write your own metric, which is a bit of a shame for such a simple operation. Let me know if you want code to do this]

When you look at the spikes generated with the best fit parameters (here for a best fit on the firing rates, but similar if you run your code), you will see that one quite clear difference between the simulated traces (blue soma, orange dendrite) compared to the observed spikes (green crosses) is that the simulated neuron quickly spikes after the current onset.

image

I think this prevents a good fit, since the model does not seem to be able to match both the timing of the first spike, and the total number of spikes at the same time. I’d therefore suggest to play around a bit with the model parameters by hand and/or to look into the dynamics to see why this early spike happens. Maybe by changing one of the parameters that are currently fixed you can delay that first spike? In that case, I’d add that parameter to the list of fittted parameters and see what the algorithm comes up with.

Hope that helps a bit, let us know how it goes!

1 Like

Dear @mstimberg,

Thanks a lot for your reply :blush: It helps clarify some of the concerns I had about the fitting. Indeed it seems that the firing rates match quite well, so the high error value I am seeing could be caused by the initial spike as you have shown, which prevents having coincident spike times later on as well.

It could be that the conductance and capacitance values I set for the soma (which are not being fitted during optimization but are taken from experimental data) makes it hard for the fitter to match the initial spike time. I will play around with these parameters and also perhaps try to fit them too, while imposing the neuron to maintain a realistic membrane time constant.

You mentioned that in order to only fit the firing rates, I would have to write my own metric, however, it would be great if you could share the code as you mentioned so I could also play around with that for various types of neurons I am trying to fit. It would probably give me a better feeling on the fitting performance than the spike-coincidence, at least until I manage to find a good set of parameters to fix that. I would appreciate it!

Many thanks again for your help, I’ll let you know how things progress.

Best wishes,
Rares

Sure, here’s what I used (you can plug this in as a replacement for GammaFactor):

class SpikeCountMetric(SpikeMetric):

    def get_features(self, model_spikes, data_spikes, dt):
        errors = np.zeros((len(model_spikes), len(model_spikes[0])))
        data_count = np.array([len(spikes) for spikes in data_spikes])
        # Looping over parameter combinations
        for i, spikes in enumerate(model_spikes):
            # Error for each trace is the absolute difference between the number of spikes
            errors[i] = np.abs(data_count - np.array([len(s) for s in spikes]))
        return errors
    
    def get_errors(self, features):
        # Sum up the errors for each trace
        return np.sum(features, axis=1)

For each trace, get_features considers the absolute difference of the number of spikes between model and data, but you could also use the squared difference, or a relative error compared to the expected number of spikes, etc. Similarly, get_errors simply sums up the error values for each trace, but you could imaging for example giving a stronger weight to an error in one of the traces (e.g. that a trace with a 0pA current should give 0 spikes).

Let me know if anything is unclear!