Winner Take All mechanism using lateral inhibition

Description of problem

Hello everyone,
I’ve been trying to implement a winner-take-all mechanism similar to the one used in this paper. Here, the author implemented a winner-take-all mechanism that led the neurons to fire sequentially and each one at a time.

In my code, each neuron has two types of connections, excitatory and inhibitory. The first is shared between the input (28x28 inputs) and the neurons, and the second is shared between each neuron and all the other neurons (except the source neuron with itself). Each time a neuron fires, it increases the variable inhIN, by 1, of all the post-neurons through the inhibitory synapses. The variable inhibited is equal to 1 if inhIN < 0. On the excitatory connections, each time a pre-neuron/input fires, the potential of the post-neurons is increased as ‘v_post += w x inhibited_post’. My idea would be for a neuron to inhibit all its post-neurons (via inhibitory synapses) by increasing their inhIN by 1 resulting in inhibited = 0 and v_post += w*0 += 0. In other words it would make the post-neurons disregard the input. This actually happens when I feed the net but not totally as wanted. See ‘Actual output’ for a brief explanation of what isn’t working as expected.

My question is, how could I make neurons, from the neurons group, fire each one at a time? Should I adopt another model for the neuron?

Any help would be much appreciated!

Minimal code to reproduce problem

wmax = 1
Apre = 0.4
taupre = 10 * ms
taupos = 10 * ms
Apost = -Apre * 0.2
tau_inh = 5 * ms

eqs_neurons = '''
dv/dt = (-v/tau)  : 1 (unless refractory)
dinhIN/dt = (-inhIN/tau_inh) : 1
inhibited = int(not(inhIN > 0.001)) : 1
'''
neurons = NeuronGroup(n_neurons, eqs, threshold='v>1', refractory=5*ms, reset='v=0', 	method='exact')


(...)	

spikes = intensity_to_temporal(train_X)
In = SpikeGeneratorGroup(28*28, [0], [0*ms] , sorted = True)

(...)

inhibitory = Synapses(neurons, neurons,
	on_pre='''inhIN_post = inhIN_post + inhibited_post
	''')

inhibitory.connect('i!=j')

stdp = Synapses(In, neurons,'''
    	         w : 1
            	 dapre/dt = -apre/taupre : 1 (event-driven)
            	 dapost/dt = -apost/taupost : 1 (event-driven)
             	''',
             	on_pre='''
             	v_post += w*inhibited_post
             	apre += Apre*inhibited_post
             	''',
             	on_post='''
             	apost += Apost
             	w = w + w * (wmax - w)*(apre+apost)
             	''')
	stdp.connect()


stdp.w = 'rand()/20'

(...)
#feed the neurons group with the MINST dataset
(...)
		
spi = StateMonitor(neurons, ['v','inhIN', 'inhibited'], record=True)
spi_dts = SpikeMonitor(neurons)

run(pattern_dt)

		
for ni in range(Nneurons):
	plt.plot(spi.t/ms, spi.v[ni], label=f'Neuron {ni} inhibited')
	plt.fill_between(spi.t/ms, 0, 1, where=(spi.inhibited[ni]==0), facecolor=colors[ni], alpha=0.1, hatch=hat[ni], label=f'Neuron {ni} inhibited')
	plt.plot(spi.t/ms, spi.inhIN[ni], label=f'inhIN [{ni}]')
for spik in spi_dts.t:
	plt.axvline(spik/ms, ls=':')
plt.xlabel('')
plt.legend()
plt.show()

What you have aready tried

I have tried to increase the defaultclock.dt to 0.01*ms.

Expected output (if relevant)

Have each neuron to fire at a time, by inhibiting the post-neurons when it fires, as shown below. This is the actual and expected output, but not always (in this simulation rate encoding was used).
Attention: the legend is partially incorrect. The blue and green lines correspond to the neurons 0 and 1 potential, respectively.

Actual output (if relevant)

  • I intend to set the inhibition duration (the time while inhIN > 0 <=> inhibited = 0) through the variable tau_inh. But when I plot these variables during simulation, I get longer periods than the value set for this variable. Also, to update the value of the variable ‘inhibited’, I had to use the condition ‘inhIN > 0.001’ instead of ‘inhIN > 0’ because the inhibited period would be even bigger. For the simulation shown below, tau_inh and pattern_dt (duration of the pattern) are equal to 0.5 and 10 ms, respectively.

  • Neurons still fire at the same time. For the simulation show below, I’m using temporal encoding (each value is converted into a spike-time, higher values correspond to earlier spikes) on the MINST dataset. This results on having several inputs firing at the same time (same input values for set of pixels = same spike-times for set of pixels).
    Attention: the legend is partially incorrect. The blue and green lines correspond to the neurons 0 and 1 potential, respectively.

e

