Implementation of Structural Plasticity

Hi everyone,
I am trying to implement a structural model using Brian2. The model is a classical sparsely connected E-I model where I have two plastic connections E to E and I to E. Since these connections get updated by STDP it can of course happen that a weight of a connection vanishes. In this case this connection should be pruned and a new random connection in the network connection should be created.

Description of problem / What I am currently using

I solved this problem by making EE connections fully connected and adding a new variable representing the status of a synapse. Using the network operation decorator, I check for weights below a predefined threshold. A synapses pruning is performed by setting the status of a synapse to 0 and at the same time activating a new synapse.

I need to do a lot of simulations of such networks with various plasticity rules, so runtime is kinda important. And this implementation is rather slow taking about 2 hours for 40s simulation time in a model of 400 E and 100 I neurons.

I also tried to make sparse connections from the beginning and then dynamically connect new synapses in the network operation method. While this did not lead to any error, it stopped synaptic updates entirely.

So my question is, is there a good way to implement this other than the way described above?

Minimal code to reproduce problem

"""Shared network parameters"""
NE = 400
NI = NE // 4
input_num = 40
input_freq = 30  # Hz
gmax = 20.0
lr = 1e-2
epsilon = 0.1
gl = 10 * nS
er = -80 * mV
el = -60 * mV
tau_gaba = 10.0 * ms
tau_ampa = 5.0 * ms
vt = -50 * mV
memc = 200 * pfarad

# Neuron model
eqs_neurons = '''
    dv/dt = (-gl * (v - el) - (g_ampa * v + g_gaba * (v - er))) / memc : volt (unless refractory)
    dg_ampa/dt = -g_ampa / tau_ampa : siemens
    dg_gaba/dt = -g_gaba / tau_gaba : siemens
'''

# Neuron group
neurons = b2.NeuronGroup(NE + NI, model=eqs_neurons, threshold='v > vt',
                         reset='v=el', refractory=5 * ms, method='euler')
neurons.v = el
Pe = neurons[:NE]
Pi = neurons[NE:]

# EE plasticity parameters
ee_alpha_pre = 0.6
ee_alpha_post = -0.5
ee_Aplus = 3
ee_tauplus_stdp = 5 * ms
ee_tauminus_stdp = 20 * ms
ee_Aminus = -1.0

# Synapse model with an `active` parameter
synapse_model = '''
    w : 1  # Weight
    syn_status : integer  # Active or inactive status (1 or 0)
    dee_trace_pre_plus/dt = -ee_trace_pre_plus / ee_tauplus_stdp : 1 (event-driven)
    dee_trace_pre_minus/dt = -ee_trace_pre_minus / ee_tauminus_stdp : 1 (event-driven)
    dee_trace_post_plus/dt = -ee_trace_post_plus / ee_tauplus_stdp : 1 (event-driven)
    dee_trace_post_minus/dt = -ee_trace_post_minus / ee_tauminus_stdp : 1 (event-driven)
'''

# Define EE synapses with sparsity and `active` parameter
con_ee = b2.Synapses(Pe, Pe, model=synapse_model,
                     on_pre='''
                         g_ampa += w * nS * syn_status  # Only contribute if active
                         ee_trace_pre_plus += 1.0
                         ee_trace_pre_minus += 1.0
                         w = clip(w + lr * (ee_alpha_pre + ee_Aplus * ee_trace_post_plus +ee_Aminus * ee_trace_post_minus), 0, gmax)
                     ''',
                     on_post='''
                         ee_trace_post_plus += 1
                         ee_trace_post_minus += 1
                         w = clip(w + lr * (ee_alpha_post + ee_Aplus * ee_trace_pre_plus + ee_Aminus * ee_trace_pre_minus), 0, gmax)
                     ''')

con_ee.connect(condition='i != j')  # Fully connect, except self-connections
con_ee.w = np.random.uniform(low=0.1, high=0.2, size=len(con_ee.w))
con_ee.syn_status = np.random.choice([0.0, 1.0], size=len(con_ee.w), p=[1 - epsilon, epsilon])
# Introduce initial sparsity
# EI and II synapses
con_ei = b2.Synapses(Pe, Pi, model='w : 1', on_pre="g_ampa += w * nS")
con_ii = b2.Synapses(Pi, Pi, model='w : 1', on_pre="g_gaba += w * nS")
con_ei.connect(p=epsilon , condition='i != j')
con_ii.connect(p=epsilon , condition='i != j')
con_ei.w = np.random.uniform(low=0.4,high=0.8, size=len(con_ei.w))
con_ii.w = np.random.uniform(low=0.4, high=0.8, size=len(con_ii.w))

