How to fix the value of a synapse after learning

Description of problem

Hello! I’m trying to use spiking neural network to learn MNIST dataset. I defined a STDP synapse with a variable ‘learned’ that it should be set equal to 0 after learning, so that the weights that encode a specific input remain fixed.

Minimal code to reproduce problem

class Simulation():

     def __init__():
          self.excpop = NeuronGroup()
          self.input_layer = NeuronGroup()
        self.Sinput = Synapses(self.input_layer, self.excpop,
                         '''w:1 
                            dx / dt = -x / tau_p : 1 (event-driven)
                            dy / dt = -y / tau_m : 1 (event-driven)
                            plastic : 1 (shared)
                            W_max : 1 (shared)
                            n_p : 1 (shared)
                            learned : 1''',
                            on_pre='''scx+=w
                                      x += apre
                                      w = clip(w-w*n_m*y*plastic*learned/W_max, 0, W_max)''',
                            on_post='''y += apost
                                       w = clip(w+(1-w/W_max)*n_p*(x - xtar)*plastic*learned, 0, W_max)''',
                            name='Sinput')
              self.Sinput.learned = '1'
        def train():
               self.net.run()
               #after learning
               self.Sinput['w>0.3'].learned = 0.

But it doesn’t work. the variable self.Sinput.learned that are defined in each synapse are always 1 and so the network keeps chaning the weights. I don’t understand why i can’t set it to 0 with that expression.
Is the idea a good way to fix synapses after learning?

Hi @caccola816. This is a perfectly reasonable way of doing things, but unfortunately you ran into a confusing syntax issue (on the good side, the fix is straightforward).

When you write: self.Sinput['w>0.3'] you are creating a so-called “synaptic subgroup”, i.e. a group of all synapses that have w>0.3. Currently, such a group is not very useful: it can only be used as a stand-in for synaptic indices. It does not have the full functionality of a Synapses object, in particular you cannot access synaptic variables via such a subgroup. If you tried something like print(self.Sinput['w>0.3'] .learned) you would get an AttributeError, stating that it does not have a learned attribute. Now, in Python assigning to a non-existing attribute of an object will create that attribute, i.e. your “synaptic subgroup” will end up with a new attribute learned with the value 0, but this will not be in any way connected to the actual synaptic variable. I know this is all very confusing, and I am currently working on a quite big change to Brian’s syntax handling (Non-contiguous subgroups by mstimberg · Pull Request #1229 · brian-team/brian2 · GitHub) which should also take care of this issue.

Until then, there is fortunately a different way to express this: instead of indexing the Synapses object, you can index the synaptic variable. I.e., the following will work as you’d expect:

self.Sinput.learned['w>0.3'] = 0.

Here’s a simplified version of your code that shows how after 1s, all synaptic weights that were above 0.3 stop updating:

code
from brian2 import *

class Simulation:
    def __init__(self):
        self.excpop = PoissonGroup(10, rates=10*Hz)
        self.input_layer = PoissonGroup(10, rates=10*Hz)
        self.Sinput = Synapses(self.input_layer, self.excpop,
                               '''w : 1 
                                  dx / dt = -x / (20*ms) : 1 (event-driven)
                                  dy / dt = -y / (20*ms) : 1 (event-driven)
                                  W_max : 1 (constant, shared)
                                  learned : 1''',
                               on_pre='''x += 1
                                         w = clip(w + y*learned, 0, W_max)''',
                               on_post='''y -= 1
                                          w = clip(w + x*learned, 0, W_max)''',
                               name='Sinput')
        self.Sinput.connect(j='i')
        self.Sinput.w = 'rand()'
        self.Sinput.W_max = 1
        self.Sinput.learned = 1
        self.weight_mon = StateMonitor(self.Sinput, 'w', record=True, dt=1*ms)

        self.net = Network(self.excpop, self.input_layer, self.Sinput, self.weight_mon)

    def train(self):
        self.net.run(1*second)
        # after learning
        self.Sinput.learned['w>0.3'] = 0.

    def test(self):
        self.net.run(1*second)

    def plot(self):
        plt.plot(self.weight_mon.t/ms, self.weight_mon.w.T)
        plt.axvline(1*second/ms, color='gray', linestyle='--')
        plt.xlabel('Time (ms)')
        plt.ylabel('Weight')
        plt.show()

if __name__ == '__main__':
    sim = Simulation()
    sim.train()
    sim.test()
    sim.plot()

Figure_1

1 Like

Yeaah now it works! :slight_smile:

Thank you very much! Indeed I was getting very confused results eheh