I can’t seem to edit the post above so I’ll post the full code here in case anyone wants to try it out.
In each iteration I feed an image from the MINST dataset (28x28 inputs) to a group of Nneurons for pattern_dt ms, and then let them rest for time_to_rest ms. The encoding scheme can be chosen when running the script like, ‘python3 cSTDP.py temporal’ or ‘python3 cSTDP.py rate’ to convert the image into spikes using temporal or rate encoding, respectively. This script will plot the variables v, inhIN, inhibited which are described in the post above.

import matplotlib.pyplot as plt
from brian2 import *
from keras.datasets import mnist



(train_X, train_y), (test_X, test_y) = mnist.load_data() 

training_size = 100	
window_size = 28	
pattern_dt = 10*ms	

Nneurons = 2		

tau = 10*ms
taupre=5*ms	
taupost=3*ms
wmax = 1
Apre = 1
Apost = -Apre*0.2

tauinh=0.5*ms

time_to_rest = 10*ms

eqs = '''
dv/dt = (-v/tau)  : 1 (unless refractory)
dinhIN/dt = (-inhIN/tauinh) : 1
inhibited = int(not(inhIN > 0.001)) : 1
'''

reset = '''
v = 0
'''


def intensity_to_rates(training_set):

	rates_ = [[[] for column in range(28)] for image in range(training_size)]

	for i, img in enumerate(training_set[0:training_size]):
		for img_column in range(28):
			rates_[i][img_column] = (img[img_column, :]/255)*200 + 2
	
	return rates_

def intensity_to_temporal(training_set):

	rates_ = [[[] for column in range(28)] for image in range(training_size)]

	for i, img in enumerate(training_set[0:training_size]):
		for img_row in range(28):
			rates_[i][img_row] = ((255 - img[img_row, :])/255)*pattern_dt/ms
	return rates_



if __name__ == "__main__":

	#clear_cache('cython')
	
	rate_encoding = False
	
	for a, arg in enumerate(sys.argv[1:]):
		if arg == 'rate':
			rate_encoding = True
		elif arg == 'iter':
			training_size = int(sys.argv[(a+1)+1])
		
		 
	neurons = NeuronGroup(Nneurons, eqs, threshold='v>1', refractory=5*ms, reset=reset, method='exact')
	
	inhibitory = Synapses(neurons, neurons,
		on_pre={'inhibit' : '''inhIN_post = inhIN_post + inhibited_post
		'''})
	
	inhibitory.inhibit.when='before_thresholds'
	
	inhibitory.connect('i!=j')
	
	
	train_X = [x for i,x in enumerate(train_X) if (train_y[i]==0)]
	
	if rate_encoding:
		rates = intensity_to_rates(train_X)
		In = PoissonGroup(28*window_size, rates=[0]*Hz)
	else:	
		spikes = intensity_to_temporal(train_X)
		In = SpikeGeneratorGroup(28*window_size, [0], [0*ms] , sorted = True)		
	
	
	
	
	
	stdp = Synapses(In, neurons,'''
                 w : 1
                 dapre/dt = -apre/taupre : 1 (event-driven)
                 dapost/dt = -apost/taupost : 1 (event-driven)
                 ''',
                 on_pre='''
                 v_post += w*inhibited_post
                 apre += Apre*inhibited_post
                 ''',
                 on_post='''
                 apost += Apost
                 w = w + w * (wmax - w)*(apre+apost)
                 ''')
	stdp.connect()
	
	stdp.w = 'rand()/20'
	
	total_t = 0
	
	colors = ['red','blue','green','grey', 'bisque', 'mistyrose', 'coral', 'powderblue', 'pink', 'palegreen']
	hat = ['/', '-', '+', 'x', 'o', 'O', '.', '*', '/', '-', '+', 'x', 'o', 'O', '.', '*']
	fire_dts = [[[] for g in range(training_size)] for k in range(Nneurons)]
	
	axs_obj = [[] for _ in range(Nneurons)]
	
	
	for tr in range(training_size):
	
		state = StateMonitor(stdp, ['w'], record=True)
		
		
		for pattern in range(int(28/window_size)):
		
			spi = StateMonitor(neurons, ['v','inhIN', 'inhibited'], record=True)
			spi_dts = SpikeMonitor(neurons)
			
			if rate_encoding:			
		
				r = np.array(rates[tr]).T[pattern*window_size:(pattern+1)*window_size].T	
							
				In.rates = [item for sublist in r for item in sublist]*Hz			
			
			else:
			
				s = np.array(spikes[tr]).T[pattern*window_size:(pattern+1)*window_size].T
				inp = [(c+r*window_size, (column + total_t)*ms) for r, row in enumerate(s) for c, column in enumerate(row)]
				ids = [i[0] for i in inp]
				ts = [i[1] for i in inp]
				In.set_spikes(ids, ts, sorted = False)					
			
			run(pattern_dt)
			for i, idx in enumerate(spi_dts.i):
				fire_dts[idx][tr].append(spi_dts.t[i]/ms - total_t)
			
			run(time_to_rest)
			total_t = total_t + pattern_dt/ms + time_to_rest/ms
			

			for ni in range(Nneurons):
				plt.plot(spi.t/ms, spi.v[ni], label=f'Neuron {ni} voltage')
				plt.fill_between(spi.t/ms, 0, 1, where=(spi.inhibited[ni]==0), facecolor=colors[ni], alpha=0.1, hatch=hat[ni], label=f'Neuron {ni} inhibited')
				plt.plot(spi.t/ms, spi.inhIN[ni], label=f'inhIN [{ni}]')
			for spik in spi_dts.t:
				plt.axvline(spik/ms, ls=':')
			plt.xlabel('')
			plt.legend()
			plt.show()
				
			del state
			del spi
			del spi_dts
	
	plt.show()

