Hello everyone,
I am currently attempting to replicate NALSM from the paper “Increasing Liquid State Machine Performance with Edge-of-Chaos Dynamics Organized by Astrocyte-modulated Plasticity.” The authors have proposed a STDP method with adaptive depression learning rate to train LSM (as shown in formula 7 in the figure) and have implemented the model in the Tensorflow framework.
In simple terms, the global depression learning rate A_ is controlled by the number of input spikes and the number of liquid spikes. Therefore, I attempted to set up a virtual population of neurons (astro) connected to the input and liquid to calculate A_ (as shown in the figure below).
However, how can I make the synapses access/link to A_ within astro? Does Brian2 support such an implementation? I have attached my complete code below. Thank you very much for your suggestions.
from brian2 import (
NeuronGroup,
Synapses,
Network,
SpikeMonitor,
StateMonitor,
PoissonGroup
)
from brian2 import ms, Hz
from brian2 import defaultclock, collect
import matplotlib.pyplot as plt
defaultclock.dt = 1. * ms
tau = 10. * ms
# input and liquid neurons
neuron_inp = PoissonGroup(1, 400 * Hz)
neuron_liq = PoissonGroup(1, 400 * Hz)
# syn with STDP
eqs_synapse = '''
dTpre / dt = -Tpre / tau : 1 (event-driven)
dTpost / dt = -Tpost / tau : 1 (event-driven)
w : 1
Aastro : 1
'''
eqs_synapse_on_pre = '''
Tpre += 0.1
w -= Tpost * Aastro
'''
eqs_synapse_on_post = '''
Tpost += 0.1
w += Tpre * 0.15
'''
syns = Synapses(neuron_inp, neuron_liq, eqs_synapse,
on_pre=eqs_synapse_on_pre, on_post=eqs_synapse_on_post)
syns.connect()
syns.w = 0.5
st_syns = StateMonitor(syns, ['Tpre', 'Tpost', 'w', 'Aastro'], record=True)
eqs_astro = '''
dAastro / dt = (-Aastro + 0.15) / (100. * ms) : 1
'''
G = NeuronGroup(1, eqs_astro, method ='euler')
G.Aastro = 0.15
# syns.Aastro = linked_var(G, 'Aastro') # link
syn_inp = Synapses(neuron_inp, G, '', on_pre='Aastro -= 0.01')
syn_inp.connect()
syn_liq = Synapses(neuron_liq, G, '', on_pre='Aastro += 0.01')
syn_liq.connect()
sp_inp = SpikeMonitor(neuron_inp, record=True)
sp_liq = SpikeMonitor(neuron_liq, record=True)
st_asto = StateMonitor(G, ['Aastro'], record=True)
net = Network(collect())
net.run(21 * ms)
fig, axs = plt.subplots(4, 2, figsize=(7, 5), sharex='all')
axs[0, 0].plot(sp_inp.t / ms, sp_inp.i, '.k', label='input spikes')
axs[0, 0].legend()
axs[1, 0].plot(sp_liq.t / ms, sp_liq.i, '.k', label='liquid spikes')
axs[1, 0].legend()
axs[3, 0].plot(st_asto.t / ms, st_asto.Aastro[0], label='A_ from astro')
axs[3, 0].legend()
axs[0, 1].plot(st_syns.t / ms, st_syns.Tpre[0], label='Tpre')
axs[0, 1].legend()
axs[1, 1].plot(st_syns.t / ms, st_syns.Tpost[0], label='Tpost')
axs[1, 1].legend()
axs[2, 1].plot(st_syns.t / ms, st_syns.w[0], label='w')
axs[2, 1].legend()
axs[3, 1].plot(st_syns.t / ms, st_syns.Aastro[0], label='A_ from syn')
axs[3, 1].legend()
axs[3, 1].set(xlim=[-0.5, 20.5])
plt.tight_layout()
plt.show()