Brian2 implementation of MNIST training with STDP ( An attempt )

Description of problem

Hi all, I am trying to implement my own version of MNIST training with STDP. What I am doing here is simple: Use a 784 input neurons to an output neuron driven by LIF model. My aim is to force this network to learn the weights of class ‘0’, meaning that after training the network and visualising the weights, it should show ‘0’ image like what literature papers show. After training, i proceed to testing phase on training set ‘0’, ‘1’ and ‘2’ by turning off weight updates and try to confirm if it fires more spikes for class 0 than class 1 and class2. Indeed, it fires less for class 1, but somehow it fires almost the same no of spikes with class 2 as class 0, so I am not sure if this is correct? Can some expert advise on this? Link to colab is below

Minimal code to reproduce problem

https://drive.google.com/file/d/19zaQ4SPHBz-rQI1nVrKPzujseUq-L7Yz/view?usp=sharing

What you have already tried

Kind of stuck for many days on looking at how people do this, but most of the codes contain bugs or are not modular or easily readable for beginners :(((. Some expert might want to point me to some good source of reference? Not the Diehl paper 2015 because the code is very complicated… Yes, I have tried raiding the whole website, so far no one has show a GOOD code on how classification problems are solved with STDP.

Expected output (if relevant)

I want the network to fire more spikes ONLY on the class it is trained on, but seems like that is not the case. Am I doing this wrongly? How can I improve on this? Please help me…

Actual output (if relevant)

Can be seen from the google colab

Full traceback of error (if relevant)

NIL

THANK YOU SO MUCH kind people!

Hi. I moved this post to the #science category since it is a more general modelling question and not only about Brian itself. I don’t have much time to look into this myself, but it seems that your collab notebook is not shared publicly so no one will be able to actually have a look at it.

This one

Please let me know if you have issues opening this

I get “There was an error loading this notebook. Ensure that the file is accessible and try again.
Invalid Credentials” with the Google Colaboratory link, but the earlier Google Drive link (and then “Open With Google Colaboratory”) works.

A few remarks on the approach, without being an expert in the whole machine-learning side to it: just from the implementation side, I think the weights are too strong and the stimulation for each input is too short. You are getting close to 200 output spikes in 300 ms, which is as fast as this neuron can spike with its refractory time of 1.5ms. It basically spikes whatever input it gets. Usually one presents the stimuli for a longer time so that there can be some variation in the output rate, I think. Also, you probably need some pause between the individual stimuli for the membrane potential to go back to its resting state before the next stimulus.
But changing this will not make the model work in the way you want it to, I think. If you are only presenting one class of stimuli, it will not learn to discriminate it from other stimuli. It might kind of work for a class like the 0 which is bigger than most other digits, but if you think of e.g. learning on the class 3, presenting instead an 8 will only add more stimulation (it’s a 3 + some additional pixels).
The approach in the Diehl et al. (2015) paper was a different one: they had several output neurons that were inhibiting each other and presented all classes to them. The result was some kind of unsupervised learning, where individual neurons were competing via the inhibition. After the learning, they looked which neurons were the most responsive for each class, and then they assigned labels to the neurons accordingly and used them for classification.

In my personal opinion (and again, I’m not an expert on all that), STDP on MNIST is really a basic proof of concept but not the most interesting use case of STDP. There is no temporal structure to the input data (but there are some variants like N-MNIST that are a bit more interesting in that regard), so you could probably get the exact same result with a simpler learning rule (and rate-based neurons instead of spiking neurons, for that matter). You have to invest quite a bit of effort to make spiking-networks perform as good as much simpler standard ANN’s on these kind of problems.
In contrast, STDP on temporal spike patterns seems more biologically relevant and is straightforward to get working. I’m thinking of studies like this one: Spike Timing Dependent Plasticity Finds the Start of Repeating Patterns in Continuous Spike Trains

Just my 2 cents!

2 Likes

Thanks for the replies.
Anyway the reason that there are so many views here (52 at least) is because this topic is pretty interesting and it would be good for the scientific community to share some of reference codes that are simple and reproducible. So please do! Thanks much!

Hi @mstimberg, Thank you for your useful points. I have a question about that. I’m also trying to do pattern classification using SNN.
you’ve said :

The approach in the Diehl et al. (2015) paper was a different one: they had several output neurons that were inhibiting each other and presented all classes to them. The result was some kind of unsupervised learning, where individual neurons were competing via the inhibition. After the learning, they looked which neurons were the most responsive for each class, and then they assigned labels to the neurons accordingly and used them for classification.

I’m struggling with recognizing which neuron is the best fitted for my input after unsupervised learning. you meant I should count the number of spikes of each output neuron for each input, after calculating maximum number of spikes for each input, I should see which output neuron has the most frequency for that specific input?
Afterwards, based on specified output labels I can do testing ?

Yes, something like that (there are various ways to deal with the details, and for non-trivial tasks you would rather train a network with the output of the unsupervised learning). The easiest approach I can think of: for each input, you see which output neuron responds with the most spikes. Then, you count for each input category (i.e. digit), which neuron was the “winning” neuron the most often in this category. This neuron you then consider to be the neuron representing the digit. In the test phase, correct classification means that the neuron representing the tested digit fires more than the other neurons.
Note that before using this approach you first have to make sure that after assigning the output neurons to digits, you actually get one output neuron for each digit. It could be that a single neuron is the “winning” neuron for more than a single input category, in that case you’d have to come up with something more complex to assign output neurons to digits.

1 Like