Hi @diogohmsilva . Thanks for sharing the runnable code. Regarding the duration of the inhibition: you are modelling the inhibitory current with an exponential decay, the constant tauinh is therefore not setting the total duration but the time constant of the decay (i.e. the time until the signal reaches 1/e of the original value). You can use a more artifical inhibitory current, where the current stays up for a certain time and then goes down, i.e. a rectangular current:

	inhibitory = Synapses(neurons, neurons,
		on_pre={'up' : 'inhIN_post += 1', 'down': 'inhIN_post -=1'},
        delay={'down': tauinh})

When a spike arrives, inhIN is increased by 1, and after a delay of tauinh, it is reduced by 1.

Now, regardless of the parameters, the scheduling that you set with when, etc., this model can never completely avoid that two neurons fire within the same time step. This is because the thresholding phase which determines which neuron fires is run for all neurons as a block, so if two neurons happen to cross the threshold, both will emit spikes and inhibit the other neurons. The only way to fix this would be with a network_operation (or a user-defined function written in Cython, but that’s much more complicated), which intervenes after the thresholding stage and forces only a single neuron to spike. Usually, models with a WTA mechanism (like the one you referenced) simply ignore this problem, since it is very unlikely that two neurons spike in the exact same time step. This is not the case for your model, since your transformation into spike times will lead to most input spikes arriving in the first or last time step of the pattern window (corresponding to black or white pixels, I guess). The weights to the neurons might be random, but if each neuron sums up 100s of random weights, the summed result will be almost the same.
I can think of some ways you could deal with this:

  • change your encoding into spike times: currently, you encode space into “input lines” and intensity into spike time, but you could also add an encoding of space into the spike times (e.g. pixels spike earlier on the left than on the right). I think a number of people have thought of spike-time encodings of the MNIST dataset, in the extreme case you could use the spikes recorded from a spike-based camera that scans the displayed number with simulated saccades: Garrick Orchard - N-MNIST
  • Make your weights smaller, so that they are just enough to evoke a spike (actually, in that case you’d probably not even need an inhibition mechanism to only make a single neuron spike, but it has the disadvantage that you might not get any spikes for some presented digits).
  • use synaptic current input instead of the “delta synapse” implementation that directly increases the membrane potential. With your current model, there is no difference between a strong super-threshold input and an input that is barely super-threshold – both will make a neuron spike. With a synaptic current input, a strong input will make the neuron spike very soon, and a weak input will only make it spike after a certain delay.
  • Add random delays to your inputs (not sure that’s really a good idea)

Hope that gives you some ideas, best
Marcel

Hi @mstimberg,

Thank you for the clarification.
I was expecting that this would be the case (* because the thresholding phase which determines which neuron fires is run for all neurons as a block*) and so tried to come up with another idea.

My strategy now is to have two potentials that receive the same input and have one of them (the real one) delayed by defaultclock.dt us. The potential that is not delayed, pseudo_v, will be used to evaluate what neuron had the maximum voltage when multiple neurons fire at the same time. Here is the code.
Ultimately I want to be able to evaluate who had the maximum voltage one clock cycle before the real voltage receives the input so I can inhibit that from happening.

import matplotlib.pyplot as plt
from matplotlib.pyplot import ion 
import os

from brian2 import *

inp = PoissonGroup(1, rates=250*Hz)

clock_dt = defaultclock.dt


eqs_G = '''
dv/dt = -v/(tau_m) : 1
dpseudo_v/dt = -pseudo_v/(tau_m) : 1
dinhIN/dt = -inhIN/tau_inh : 1
inhibited = int(not(inhIN > 0.001)) : 1
tau_inh : second
tau_m : second
winner : boolean
'''

neurons = NeuronGroup(2, eqs_G, threshold='v>1 and winner',
                reset='v = 0; winner = False',
                refractory='5*ms',
                events={'IPSP' : 'pseudo_v > 1',
                        'inhibition' : 'inhIN > 1/(exp(1))',
                        'take_all' : 'inhIN > 1/(exp(1))'})

