Hi everyone,
I am trying to implement a structural model using Brian2. The model is a classical sparsely connected E-I model where I have two plastic connections E to E and I to E. Since these connections get updated by STDP it can of course happen that a weight of a connection vanishes. In this case this connection should be pruned and a new random connection in the network connection should be created.
Description of problem / What I am currently using
I solved this problem by making EE connections fully connected and adding a new variable representing the status of a synapse. Using the network operation decorator, I check for weights below a predefined threshold. A synapses pruning is performed by setting the status of a synapse to 0 and at the same time activating a new synapse.
I need to do a lot of simulations of such networks with various plasticity rules, so runtime is kinda important. And this implementation is rather slow taking about 2 hours for 40s simulation time in a model of 400 E and 100 I neurons.
I also tried to make sparse connections from the beginning and then dynamically connect new synapses in the network operation method. While this did not lead to any error, it stopped synaptic updates entirely.
So my question is, is there a good way to implement this other than the way described above?
Minimal code to reproduce problem
"""Shared network parameters"""
NE = 400
NI = NE // 4
input_num = 40
input_freq = 30 # Hz
gmax = 20.0
lr = 1e-2
epsilon = 0.1
gl = 10 * nS
er = -80 * mV
el = -60 * mV
tau_gaba = 10.0 * ms
tau_ampa = 5.0 * ms
vt = -50 * mV
memc = 200 * pfarad
# Neuron model
eqs_neurons = '''
dv/dt = (-gl * (v - el) - (g_ampa * v + g_gaba * (v - er))) / memc : volt (unless refractory)
dg_ampa/dt = -g_ampa / tau_ampa : siemens
dg_gaba/dt = -g_gaba / tau_gaba : siemens
'''
# Neuron group
neurons = b2.NeuronGroup(NE + NI, model=eqs_neurons, threshold='v > vt',
reset='v=el', refractory=5 * ms, method='euler')
neurons.v = el
Pe = neurons[:NE]
Pi = neurons[NE:]
# EE plasticity parameters
ee_alpha_pre = 0.6
ee_alpha_post = -0.5
ee_Aplus = 3
ee_tauplus_stdp = 5 * ms
ee_tauminus_stdp = 20 * ms
ee_Aminus = -1.0
# Synapse model with an `active` parameter
synapse_model = '''
w : 1 # Weight
syn_status : integer # Active or inactive status (1 or 0)
dee_trace_pre_plus/dt = -ee_trace_pre_plus / ee_tauplus_stdp : 1 (event-driven)
dee_trace_pre_minus/dt = -ee_trace_pre_minus / ee_tauminus_stdp : 1 (event-driven)
dee_trace_post_plus/dt = -ee_trace_post_plus / ee_tauplus_stdp : 1 (event-driven)
dee_trace_post_minus/dt = -ee_trace_post_minus / ee_tauminus_stdp : 1 (event-driven)
'''
# Define EE synapses with sparsity and `active` parameter
con_ee = b2.Synapses(Pe, Pe, model=synapse_model,
on_pre='''
g_ampa += w * nS * syn_status # Only contribute if active
ee_trace_pre_plus += 1.0
ee_trace_pre_minus += 1.0
w = clip(w + lr * (ee_alpha_pre + ee_Aplus * ee_trace_post_plus +ee_Aminus * ee_trace_post_minus), 0, gmax)
''',
on_post='''
ee_trace_post_plus += 1
ee_trace_post_minus += 1
w = clip(w + lr * (ee_alpha_post + ee_Aplus * ee_trace_pre_plus + ee_Aminus * ee_trace_pre_minus), 0, gmax)
''')
con_ee.connect(condition='i != j') # Fully connect, except self-connections
con_ee.w = np.random.uniform(low=0.1, high=0.2, size=len(con_ee.w))
con_ee.syn_status = np.random.choice([0.0, 1.0], size=len(con_ee.w), p=[1 - epsilon, epsilon])
# Introduce initial sparsity
# EI and II synapses
con_ei = b2.Synapses(Pe, Pi, model='w : 1', on_pre="g_ampa += w * nS")
con_ii = b2.Synapses(Pi, Pi, model='w : 1', on_pre="g_gaba += w * nS")
con_ei.connect(p=epsilon , condition='i != j')
con_ii.connect(p=epsilon , condition='i != j')
con_ei.w = np.random.uniform(low=0.4,high=0.8, size=len(con_ei.w))
con_ii.w = np.random.uniform(low=0.4, high=0.8, size=len(con_ii.w))
# IE Plasticity parameters and model
ie_alpha_pre = 0
ie_alpha_post = 0
ie_Aplus = 0.5
ie_tauplus_stdp = 15 * ms
ie_tauminus_stdp = 3 * ms
ie_Aminus = -1.0
synapse_model = '''
w : 1 # Weight
syn_status : integer # Active or inactive status (1 or 0)
die_trace_pre_plus/dt = -ie_trace_pre_plus / ie_tauplus_stdp : 1 (event-driven)
die_trace_pre_minus/dt = -ie_trace_pre_minus / ie_tauminus_stdp : 1 (event-driven)
die_trace_post_plus/dt = -ie_trace_post_plus / ie_tauplus_stdp : 1 (event-driven)
die_trace_post_minus/dt = -ie_trace_post_minus / ie_tauminus_stdp : 1 (event-driven)
'''
con_ie = b2.Synapses(Pi, Pe, model=synapse_model,
on_pre='''
g_gaba += w*nS * syn_status
ie_trace_pre_plus += 1.0
ie_trace_pre_minus += 1.0
w = clip(w + lr * (ie_alpha_pre +ie_Aplus * ie_trace_post_plus +ie_Aminus * ie_trace_post_minus), 0, gmax)
''',
on_post='''
ie_trace_post_plus += 1
ie_trace_post_minus += 1
w = clip(w + lr * (ie_alpha_post + ie_Aplus * ie_trace_pre_plus + ie_Aminus * ie_trace_pre_minus), 0, gmax)
'''
)
con_ie.connect(condition='i != j') # Fully connect, except self-connections
con_ie.w = np.random.uniform(low=0.8, high=1.0, size=len(con_ie.w))
con_ie.syn_status = np.random.choice([0.0, 1.0], size=len(con_ie.w), p=[1 - epsilon, epsilon])
# Input group
P = b2.PoissonGroup(500, 30 * Hz)
S = b2.Synapses(P, Pe, on_pre="g_ampa += 0.4 * nS")
S.connect(p=0.05)
P2 = b2.PoissonGroup(500, 30 * Hz)
S2 = b2.Synapses(P, Pi, on_pre="g_ampa += 0.4 * nS")
S2.connect(p=0.05)
# Threshold
threshold = 0.005
@b2.network_operation(dt=50 * ms)
def structural_plasticity():
for k,synapse in enumerate([con_ee, con_ie]):
source, target = synapse.i[:], synapse.j[:]
for i in range(len(source)):
if synapse.w[i] < threshold and synapse.syn_status[i] == 1:
synapse.syn_status[i] = 0 # Deactivate weak connection
inactive_indices = np.where((synapse.i == source[i]) & (synapse.syn_status == 0))[0]
if inactive_indices.size > 0:
new_index = np.random.choice(inactive_indices)
synapse.syn_status[new_index] = 1
if k == 0:
synapse.w[new_index] = 0.15
elif k == 1:
synapse.w[new_index] = 0.9
network = b2.Network(neurons, P, S, P2, S2, con_ee, con_ei, con_ii, con_ie, MPe, MPi)
network.add(structural_plasticity)
network.run(5 * second, report='text')
Thanks in advance,
Jan