from brian2 import * import numpy as np # Defining necessary equations and variables for Neuron Group. eqs = ''' w_pool : 1 ''' # Defining Poisson Group as input. P = PoissonGroup(5, 50 * Hz) MP = SpikeMonitor(P, record=True) # Defining SpikeMonitor to track spikes. # Defining a neuron group. G = NeuronGroup(1, eqs) # How I use is_spike to get pre- and post-synaptic indexes: # I set is_spike to False at the start of simulation. # Assume that Poisson Group has 5 neurons and Neuron Group that has connection with that Poisson Group has 3 neurons. # If a spike came from Poisson neuron with index 3 and went to the neuron with index 2, is_spike value represents that # connection will be true. # Inside @network_operation I get the indexes of is_spike with True values by using np.where() and assign # them to a variable named w_index. # Then I use w_index to find pre- and post- synaptic indexes like following : # Synapse.i[w_index] will give me pre-synaptic index. # Synapse.j[w_index] will give me post-synaptic index. eqs_syn = ''' # Defining necessary variables for Synapse. w : 1 # w will show weights. is_spike : boolean ''' S = Synapses(P, G, model=eqs_syn, on_pre=''' is_spike = True ''') S.connect() MW = StateMonitor(S, 'w', record=True) # StateMonitor for monitoring weights. MW_POOL = StateMonitor(G, 'w_pool', record=True) # StateMonitor for monitoring w_pool. # Setting initial values S.w = 0 G.w_pool = 5 S.is_spike = False syn_arr = [S] # In case there is more than one synapse object to be able to iterate over them I added them to a list. neuron_group_arr = [G] # In case there is more than one neuron group object to be able to iterate over them I added them to a list. def update_process(synapse, neuron_group, w_index): synapse.is_spike[w_index] = False # I make True is_spike values false again to be able to understand if a new spike comes. for ind_pre, ind_post in zip(synapse.i[w_index], synapse.j[w_index]): updated_w = synapse.w[ind_pre, ind_post] synapse_count = sum(synapse.j == ind_post) # To find how many synapse connected to the post-synaptic neuron. for i in range(synapse_count): # Iterating over synapses which are connected to the post-synaptic neuron. if synapse.w[i, ind_post] > updated_w and i != ind_pre: # Rule 10 from article. updater_w = synapse.w[i, ind_post] # Assigning weight to a dummy variable. synapse.w[i, ind_post] -= (updater_w * (updater_w - updated_w)) / (5 + updater_w) synapse.w[ind_pre, ind_post] += (updater_w * (updater_w - updated_w)) / (5 + updater_w) synapse.w[ind_pre, ind_post] += 0.8 * neuron_group.w_pool[ind_post] # Taking weight from pool. neuron_group.w_pool[ind_post] = 0.2 * neuron_group.w_pool[ind_post] # Updating pool reserve. @network_operation(when='after_synapses') def weight_updater(): for arr_iter in range(len(syn_arr)): w_index = np.where(syn_arr[arr_iter].is_spike)[0] if len(w_index) > 0: # If a spike arrived within time step calls the update_process function. update_process(syn_arr[arr_iter], neuron_group_arr[arr_iter], w_index) run(100 * ms) plt.figure(100) # Drawing Figure 2. plt.plot(MW.t / ms, MW.w[0], label='Synapse 1', color='g') plt.plot(MW.t / ms, MW.w[1], label='Synapse 2', color='r') plt.plot(MW.t / ms, MW.w[2], label='Synapse 3', color='b') plt.plot(MW.t / ms, MW.w[3], label='Synapse 4', linestyle='dashed', color='c') plt.plot(MW.t / ms, MW.w[4], label='Synapse 5', linestyle='dashed', color='m') plt.plot(MW_POOL.t / ms, MW_POOL.w_pool[0], color='k', label='POOL') xlabel('Time (ms)') ylabel('w') legend() plt.show() print(S.i) print("pool:", G.w_pool) print("w's:", S.w) print('MP.i:', MP.i) print('MP.t:', MP.t)