@check_units(idx=1, result=bool)
def wta(idx):
    return idx == np.argmax(neurons.v)

wta = Function(wta)


neurons.tau_m = '50*ms'
neurons.tau_inh = '1*ms'
neurons.winner = 'False'

neurons.run_on_event('inhibition', 'pseudo_v= v * exp(-clock_dt/tau_m)')
neurons.run_on_event('take_all', 'winner=wta(i)')
            

inpS = Synapses(inp, neurons, '''w : 1''',
                on_pre='''v_post += w*inhibited_post''')
inpS.connect()
inpS.w = '0.2'
inpS.delay = clock_dt

inpS_early = Synapses(inp, neurons, '''w : 1''',
                on_pre='''pseudo_v_post += w*inhibited_post''')
inpS_early.connect()
inpS_early.w = '0.2'


inhibitory = Synapses(neurons, neurons,
                  on_pre={'IPSPpath': 'inhIN_post += inhibited_post'},                
                  on_event={'IPSPpath': 'IPSP'}
                  )

inhibitory.connect('i!=j')

ss = StateMonitor(neurons, ['v', 'pseudo_v', 'inhIN'], record=True)

run(100*ms)

for n in range(2):    
    plt.plot(ss.t/ms, ss.v[n], label=f'v[{n}]')
    plt.plot(ss.t/ms, ss.pseudo_v[n], label=f'pseudo_v[{n}]')
    plt.plot(ss.t/ms, ss.inhIN[n], label=f'inhIN[{n}]')


plt.legend()
plt.show()

del ss

When I try to run this I get the error "NotImplementedError: Cannot use function wta: 'No implementation available for target cython. Available implementations: ’ ". A similar implementation would be
neurons.run_on_event('take_all', 'winner = (i ==np.argmax(A.v))')
which also doesn’t seem to work.

Is it not possible to use user provided functions with run_on_event?

Your suggestions brought up some questions.

  • change your encoding into spike times’ - I thought about this earlier, like encoding the row and column of each pixel but I’m not sure how to do it. With 1 variable (like pixel intensity) I generate a spike according to the variable value as pattern_duration x (value_max - value)/value_max. With 2 variables, p.e. pixel intensity and index (row x 28+column), the encoding scheme I can think of would be having value_max windows inside the pattern duration (each for each possible value of pixel intensity), where each window has number_rows x number_columns possible spike times (each for each input). Or the opposite, that is, having number_rows x number_columns windows and in each window max_values possible spike times. Is this what you mean when you say ‘you could add also add an encoding of space into the spike times’? Also, what do you mean with ‘you encode space into “input lines”’?

  • use synaptic current - I’ve tried it but I still don’t fully understand it. Will try again.

  • Add random delays to your inputs (not sure that’s really a good idea) - also tried before and turns out the delay sometimes ‘helps’ neurons to fire at the same time.

Hi. As I mentioned in my earlier post, using a custom function to deal with this is a bit complicated, using a network_operation would be easier. You are getting the error message because the model code is translated into Cython code, and you only provide a Python implementation for your function. But again, this complicates things here, since functions are meant to do calculations for a single neuron/synapse, not across neurons as your np.argmax here. Your second variant cannot run, since code you provide in the strings is not Python code, it is “Brian code” in a special syntax (a subset of Python syntax) that later gets translated into Cython code.

Here’s a simple example (without the inhibition etc.), that shows how you can use a network_operation to restrict spiking to a single neuron in each time step:

tau_m = 50*ms
neurons = NeuronGroup(2, '''dv/dt = -v/tau_m : 1
                            is_winner : boolean''',
                      threshold='v > 1 and is_winner',
                      reset='v=0')
@network_operation(when='before_thresholds')
def determine_winner():
    neurons.is_winner[:] = (neurons.i[:] == np.argmax(neurons.v_[:]))

There are many possibilities, but shifting the window would be one option. Maybe not for both column and row, but only for one of them? E.g. if you only encode the column in time, you could have a window from 0 to 10ms for the first column, 1 to 11ms for the second column, and so on. There are a number of possibilities, but all is of course quite arbitrary.
When I said “you encode space into input lines”, I meant that each of your pixels has a single connection/synapse (“input lines”) – this probably makes sense, but in principle there is an infinite number of possibilities to code things up. You might also consider not having spikes for the 0 values instead of making them spike at the beginning of the interval. See e.g. https://www.frontiersin.org/articles/10.3389/fnins.2021.712667/full

The idea is instead of increasing v_post, you change your model to something like Example: CUBA — Brian 2 2.5.4 documentation and increase the synaptic current (g_e in that example) instead. This way, stronger inputs will lead to earlier spikes.

This seems to work, thanks!