# IE Plasticity parameters and model
ie_alpha_pre = 0
ie_alpha_post = 0
ie_Aplus = 0.5
ie_tauplus_stdp = 15 * ms
ie_tauminus_stdp = 3 * ms
ie_Aminus = -1.0


synapse_model = '''
    w : 1  # Weight
    syn_status : integer  # Active or inactive status (1 or 0)
    die_trace_pre_plus/dt = -ie_trace_pre_plus / ie_tauplus_stdp : 1 (event-driven)
    die_trace_pre_minus/dt = -ie_trace_pre_minus / ie_tauminus_stdp : 1 (event-driven)
    die_trace_post_plus/dt = -ie_trace_post_plus / ie_tauplus_stdp : 1 (event-driven)
    die_trace_post_minus/dt = -ie_trace_post_minus / ie_tauminus_stdp : 1 (event-driven)
'''
con_ie = b2.Synapses(Pi, Pe, model=synapse_model,
                         on_pre='''
                                g_gaba += w*nS * syn_status
                                ie_trace_pre_plus += 1.0
                                ie_trace_pre_minus += 1.0
                                w = clip(w + lr * (ie_alpha_pre +ie_Aplus * ie_trace_post_plus +ie_Aminus * ie_trace_post_minus), 0, gmax)
                                ''',
                         on_post='''
                                ie_trace_post_plus += 1
                                ie_trace_post_minus += 1
                                w = clip(w + lr * (ie_alpha_post + ie_Aplus * ie_trace_pre_plus + ie_Aminus * ie_trace_pre_minus), 0, gmax)
                                '''
                         )
con_ie.connect(condition='i != j')  # Fully connect, except self-connections
con_ie.w = np.random.uniform(low=0.8, high=1.0, size=len(con_ie.w))
con_ie.syn_status = np.random.choice([0.0, 1.0], size=len(con_ie.w), p=[1 - epsilon, epsilon])

# Input group
P = b2.PoissonGroup(500, 30 * Hz)
S = b2.Synapses(P, Pe, on_pre="g_ampa += 0.4 * nS")
S.connect(p=0.05)
P2 = b2.PoissonGroup(500, 30 * Hz)
S2 = b2.Synapses(P, Pi, on_pre="g_ampa += 0.4 * nS")
S2.connect(p=0.05)
# Threshold
threshold = 0.005


@b2.network_operation(dt=50 * ms)
def structural_plasticity():
    for k,synapse in enumerate([con_ee, con_ie]):
        source, target = synapse.i[:], synapse.j[:]
        for i in range(len(source)):
            if synapse.w[i] < threshold and synapse.syn_status[i] == 1:
                synapse.syn_status[i] = 0  # Deactivate weak connection
                inactive_indices = np.where((synapse.i == source[i]) & (synapse.syn_status == 0))[0]
                if inactive_indices.size > 0:
                    new_index = np.random.choice(inactive_indices)
                    synapse.syn_status[new_index] = 1
                    if k == 0:
                        synapse.w[new_index] = 0.15
                    elif k == 1:
                        synapse.w[new_index] = 0.9

network = b2.Network(neurons, P, S, P2, S2, con_ee, con_ei, con_ii, con_ie, MPe, MPi)
network.add(structural_plasticity)
network.run(5 * second, report='text')

Thanks in advance,
Jan

Hi @JanHuehne, great question! Optimizing network_operation is not always obvious, and we should certainly document it better. Just get me one thing out of the way:

Indeed. While network_operation has full access to the state variables of a network (and in principle also to its internal variables), as a rule of thumb, it does not work for anything that affects the structure of the network (adding synapses, replacing/removing/adding objects to the network, etc.). The main reason is that there is some preparation that is done during the before_run phase. In particular, this prepares data structures for synaptic propagation. These will no longer be up-to-date if you add synapses after the start of a run, and things will break. In principle, it would be possible to update also the internal state of the network, but this would be quite fiddly and could also potentially break with updates to the Brian simulator.

Coming back to your original solution, which I think is the best general approach for now. At first, I wasn’t 100% sure whether the network_operation is really the issue (since it is only executed every 50ms), or whether the problem is doing all these “null operations” for the non-existing synapses. It turns out, in a small network like yours, it really is the network_operation that slows things down. Here’s the beginning of the profiling_summary output after running the simulation with profile=True:

Profiling summary
=================
networkoperation                    120.15 s    97.24 %
synapses_post                         0.48 s     0.39 %
poissongroup_1_spike_thresholder      0.36 s     0.29 %
synapses_pre                          0.32 s     0.26 %
synapses_3_pre                        0.32 s     0.26 %
...

