Automatic Parameter fitting with delfi and Brian2

A basic adaptation of the delfi tutorial to fit two parameters (refractory time and membrane time constant).

import brian2 as br
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
br.set_device("cpp_standalone")

eqs = '''
dv/dt = (25-v)/tau : 1 (unless refractory)
'''

rst = """
v=Vr
"""

def run_network(params, N, eqs, rst, time = 20000*br.ms):
    """
    builds and runs the network with specified params
    args: 
    params: np.array([tref, tau])
    N: number of neurons
    eqs: equations governing neuronal dynamics
    rst: what to do after a spike
    time: the time for which to run the network (in brian2 compatable format)
    returns:
    histogramm of spike times
    """
    


    Vthr = 20 #spiking threshold
    Vr = 10 #reset after spike

    tref = params[0]*br.ms #refractory period
    tau = params[1]*br.ms #membrane time constant
    
    #define objects of the network
    G = br.NeuronGroup(N, model=eqs, threshold="v>Vthr", reset=rst, refractory=tref, method="euler")
    
    Spikes = br.SpikeMonitor(G)
    
    net = br.Network([G, Spikes])

    net.run(time)

    print("done with a run")
    
    spiketimes = Spikes.t/br.ms
    
    br.device.reinit()
    br.device.activate()
    
    #has to be histogramm because number of spikes can vary -> problem broadcasting shapes
    return np.histogram(spiketimes,int(10*time/br.ms))[0]

#define the simulator
from delfi.simulator.BaseSimulator import BaseSimulator

class NetworkSim(BaseSimulator):
    def __init__(self, N, eqs, rst,  time = 20000*br.ms, seed = None):
        dim_param = 1
        
        super().__init__(dim_param=dim_param, seed=seed)
        self.N = N
        self.eqs = eqs
        self.rst = rst
        self.time = time
        self.run_network = run_network
        
    def gen_single(self, params):
        
        assert params.ndim == 1, "parameter dimension must be 1"
        
        #network_seed = self.gen_newseed()
        states = self.run_network(params, self.N, self.eqs, self.rst,  self.time)
        
        return {"data" : states}

#create prior distributions
import delfi.distribution as dd
seed_p = 2
#range of [tref, tau]
prior_min = np.array([0, 0.1])
prior_max = np.array([20, 100])

prior = dd.Uniform(lower=prior_min, upper=prior_max,seed=seed_p)



#generate network
import delfi.generator as dg

m = NetworkSim(N=1, eqs=eqs, rst=rst, time = 40000*br.ms)

from delfi.summarystats.Identity import Identity
s = Identity()

foo = dg.Default(model = m, prior = prior, summary = s)

#define ground truth simulation
true_params = np.array([2, 20])
labels_params = ["tref", "tau"]

obs = m.gen_single(true_params)
obs_stats = s.calc([obs])


#meta-parameters for SNPE
seed_inf = 1

pilot_samples = 10

#training schedule
n_train = 10
n_rounds = 2

#fitting setup
minibatch = 5
epochs = 10
val_frac = 0.05

#network setup
n_hiddens = [50,50]

#convenience
prior_norm = True

#MAF parameters
density = 'maf'
n_mades = 5         # number of MADES


import delfi.inference as infer

#inference object
res = infer.SNPEC(foo,
                obs=obs_stats,
                n_hiddens=n_hiddens,
                seed=seed_inf,
                pilot_samples=pilot_samples,
                n_mades=n_mades,
                prior_norm=prior_norm,
                density=density)

#train
loglik, _, posterior = res.run(
                    n_train=n_train,
                    n_rounds=n_rounds,
                    minibatch=minibatch,
                    epochs=epochs,
                    silent_fail=False,
                    proposal='prior',
                    val_frac=val_frac,
                    verbose=True)

#plot the loss
fig = plt.figure(figsize=(15,5))
plt.plot(loglik[0]['loss'],lw=2)
plt.xlabel('iteration')
plt.ylabel('loss');

