Encoding And Decoding Of Analog Signals With Brian

I am trying to recreate this paper using Brian, but although I believe I have succeeded in my implementation the results are poor to say the least. Am I missing something? By setting manually half the weights in W to 0.5 and the other half to -0.5 the results are very good, so by finding them through a minimization procedure one would expect the results to be at least as good.

# =================================
# =================================
# Encoding And Decoding Of Analog Signals

from brian2 import *

from brian2.core.functions import timestep

from scipy.optimize import minimize

from sklearn.metrics import mean_squared_error

start_scope()

# =================================
# =================================
# Define the model

# The number of half the neurons, used globally many times in loops.

n = 10

# Assuming resting membrane potential 0.
# The parameter b is the bias.

neuron_equation_on = '''
    dv/dt = (R*I-v)/tau : 1 (unless refractory)
    I = 3*sin(2*pi*f*t)+b : 1
    f = 5*Hz: Hz
    b : 1
    R : 1
    tau : second
    '''

group_of_neurons_on = NeuronGroup(
    N = n,
    model = neuron_equation_on,
    threshold = 'v>1',
    reset = 'v=0',
    refractory = 2*msecond,
    method = 'euler')

group_of_neurons_on.R = [1 for i in range(n)]
group_of_neurons_on.tau = [2 for i in range(n)]*msecond
group_of_neurons_on.b = [random() for i in range(n)]

neuron_equation_off = '''
    dv/dt = (R*I-v)/tau : 1 (unless refractory)
    I = -3*sin(2*pi*f*t)+b : 1
    f = 5*Hz: Hz
    b : 1
    R : 1
    tau : second
    '''

group_of_neurons_off = NeuronGroup(
    N = n,
    model = neuron_equation_off,
    threshold = 'v>1',
    reset = 'v=0',
    refractory = 2*msecond,
    method = 'euler')

group_of_neurons_off.R = [1 for i in range(n)]
group_of_neurons_off.tau = [2 for i in range(n)]*msecond
group_of_neurons_off.b = [random() for i in range(n)]

# =================================
# =================================
# Simulation

neuron_voltage_on = StateMonitor(
    group_of_neurons_on,
    variables='v',
    record=True)

neuron_current_on = StateMonitor(
    group_of_neurons_on,
    variables='I',
    record=True)

neuron_voltage_off = StateMonitor(
    group_of_neurons_off,
    variables='v',
    record=True)

neuron_current_off = StateMonitor(
    group_of_neurons_off,
    variables='I',
    record=True)

neuron_spikes_on = SpikeMonitor(group_of_neurons_on)
neuron_spikes_off = SpikeMonitor(group_of_neurons_off)

simulation_time = 1000*msecond

run(simulation_time, report = 'text')

simulation_timesteps = np.arange(0, simulation_time, defaultclock.dt)

simulation_timesteps = simulation_timesteps/1000

# =================================
# =================================
# Create spike trains

neuron_spike_train_on = np.zeros((n, timestep(simulation_time, defaultclock.dt)), dtype=int)
neuron_spike_train_off = np.zeros((n, timestep(simulation_time, defaultclock.dt)), dtype=int)

# Spike train for the ON neuron
neuron_spike_train_on[neuron_spikes_on.i, timestep(neuron_spikes_on.t, defaultclock.dt)] = 1
# Spike train for the OFF neuron
neuron_spike_train_off[neuron_spikes_off.i, timestep(neuron_spikes_off.t, defaultclock.dt)] = 1

# Check the validity of the created spike trains

# for i in range(n):
#     print("---------------------------")
#     print("Neuron pair number", i)
#     print("---------------------------")
#     print("Provided spikes of ON neurons:", neuron_spikes_on.count[i])
#     print("Calculated spikes of ON neurons:", sum(neuron_spike_train_on[i]))
#     print("Provided spikes of OFF neurons:", neuron_spikes_off.count[i])
#     print("Calculated spikes of OFF neurons:", sum(neuron_spike_train_off[i]))

# =================================
# =================================
# Decoding

# Find optimal weights

# Filter with transfer function H = k/(Ts+1)
neuron_filter = 5*exp(-50*(simulation_timesteps)/msecond)

# The convolution later on produces output of length 2n-1
filtered_spikes_full_on = np.zeros((n, 2*len(simulation_timesteps)-1), dtype=float)
filtered_spikes_full_off = np.zeros((n, 2*len(simulation_timesteps)-1), dtype=float)

filtered_spikes_trimmed_on = np.zeros((n, len(simulation_timesteps)), dtype=float)
filtered_spikes_trimmed_off = np.zeros((n, len(simulation_timesteps)), dtype=float)