Ok, there’s definitely room for optimization here :laughing: When I profile the network_operation line-by-line (with line_profiler), I get:

Timer unit: 1e-06 s

Total time: 260.393 s
File: /home/mstimberg/scratch/structual_plasticity.py
Function: structural_plasticity at line 124

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   124                                           @b2.network_operation(dt=50 * ms)
   125                                           @profile
   126                                           def structural_plasticity():
   127       300       1049.5      3.5      0.0      for k,synapse in enumerate([con_ee, con_ie]):
   128       200     776965.0   3884.8      0.3          source, target = synapse.i[:], synapse.j[:]
   129  19950200    2788742.0      0.1      1.1          for i in range(len(source)):
   130  19950000  256151272.9     12.8     98.4              if synapse.w[i] < threshold and synapse.syn_status[i] == 1:
   131      1439     100721.4     70.0      0.0                  synapse.syn_status[i] = 0  # Deactivate weak connection
   132      1439     361308.1    251.1      0.1                  inactive_indices = np.where((synapse.i == source[i]) & (synapse.syn_status == 0))[0]
   133      1439       1574.8      1.1      0.0                  if inactive_indices.size > 0:
   134      1439      82969.4     57.7      0.0                      new_index = np.random.choice(inactive_indices)
   135      1439      84789.0     58.9      0.0                      synapse.syn_status[new_index] = 1
   136      1439        370.7      0.3      0.0                      if k == 0:
   137      1439      43198.3     30.0      0.0                          synapse.w[new_index] = 0.15
   138                                                               elif k == 1:
   139                                                                   synapse.w[new_index] = 0.9

The problem here is that you are looping over every synapse, and then access their weight using synapse.w[i] (and the same for syn_status). This will go through Brian’s very general indexing machinery, which allows for 1 or 2-dimensional indexing, string indices, etc.; this is in general not that slow, but if you do it for thousands of synapses, it quickly adds up. It is quite a bit faster to get all synaptic weights at once, and then index this numpy array. This gives the following profiler output:

Timer unit: 1e-06 s

Total time: 9.11284 s
File: /home/mstimberg/scratch/structual_plasticity.py
Function: structural_plasticity at line 124

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   124                                           @b2.network_operation(dt=50 * ms)
   125                                           @profile
   126                                           def structural_plasticity():
   127       300        540.8      1.8      0.0      for k,synapse in enumerate([con_ee, con_ie]):
   128       200     726285.9   3631.4      8.0          source, target = synapse.i[:], synapse.j[:]
   129       200       2539.3     12.7      0.0          weights = synapse.w[:]
   130       200       2343.7     11.7      0.0          syn_status = synapse.syn_status[:]
   131  19950200    2805344.1      0.1     30.8          for i in range(len(source)):
   132  19950000    5226548.2      0.3     57.4              if weights[i] < threshold and syn_status[i] == 1:
   133      1726      56821.5     32.9      0.6                  synapse.syn_status[i] = 0  # Deactivate weak connection
   134      1726     152672.5     88.5      1.7                  inactive_indices = np.where((source == source[i]) & (synapse.syn_status == 0))[0]
   135      1726        768.9      0.4      0.0                  if inactive_indices.size > 0:
   136      1726      30496.6     17.7      0.3                      new_index = np.random.choice(inactive_indices)
   137      1726      61363.2     35.6      0.7                      synapse.syn_status[new_index] = 1
   138      1726        366.4      0.2      0.0                      if k == 0:
   139      1726      46750.2     27.1      0.5                          synapse.w[new_index] = 0.15
   140                                                               elif k == 1:
   141                                                                   synapse.w[new_index] = 0.9

This simple change reduces the total runtime from a few minutes (note that things run slower while the profiler is enabled) to a few seconds! But we can actually do quite a bit better. Instead of looping over each synapse, we can loop over each source neuron, and handle all synapses coming from this neuron as vectorized numpy operations. Also, we do not actually need to search for the synapses coming from a neuron, since all arrays are ordered by the source index. E.g. for the 400 × 399 connections from E to E (no self-connections), the first 399 entries in synapses.w will be the synapses coming from neuron 0, the next 399 entries those from neuron 1, and so on. Taking advantage of this in code leads us to:

