Pattern Matching task takes long time consuming

Description of problem

I’m dealing with a Pattern Matching task, which use SNNs to construct Logic Gates such as “and”, “or” etc. And use the basic Gates to match the pattern.

The Gate code:

def generate_logic_gate_block(gate_type):
    v_reset = -80 * mV
    lif_model = '''dv/dt = (v_reset - v)/tau_m : volt'''
    synapses_eqs = '''w : 1 (constant)'''
    synapses_pre_action = '''v_post +=  w * (v_th - v_reset + 5*mV)'''

    if gate_type == 'not':
        N_in = 2
        N_out = 2
    else:
        N_in = 4
        N_mid = 4
        N_out = 2

    input_group = NeuronGroup(N_in, model='v : volt', threshold='v >= v_th', reset='v = v_reset')
    input_group.v = v_reset

    if gate_type == 'not':
        pass
    else:
        middle_group = NeuronGroup(N=N_mid, model=lif_model, threshold='v > v_th', reset='v = v_reset',
                                   refractory=5 * ms, method='exact')
        middle_group.v = v_reset

    output_group = NeuronGroup(N=N_out, model=lif_model, threshold='v > v_th', reset='v = v_reset', refractory=5 * ms,
                               method='exact')
    output_group.v = v_reset

    if gate_type == 'not':
        i2o = Synapses(input_group, output_group, on_pre='v_post = v_th + 5 * mV')
        i2o.connect(i=[0, 1], j=[1, 0])
        net = Network(input_group, output_group, i2o)  # , output_mon, output_state_mon
    else:

        i2m = Synapses(input_group, middle_group, on_pre='v_post = v_post + (v_th - v_reset) * 0.5 + 5 * mV')
        i2m.connect(i=[0, 2, 0, 3, 1, 2, 1, 3], j=[0, 0, 1, 1, 2, 2, 3, 3])
        m2o = Synapses(middle_group, output_group, model=synapses_eqs, on_pre=synapses_pre_action)
        m2o.connect()
        m2o.w = weight_dic[gate_type]
        net = Network(input_group, middle_group, output_group, i2m, m2o)  # , output_mon, output_state_mon

    return net, input_group, middle_group, output_group  # , output_mon, output_state_mon

To create more Gates,

def generate_logic_gate_block_list(gate_dic: dict):
    gate_network = Network()
    head_layer_list = []
    mid_layer_list = []
    tail_layer_list = []

    for key in gate_dic:
        for _ in range(gate_dic[key]):
            net, head_layer, mid_layer, tail_layer = generate_logic_gate_block(key)
            gate_network.add(net)
            head_layer_list.append(head_layer)
            mid_layer_list.append(mid_layer)
            tail_layer_list.append(tail_layer)
    return gate_network, head_layer_list, mid_layer_list, tail_layer_list

The task is to match MNIST dataset, which use the average data of each number classes as the Pattern.

The Main code is

# parameters
v_th = -50 * mV  # -54 * mV
v_reset = -80 * mV  # -65 * mV
tau_m = 1 * ms
N_spike = 2
N_in = 4
N_mid = 4
N_out = 2
duration = 784 * ms  # 784
delta_t, defaultclock.dt = 3 * 0.1 * ms, 0.1 * ms
on_pre = 'v_post = v_th'

# pattern spike ini
encoding_threshold = 0.3
pattern_net, spike_group, spike_mon = pattern_spike_generate(encoding_threshold=encoding_threshold)
task_network = Network(pattern_net)

# create gate
gate_num = 10
gate_dic = {'and': gate_num, 'xor': gate_num}
task_net, head_layer, _, tail_layer = generate_logic_gate_block_list(gate_dic)
task_network.add(task_net)

# synapses with pattern
for enum in range(len(spike_group)):
   synapse_and = Synapses(spike_group[enum], head_layer[enum], on_pre=on_pre)
   synapse_and.connect(i=[0, 1], j=[0, 1])
   synapse_xor = Synapses(spike_group[enum], head_layer[enum + gate_num], on_pre=on_pre)
   synapse_xor.connect(i=[0, 1], j=[0, 1])
   task_network.add(synapse_and, synapse_xor)

# pattern spike ini  
task = 'test'
pic_bin, pic_label = data_preparation(task)
spike_time = np.arange(0, 784, 1) * ms
correct_count = 0
and_spike = np.zeros(10)
xor_spike = np.zeros(10)
task_count = len(pic_label)
wrong_indices_list = []

# store the network
task_network.store(name='clf_MNIST', filename='clf_MNIST.network')

