from brian2 import * v_th = 15 / 1000 tau = 50 * ms # Defining necessary equations and variables for Neuron. eqs = ''' dv/dt = -v/tau : 1 # Defining a random equation, since we are only interested in changes between Poisson Input and Neuron. w_pool : 1 # w_pool will be used for weight pool of neuron. ''' # Defining Poisson Group for input. P = PoissonGroup(5, 50 * Hz) MP = SpikeMonitor(P) # Defining SpikeMonitor to track spikes. # Defining a neuron group. G = NeuronGroup(1, model=eqs, threshold='v>v_th', method='exact') M = StateMonitor(G, 'v', record=True) # Defining necessary variables for Synapse. eqs_syn = ''' w : 1 # w will show weights. x : 1 # Counter variable x_prev : 1 # Holds the previous count. ''' S = Synapses(P, G, model=eqs_syn, on_pre=''' x += 1 # Count increases by one when a spike comes. ''') S.connect() MW = StateMonitor(S, 'w', record=True) # StateMonitor for monitoring weights. MW_POOL = StateMonitor(G, 'w_pool', record=True) # StateMonitor for monitoring w_pool. # Setting initial values S.w = 0 S.x = 0 S.x_prev = 0 G.w_pool = 5 @network_operation def weight_updater(): if S.x != S.x_prev: # I tried adding a condition like this, otherwise weight_updater will work everytime step. But it is not working since x and x_prev are not scalar values. S.x_prev = S.x i_pre = MP.i[-1] # Assigning index of the pre-synaptic neuron to i_pre by using Spike Monitor. i_post = 0 # Assigning index of the post-synaptic neuron to i_post (IF I HAVE MORE THAN ONE NEURON AT MY NEURON GROUP I DO NOT KNOW HOW TO GET POSTSYNAPTIC INDEX) updated_w = S.w[i_pre, i_post] # Assigning the w value of synapse which received spike to a dummy variable. synapse_count = sum(S.j == i_post) # synapse_Count shows how many synapse are connected to the neuron. # in this for loop rule (10) is executed, we look for each synapse that connects to the neuron then we check # if the weights of the other synapses are bigger than the synapse which received spike. # If the other synapse has bigger weight than the synapse that received spike, then rule (10) is applied accordingly. for a in range(synapse_count): if S.w[a, i_post] > updated_w and a != i_pre: updater_w = S.w[a, i_post] # Assigning to a dummy variable. S.w[a, i_post] -= (updater_w * (updater_w - updated_w)) / (3 + updater_w) S.w[i_pre, i_post] += (updater_w * (updater_w - updated_w)) / (3 + updater_w) S.w[i_pre, i_post] += 0.8 * G.w_pool[i_post] # Adding values received from pool to w. G.w_pool[i_post] = 0.2 * G.w_pool[i_post] # Updating pool reserve. run(100 * ms) print("pool", G.w_pool) print("w's", S.w) print(MP.i) plt.figure(100) # Drawing Figure 2. plt.plot(MW.t / ms, MW.w[0], label='Synapse 1', color='g') plt.plot(MW.t / ms, MW.w[1], label='Synapse 2', color='r') plt.plot(MW.t / ms, MW.w[2], label='Synapse 3', color='b') plt.plot(MW.t / ms, MW.w[3], label='Synapse 4', linestyle='dashed', color='c') plt.plot(MW.t / ms, MW.w[4], label='Synapse 5', linestyle='dashed', color='m') plt.plot(MW_POOL.t / ms, MW_POOL.w_pool[0], color='k', label='POOL') xlabel('Time (ms)') ylabel('w') legend() plt.show()