Description of problem
I’m dealing with a Pattern Matching task, which use SNNs to construct Logic Gates such as “and”, “or” etc. And use the basic Gates to match the pattern.
The Gate code:
def generate_logic_gate_block(gate_type):
v_reset = -80 * mV
lif_model = '''dv/dt = (v_reset - v)/tau_m : volt'''
synapses_eqs = '''w : 1 (constant)'''
synapses_pre_action = '''v_post += w * (v_th - v_reset + 5*mV)'''
if gate_type == 'not':
N_in = 2
N_out = 2
else:
N_in = 4
N_mid = 4
N_out = 2
input_group = NeuronGroup(N_in, model='v : volt', threshold='v >= v_th', reset='v = v_reset')
input_group.v = v_reset
if gate_type == 'not':
pass
else:
middle_group = NeuronGroup(N=N_mid, model=lif_model, threshold='v > v_th', reset='v = v_reset',
refractory=5 * ms, method='exact')
middle_group.v = v_reset
output_group = NeuronGroup(N=N_out, model=lif_model, threshold='v > v_th', reset='v = v_reset', refractory=5 * ms,
method='exact')
output_group.v = v_reset
if gate_type == 'not':
i2o = Synapses(input_group, output_group, on_pre='v_post = v_th + 5 * mV')
i2o.connect(i=[0, 1], j=[1, 0])
net = Network(input_group, output_group, i2o) # , output_mon, output_state_mon
else:
i2m = Synapses(input_group, middle_group, on_pre='v_post = v_post + (v_th - v_reset) * 0.5 + 5 * mV')
i2m.connect(i=[0, 2, 0, 3, 1, 2, 1, 3], j=[0, 0, 1, 1, 2, 2, 3, 3])
m2o = Synapses(middle_group, output_group, model=synapses_eqs, on_pre=synapses_pre_action)
m2o.connect()
m2o.w = weight_dic[gate_type]
net = Network(input_group, middle_group, output_group, i2m, m2o) # , output_mon, output_state_mon
return net, input_group, middle_group, output_group # , output_mon, output_state_mon
To create more Gates,
def generate_logic_gate_block_list(gate_dic: dict):
gate_network = Network()
head_layer_list = []
mid_layer_list = []
tail_layer_list = []
for key in gate_dic:
for _ in range(gate_dic[key]):
net, head_layer, mid_layer, tail_layer = generate_logic_gate_block(key)
gate_network.add(net)
head_layer_list.append(head_layer)
mid_layer_list.append(mid_layer)
tail_layer_list.append(tail_layer)
return gate_network, head_layer_list, mid_layer_list, tail_layer_list
The task is to match MNIST dataset, which use the average data of each number classes as the Pattern.
The Main code is
# parameters
v_th = -50 * mV # -54 * mV
v_reset = -80 * mV # -65 * mV
tau_m = 1 * ms
N_spike = 2
N_in = 4
N_mid = 4
N_out = 2
duration = 784 * ms # 784
delta_t, defaultclock.dt = 3 * 0.1 * ms, 0.1 * ms
on_pre = 'v_post = v_th'
# pattern spike ini
encoding_threshold = 0.3
pattern_net, spike_group, spike_mon = pattern_spike_generate(encoding_threshold=encoding_threshold)
task_network = Network(pattern_net)
# create gate
gate_num = 10
gate_dic = {'and': gate_num, 'xor': gate_num}
task_net, head_layer, _, tail_layer = generate_logic_gate_block_list(gate_dic)
task_network.add(task_net)
# synapses with pattern
for enum in range(len(spike_group)):
synapse_and = Synapses(spike_group[enum], head_layer[enum], on_pre=on_pre)
synapse_and.connect(i=[0, 1], j=[0, 1])
synapse_xor = Synapses(spike_group[enum], head_layer[enum + gate_num], on_pre=on_pre)
synapse_xor.connect(i=[0, 1], j=[0, 1])
task_network.add(synapse_and, synapse_xor)
# pattern spike ini
task = 'test'
pic_bin, pic_label = data_preparation(task)
spike_time = np.arange(0, 784, 1) * ms
correct_count = 0
and_spike = np.zeros(10)
xor_spike = np.zeros(10)
task_count = len(pic_label)
wrong_indices_list = []
# store the network
task_network.store(name='clf_MNIST', filename='clf_MNIST.network')
print('==========Start===========')
for pic_index in range(task_count):
# restore the network
task_network.restore(name='clf_MNIST', filename='clf_MNIST.network')
print('======= %d / %d =========' % (pic_index + 1, task_count))
pic_spike_group = SpikeGeneratorGroup(N=N_spike, indices=pic_bin[pic_index], times=spike_time, sorted=True)
pic_spike_mon = SpikeMonitor(pic_spike_group)
task_network.add(pic_spike_group, pic_spike_mon)
# synapses with single picture
synapse_and_list = []
synapse_xor_list = []
for enum in range(len(spike_group)):
synapse_and_pic = Synapses(pic_spike_group, head_layer[enum], on_pre=on_pre)
synapse_and_pic.connect(i=[0, 1], j=[2, 3])
synapse_and_list.append(synapse_and_pic)
synapse_xor_pic = Synapses(pic_spike_group, head_layer[enum + gate_num], on_pre=on_pre)
synapse_xor_pic.connect(i=[0, 1], j=[2, 3])
synapse_xor_list.append(synapse_xor_pic)
task_network.add((synapse_and_pic, synapse_xor_pic))
# monitor
spike_mon_and_list = []
spike_mon_xor_list = []
for enum in range(len(spike_group)):
spike_mon_and_output = SpikeMonitor(tail_layer[enum])
spike_mon_xor_output = SpikeMonitor(tail_layer[enum + gate_num])
spike_mon_and_list.append(spike_mon_and_output)
spike_mon_xor_list.append(spike_mon_xor_output)
task_network.add(spike_mon_and_output, spike_mon_xor_output)
# run simulation
task_network.run(duration, report='text') # , profile=True
# remove
task_network.remove(pic_spike_group, pic_spike_mon)
for enum in range(len(spike_group)):
task_network.remove(synapse_and_list[enum], synapse_xor_list[enum])
task_network.remove(spike_mon_and_list[enum], spike_mon_xor_list[enum])
print('==============Finish=================')
My CPU is Intel i7-9760H
, using cython
backend, but dealing with one picture(784ms) takes ~15s. The whole Train dataset(60000 #) will cost ~10 days.
======= 4201 / 60000 =========
Starting simulation at t=0. s for a duration of 0.784 s
0.5126 s (65%) simulated in 10s, estimated 5s remaining.
0.784 s (100%) simulated in 15s
Is there anything I can do to cut down the simulation time?