print('==========Start===========')
for pic_index in range(task_count):
   # restore the network
   task_network.restore(name='clf_MNIST', filename='clf_MNIST.network')
   print('======= %d / %d =========' % (pic_index + 1, task_count))
   pic_spike_group = SpikeGeneratorGroup(N=N_spike, indices=pic_bin[pic_index], times=spike_time, sorted=True)
   pic_spike_mon = SpikeMonitor(pic_spike_group)
   task_network.add(pic_spike_group, pic_spike_mon)

   # synapses with single picture
   synapse_and_list = []
   synapse_xor_list = []
   for enum in range(len(spike_group)):
       synapse_and_pic = Synapses(pic_spike_group, head_layer[enum], on_pre=on_pre)
       synapse_and_pic.connect(i=[0, 1], j=[2, 3])
       synapse_and_list.append(synapse_and_pic)

       synapse_xor_pic = Synapses(pic_spike_group, head_layer[enum + gate_num], on_pre=on_pre)
       synapse_xor_pic.connect(i=[0, 1], j=[2, 3])
       synapse_xor_list.append(synapse_xor_pic)
       task_network.add((synapse_and_pic, synapse_xor_pic))

   # monitor
   spike_mon_and_list = []
   spike_mon_xor_list = []
   for enum in range(len(spike_group)):
       spike_mon_and_output = SpikeMonitor(tail_layer[enum])
       spike_mon_xor_output = SpikeMonitor(tail_layer[enum + gate_num])
       spike_mon_and_list.append(spike_mon_and_output)
       spike_mon_xor_list.append(spike_mon_xor_output)
       task_network.add(spike_mon_and_output, spike_mon_xor_output)

   # run simulation
   task_network.run(duration, report='text')  # , profile=True

   # remove
   task_network.remove(pic_spike_group, pic_spike_mon)
   for enum in range(len(spike_group)):
       task_network.remove(synapse_and_list[enum], synapse_xor_list[enum])
       task_network.remove(spike_mon_and_list[enum], spike_mon_xor_list[enum])

print('==============Finish=================')

My CPU is Intel i7-9760H, using cython backend, but dealing with one picture(784ms) takes ~15s. The whole Train dataset(60000 #) will cost ~10 days.

======= 4201 / 60000 =========
Starting simulation at t=0. s for a duration of 0.784 s
0.5126 s (65%) simulated in 10s, estimated 5s remaining.
0.784 s (100%) simulated in 15s

Is there anything I can do to cut down the simulation time? :pleading_face:

Hi. My main recommendation would be to reduce the number of objects, because each of them adds a little overhead (in particular for the compilation at the start of the run). As a general rule, neurons with the same equation should all be in the same NeuronGroup. You can then either use subgroups or just directly the indices to connect things correctly. This will also allow you to greatly reduce the number of SpikeMonitor and Synapses objects you need – remember, each of this objects will generate a Cython file which gets converted to C++ and then compiled, so redundant objects (in the sense that they have the same code) is very wasteful. In addition to the additional compilation time, simulating the objects will also be slower due to the additional per-object overhead.

A personal recommendation that I find useful in this kind of setting to avoid getting lost in index calculations: when you create your initial NeuronGroups, add additional “label” parameters to your equations. Without having looked at this in detail, in your case this could be something a long the lines:

'''...
gate_index : integer (constant)
input_index : integer (constant)
'''

(this could also be e.g. gate_type : integer (constant) etc.)
If you then set this correctly for each neuron, your later connection code becomes really straightforward and you can connect all the synapses in one go with something like:

synapses.connect('gate_index_pre == gate_index_post and input_index_pre == input_index_post')

Sometimes it might help to split things up into multiple connect calls (which each add connections on top of each other).

Related to the above recommendations: try to avoid creating/removing objects between runs (again, each of them needs to be compiled for each run, even though Brian can sometimes avoid it if it comes across the same code). To store monitor content for later use, rather store the data in .i and .t instead of the SpikeMonitor object itself. You can reuse the same SpikeGeneratorGroup for different patterns by changing its pattern with set_spikes. Finally, if you need to enable/disable objects like synapses, you can either set the Synapses's .active attribute to disable all the synapses in the object, or do something more fine-grained where you add a is_active : boolean (constant) parameter to the synaptic equations and then refer to this as part of the on_pre statement. Or if you want to switch off certain neurons in your input layer, you can add the same kind of parameter to your NeuronGroup equations and reformulate your threshold as something like v > v_th and is_active.

It is a bit of work, but you should end up with only a handful of objects and much faster code!

1 Like

@mstimberg Thanks for these kind advises, Marcel! Very useful!

When I first build the Network, only consider to construct the basic Gate structure, and use many redundent NeuronGroups. I’ll try to use Bigger group, and use subgroups to seperate them apart.

Didn’t notice this set_spikes method before, it’s much more convinient to change the pattern!

I’m gonna check other details, and thank you again!

1 Like