def structural_plasticity():
    for k,synapse in enumerate([con_ee, con_ie]):
        weights = synapse.w[:]
        syn_status = synapse.syn_status[:]
        n_targets = len(synapse.target) - 1  # no self-connections
        for source_index in range(len(synapse.source)):            
            # We do not need to search for the source index, synaptic variables are ordered by it
            source_weights = weights[n_targets*source_index:n_targets*(source_index+1)]
            source_syn_status = syn_status[n_targets*source_index:n_targets*(source_index+1)]
            to_update_indices = np.nonzero((source_weights < threshold) & (source_syn_status == 1))[0]
            if len(to_update_indices):
                synapse.syn_status[n_targets*source_index + to_update_indices] = 0  # Deactivate weak connection
                inactive_indices = np.nonzero(source_syn_status == 0)[0]
                new_indices = np.random.choice(inactive_indices, len(to_update_indices), replace=False)
                synapse.syn_status[n_targets*source_index + new_indices] = 1
                if k == 0:
                    synapse.w[n_targets*source_index + new_indices] = 0.15
                elif k == 1:
                    synapse.w[n_targets*source_index + new_indices] = 0.9

These changes bring the time spent in network_operation down to 0.3s for a 5s simulation, which sounds much more reasonable :blush: On my machine, this makes the simulation faster than realtime, so a 40s simulation should take well under a minute. Please make sure that the code above is actually correct, though. I only verified that it runs without raising an error, but not that it does the same thing as the previous code…

Hope that gets you going

Amazing, thanks a lot. This definitely speeds up things. I am thinking of extending this structural type of plasticity also in a way such that a synapse does not necessary need to share the same source as a pruned one but that it can a totally random connection is added instead.

@b2.network_operation(dt=50 * ms)
def structural_plasticity():
    for k,synapse in enumerate([con_ee, con_ie]):
        weights = synapse.w[:]
        syn_status = synapse.syn_status[:]
        n_targets = len(synapse.target) - 1  # no self-connections
        for source_index in range(len(synapse.source)):            
            # We do not need to search for the source index, synaptic variables are ordered by it
            source_weights = weights[n_targets*source_index:n_targets*(source_index+1)]
            source_syn_status = syn_status[n_targets*source_index:n_targets*(source_index+1)]
            to_update_indices = np.nonzero((source_weights < threshold) & (source_syn_status == 1))[0]
            if len(to_update_indices):
                #inactive_indices = np.nonzero(source_syn_status == 0)[0]
                synapse.syn_status[n_targets*source_index + to_update_indices] = 0  # Deactivate weak connection
                # We might choose another source neuron to create a new connection
                new_sources =  np.random.choice(range(0,NE), len(to_update_indices), replace=True)

            
                new_sources, numbers = np.unique(new_sources, return_counts=True)
                for new_source, number in zip(new_sources, numbers):
                        source_syn_status = syn_status[n_targets*new_source:n_targets*(new_source+1)]
                        inactive_indices = np.nonzero(source_syn_status == 0)[0]
                        new_indices = np.random.choice(inactive_indices, number, replace=False)
                        synapse.syn_status[n_targets*new_source + new_indices] = 1
                        if k == 0:
                            synapse.w[n_targets*new_source + new_indices] = 0.15
                        elif k == 1:
                            synapse.w[n_targets*new_source + new_indices] = 0.9

I still don’t think that this is the optimal implementation of it. I guess one alternative way is to first gather the number of pruned synapses using the batched numpy operations and then sampling the new source neurons allowing for repetitive drawing and finally using the batched processing again to create the new connections for each samples source neuron. The problem is that one does not want to disable pruned synapses in the first loop just to by chance reactivate them in the second loop.

Hi @JanHuehne. Actually, if there is no restriction on the source/target neuron, i.e. pruning a synapse can give rise to another synapse anywhere, the code becomes quite a lot simpler, since you do not need to loop at all. You can get all the synapses to prune in a batch operation, and do the same for the creation of the new synapses. The trick to avoid that synapses get pruned and re-created in the same operation is to only get the indices of the indices to prune, but delay the pruning until after deciding on the new synapses to generate. This way, the synapses will not be considered by the synapse generation step. Something like this should work (just writing down the code, I did not test this):

@b2.network_operation(dt=50 * ms)
def structural_plasticity():
    for k,synapse in enumerate([con_ee, con_ie]):
        weights = synapse.w[:]
        syn_status = synapse.syn_status[:]
        to_remove_indices = np.nonzero((weights < threshold) & (syn_status == 1))[0]
        if len(to_remove_indices):
            inactive_indices = np.nonzero(syn_status == 0)[0]
            new_indices = np.random.choice(inactive_indices, len(to_remove_indices), replace=False)
            # Prune + generate
            synapse.syn_status[to_remove_indices] = 0
            synapse.syn_status[new_indices] = 1
            if k == 0:
                synapse.w[new_indices] = 0.15
            elif k == 1:
                synapse.w[new_indices] = 0.9

Hope that works for you!