#plot posterior distribution
from delfi.utils.viz import samples_nd

prior_min = foo.prior.lower
prior_max = foo.prior.upper
prior_lims = np.concatenate((prior_min.reshape(-1,1),prior_max.reshape(-1,1)),axis=1)

posterior_samples = posterior[0].gen(1000)

###################
#colors
hex2rgb = lambda h: tuple(int(h[i:i+2], 16) for i in (0, 2, 4))

#RGB colors in [0, 255]
col = {}
col['GT']      = hex2rgb('30C05D')
col['SNPE']    = hex2rgb('2E7FE8')
col['SAMPLE1'] = hex2rgb('8D62BC')
col['SAMPLE2'] = hex2rgb('AF99EF')

#convert to RGB colors in [0, 1]
for k, v in col.items():
    col[k] = tuple([i/255 for i in v])

###################
#posterior
fig, axes = samples_nd(posterior_samples,
                       limits=prior_lims,
                       ticks=prior_lims,
                       labels=labels_params,
                       #fig_size=(5,5),
                       diag='kde',
                       upper='kde',
                       hist_diag={'bins': 50},
                       hist_offdiag={'bins': 50},
                       kde_diag={'bins': 50, 'color': col['SNPE']},
                       kde_offdiag={'bins': 50},
                       points=[true_params],
                       points_offdiag={'markersize': 5},
                       points_colors=[col['GT']],
                       title='')
5 Likes

Hi. That’s great, many thanks for sharing! For the notebook, I think the easiest would be to upload it as a “gist” on gist.github.com/ and then link to it here.

PS: I edited your post to change the formatting. You can get nicely highlighted code if you wrap it in:

```Python
# the python code
```

Thank you. It looks much cleaner now

Here’s the gist: https://gist.github.com/TheSalocin/a52f493c2d2a90d7b9b73b4caf47f22e

1 Like

Thanks again. Having the content is the most important thing, but note that you can also upload the actual .ipynb file as a gist, github will then render it (including plots, etc.).

Super cool! I’m going to try this out myself as soon as I get a chance.

Hi, Can I ask something?
I already worked with another package from this group SBI, which is kind of simpler (inference in one line of code).

As I know SBI use joblib library for multiprocessing, and I think if we want to integrate Brian simulator in standalone mode with SBI, it need to compile the code with different parameter for each simulation during the training, and sometimes we need thousands of simulations.

So the question is that is it possible to make it faster by avoiding the compilation time. Despite using Numpy mode.

def simulation_wrapper(params):
    """
    Returns summary statistics from conductance values in `params`.

    Summarizes the output of the HH simulator and converts it to `torch.Tensor`.
    """
    obs = run_HH_model(params)   # Using Brian
    summstats = torch.as_tensor(calculate_summary_statistics(obs))
    return summstats

... some more code ...

posterior = infer(simulation_wrapper, prior, method='SNPE', 
                  num_simulations=300, num_workers=4) # this use joblib to parallel the simulation

Probably the example
@mstimberg introduced may work here, but I haven’t tried yet.

Hi @Ziaeemehr. When each individual simulation is short, then the compilation overhead of the standalone mode is indeed probably not worth it. Running the same simulation repeatedly but with different parameters is something that we are planning to improve in Brian, right now it is unfortunately not straightforward. If you are simulating/fitting single neurons and not networks, you can simulate many parameters at once within a single “network”. That’s the approach we are taking in the brian2modelfitting toolbox. For a network, you can speed up repeated simulations with the trick you linked but this only works for sequential simulations. If you run multiple standalone simulations in parallel, you have to be careful that they do not use the same output directory. Since this directory stores the results, independent processes would overwrite each other’s results, making a big mess! There is a way to make all this work efficiently (you could copy over the compiled code to another directory before running it), but this would be quite a fragile and hacky solution. I might upload an example doing this nevertheless, since it has become an issue for many users.

1 Like