for i in range(n):

    # Filter the spike trains from both on and off neuron groups.
    filtered_spikes_full_on[i] = convolve(neuron_filter, neuron_spike_train_on[i])
    filtered_spikes_full_off[i] = convolve(neuron_filter, neuron_spike_train_off[i])
    
    # Trim the spike trains.
    filtered_spikes_trimmed_on[i] = filtered_spikes_full_on[i][:len(simulation_timesteps)]
    filtered_spikes_trimmed_off[i] = filtered_spikes_full_off[i][:len(simulation_timesteps)]

# Reconstruct the initial signal

t = linspace(0, 1, len(simulation_timesteps))
input_signal = 3*sin(2*pi*5*t)

W = np.zeros(2*n, dtype=float)
W0 = np.zeros(2*n, dtype=float)

for i in range(n):
    W0[i] = 0.5
    W0[n+i] = -0.5

def func(W):

    reconstructed_signal = [0 for i in range(len(simulation_timesteps))]

    for i in range(n):

        part_on = multiply(W[i], filtered_spikes_trimmed_on[i])
        part_off = multiply(W[n+i], filtered_spikes_trimmed_off[i])

        reconstructed_signal = numpy.add(reconstructed_signal, part_on)
        reconstructed_signal = numpy.add(reconstructed_signal, part_off)

    return mean_squared_error(input_signal, reconstructed_signal)

result = minimize(func, W0)

W = result.x

reconstructed_signal = [0 for i in range(len(simulation_timesteps))]

for i in range(n):

    part_on = multiply(W[i], filtered_spikes_trimmed_on[i])
    part_off = multiply(W[n+i], filtered_spikes_trimmed_off[i])

    reconstructed_signal = numpy.add(reconstructed_signal, part_on)
    reconstructed_signal = numpy.add(reconstructed_signal, part_off)

# =================================
# =================================
# Plot results

graph, (graph1, graph2) = plt.subplots(2, figsize=(20, 10), sharex=True)

graph1.plot(simulation_timesteps/msecond, reconstructed_signal)

# Create input current
I = 3*sin(2*pi*5*simulation_timesteps/msecond)

graph2.plot(simulation_timesteps/msecond, I)

show()

Hi. Very interesting project! I did not have time to read the paper, so maybe they deal with the issue I mention below in some other way. I’d be interested in knowing more about this (but then, I should probably just read the paper…).

So I agree at first sight the 0.5/-0.5 decoding looks good, but it is actually a very bad solution if you calculate the mean squared error. This is first of all because the amplitude is off: the sine wave is between about -200 and 200 instead of -3 and 3. Even if you correct the amplitude, you can see that this simple decoding does not account for the shift introduced by the neuronal integration and the decoding kernel. Below is a simulation where I plot everything on top of each other (dotted black line is the target, the orange line is the simple solution where I divided your W0 weights by 60 to correct the scale, and blue is the optimized solution). As you can see the optimized solution looks “poor”, but it actually shifts the reconstructed signal a bit to the left in a number of places, reducing its difference to the target signal.

If you think of possible solutions to this problem, you see that given the spike trains, the algorithm cannot do much. There is some variability in the onset of the spike trains due to the random bias which you can exploit to shift the reconstructed signal to the left (by weighing the early spike trains stronger), but due to the way the bias works, there is a symmetry that you cannot get rid of: the neurons that spike early due to a strong bias will also stop firing the last. This means that by accurately reconstructing the rising phase of the signal you will also introduce errors in the falling phase of the signal.

The only solution to this problem that I can think of right now (but the paper might propose a different one) is to fine-tune the bias and the strength of the signal so that each neuron fires very sparsely (e.g. even a single spike per period), with each neuron firing at a different relative phase. Then, with many neurons (i.e. n would need to be much bigger than 10), you should be able to reconstruct the signal accurately even with a lag. But then, this might be cheating since if you end up with many spike trains that basically have single spikes all over the place, you could reconstruct any signal, not only the sinuisoid that was given as the input…

Very insightful answer. Thank you very much! Regarding the amplitude of the reconstructed sine wave I only used 0.5 and -0.5 in order to test the shape as I used many different sine waves. Adjusting the amplitude manually is easy, the problem is for that to happen through the optimization of the weights. In the paper the authors use a different error for the minimization but from different sources I had seen the MSE used so I tried that first. As for the shift, the authors of the paper mention that the reconstructed signal is indeed shifted which as they say imply that the procedure has linear properties. I will look into what you said, thanks again!

Ah, if it is ok to reconstruct a shifted version of the signal, then most of my discussion is not relevant :upside_down_face: : Then, the only issue is to find an error function that measures what you want to minimize (maybe simply the mean-squared error difference to a shifted version of the signal?).

It is not irrelevant at all, you provided many good points. As a reference the error they use is the following

Since they don’t actually perform any other procedure maybe the error function in the reason for their successful implementation.