Description of problem
I am trying to implement the simple STDP rule with a boundary function between two neurons. The rule is as follows:
\frac{dw_{i,j}}{dt}=pF(w_{max}-w_{i,j})\exp(-\frac{t-t_{i}^{last}}{\tau_{p}})S_{j}(t)-dF(w_{i,j})\exp(-\frac{t-t_{j}^{last}}{\tau_{p}})S_{i}(t)
Here w_{i,j} marks synapse from ith to jth neuron. Function F(x)=\tanh(x/\mu) is a function that should limit the infinite growth of the synaptic weight. Functions S_{j}(t)=\sum_{n}\delta(t-t_{n}^{(j)}) are sums of delta functions at the time moments of the j-th neuron spike.
I had a hard time understanding the meaning of ‘on_pre=
’ and ‘on_post=
’ statements in the definition of the synapses. Bellow, I write my thoughts and a simple program to implement synapses defined by the earlier equation. Maybe somebody can confirm or correct it.
For the synapse w_{i,j}, the presynaptic neuron is ith, and the postsynaptic neuron is jth. So at the spike moment of ith neuron, i.e. moment of presynaptic neuron spike, w_{i,j} should decrease. In other words, all synapses that go from the spiking neuron should be weakened. Therefore, in the statement ‘on_pre=
’, we should write ‘w=w-d..
’.
At the spike of the postsynaptic jth neuron, the value of w_{i,j} should be increased. Therefore, in the statement ‘on_post=
’, we should write ‘w=w+p..
’.
Minimal code to reproduce problem
from brian2 import *
import numpy as np
%matplotlib widget
start_scope()
tau = 10*ms
g=0.05
taupre = 4*ms
taupost =2*ms
wmax = 1
p = 0.001
d = p
muu=0.02
eqs = '''
dv/dt = (a-v)/tau : 1
a :1
'''
G = NeuronGroup(2, eqs, threshold='v>1', reset='v = 0', method='exact')
G.a=[2,1.8];
M = StateMonitor(G, 'v', record=[0,1])
Mspk = SpikeMonitor(G)
S = Synapses(G, G,
'''
w : 1
dapre/dt = -apre/taupre : 1 (event-driven)
dapost/dt = -apost/taupost : 1 (event-driven)
''',
on_pre='''
v_post +=g*w
apre = p
w -= tanh(w/muu)*apost
''',
on_post='''
apost = d
w += tanh((wmax-w)/muu)*apre
''')
S.connect('i!=j')
S.w[0,1]=0.09
S.w[1,0]=0.9
Mw = StateMonitor(S, 'w', record=True)
run(10000*ms)
figure(2).clear()
step(Mw.t/0.001, Mw[S[0,1]].w.reshape(-1,1))
step(Mw.t/0.001, Mw[S[1,0]].w.reshape(-1,1))
legend(['$w_{01}$','$w_{10}$'])
show()