I am trying to recreate this paper using Brian, but although I believe I have succeeded in my implementation the results are poor to say the least. Am I missing something? By setting manually half the weights in W
to 0.5 and the other half to -0.5 the results are very good, so by finding them through a minimization procedure one would expect the results to be at least as good.
# =================================
# =================================
# Encoding And Decoding Of Analog Signals
from brian2 import *
from brian2.core.functions import timestep
from scipy.optimize import minimize
from sklearn.metrics import mean_squared_error
start_scope()
# =================================
# =================================
# Define the model
# The number of half the neurons, used globally many times in loops.
n = 10
# Assuming resting membrane potential 0.
# The parameter b is the bias.
neuron_equation_on = '''
dv/dt = (R*I-v)/tau : 1 (unless refractory)
I = 3*sin(2*pi*f*t)+b : 1
f = 5*Hz: Hz
b : 1
R : 1
tau : second
'''
group_of_neurons_on = NeuronGroup(
N = n,
model = neuron_equation_on,
threshold = 'v>1',
reset = 'v=0',
refractory = 2*msecond,
method = 'euler')
group_of_neurons_on.R = [1 for i in range(n)]
group_of_neurons_on.tau = [2 for i in range(n)]*msecond
group_of_neurons_on.b = [random() for i in range(n)]
neuron_equation_off = '''
dv/dt = (R*I-v)/tau : 1 (unless refractory)
I = -3*sin(2*pi*f*t)+b : 1
f = 5*Hz: Hz
b : 1
R : 1
tau : second
'''
group_of_neurons_off = NeuronGroup(
N = n,
model = neuron_equation_off,
threshold = 'v>1',
reset = 'v=0',
refractory = 2*msecond,
method = 'euler')
group_of_neurons_off.R = [1 for i in range(n)]
group_of_neurons_off.tau = [2 for i in range(n)]*msecond
group_of_neurons_off.b = [random() for i in range(n)]
# =================================
# =================================
# Simulation
neuron_voltage_on = StateMonitor(
group_of_neurons_on,
variables='v',
record=True)
neuron_current_on = StateMonitor(
group_of_neurons_on,
variables='I',
record=True)
neuron_voltage_off = StateMonitor(
group_of_neurons_off,
variables='v',
record=True)
neuron_current_off = StateMonitor(
group_of_neurons_off,
variables='I',
record=True)
neuron_spikes_on = SpikeMonitor(group_of_neurons_on)
neuron_spikes_off = SpikeMonitor(group_of_neurons_off)
simulation_time = 1000*msecond
run(simulation_time, report = 'text')
simulation_timesteps = np.arange(0, simulation_time, defaultclock.dt)
simulation_timesteps = simulation_timesteps/1000
# =================================
# =================================
# Create spike trains
neuron_spike_train_on = np.zeros((n, timestep(simulation_time, defaultclock.dt)), dtype=int)
neuron_spike_train_off = np.zeros((n, timestep(simulation_time, defaultclock.dt)), dtype=int)
# Spike train for the ON neuron
neuron_spike_train_on[neuron_spikes_on.i, timestep(neuron_spikes_on.t, defaultclock.dt)] = 1
# Spike train for the OFF neuron
neuron_spike_train_off[neuron_spikes_off.i, timestep(neuron_spikes_off.t, defaultclock.dt)] = 1
# Check the validity of the created spike trains
# for i in range(n):
# print("---------------------------")
# print("Neuron pair number", i)
# print("---------------------------")
# print("Provided spikes of ON neurons:", neuron_spikes_on.count[i])
# print("Calculated spikes of ON neurons:", sum(neuron_spike_train_on[i]))
# print("Provided spikes of OFF neurons:", neuron_spikes_off.count[i])
# print("Calculated spikes of OFF neurons:", sum(neuron_spike_train_off[i]))
# =================================
# =================================
# Decoding
# Find optimal weights
# Filter with transfer function H = k/(Ts+1)
neuron_filter = 5*exp(-50*(simulation_timesteps)/msecond)
# The convolution later on produces output of length 2n-1
filtered_spikes_full_on = np.zeros((n, 2*len(simulation_timesteps)-1), dtype=float)
filtered_spikes_full_off = np.zeros((n, 2*len(simulation_timesteps)-1), dtype=float)
filtered_spikes_trimmed_on = np.zeros((n, len(simulation_timesteps)), dtype=float)
filtered_spikes_trimmed_off = np.zeros((n, len(simulation_timesteps)), dtype=float)
for i in range(n):
# Filter the spike trains from both on and off neuron groups.
filtered_spikes_full_on[i] = convolve(neuron_filter, neuron_spike_train_on[i])
filtered_spikes_full_off[i] = convolve(neuron_filter, neuron_spike_train_off[i])
# Trim the spike trains.
filtered_spikes_trimmed_on[i] = filtered_spikes_full_on[i][:len(simulation_timesteps)]
filtered_spikes_trimmed_off[i] = filtered_spikes_full_off[i][:len(simulation_timesteps)]
# Reconstruct the initial signal
t = linspace(0, 1, len(simulation_timesteps))
input_signal = 3*sin(2*pi*5*t)
W = np.zeros(2*n, dtype=float)
W0 = np.zeros(2*n, dtype=float)
for i in range(n):
W0[i] = 0.5
W0[n+i] = -0.5
def func(W):
reconstructed_signal = [0 for i in range(len(simulation_timesteps))]
for i in range(n):
part_on = multiply(W[i], filtered_spikes_trimmed_on[i])
part_off = multiply(W[n+i], filtered_spikes_trimmed_off[i])
reconstructed_signal = numpy.add(reconstructed_signal, part_on)
reconstructed_signal = numpy.add(reconstructed_signal, part_off)
return mean_squared_error(input_signal, reconstructed_signal)
result = minimize(func, W0)
W = result.x
reconstructed_signal = [0 for i in range(len(simulation_timesteps))]
for i in range(n):
part_on = multiply(W[i], filtered_spikes_trimmed_on[i])
part_off = multiply(W[n+i], filtered_spikes_trimmed_off[i])
reconstructed_signal = numpy.add(reconstructed_signal, part_on)
reconstructed_signal = numpy.add(reconstructed_signal, part_off)
# =================================
# =================================
# Plot results
graph, (graph1, graph2) = plt.subplots(2, figsize=(20, 10), sharex=True)
graph1.plot(simulation_timesteps/msecond, reconstructed_signal)
# Create input current
I = 3*sin(2*pi*5*simulation_timesteps/msecond)
graph2.plot(simulation_timesteps/msecond, I)
show()