Description of problem
I created spiking neural network and i have train dataLoader from which i want to take input for my network. Also after run i wnat to update weights offline and run next example with updated weights
Minimal code to reproduce problem
from mit_bih import download_mitbih
train_loader, test_loader = download_mitbih()
from brian2 import *
import numpy as np
import matplotlib.pyplot as plt
import torch
from datetime import datetime
prefs.codegen.target = "numpy"
np.random.seed(25)
start_scope()
def get_random_20_percent_indices(n_count, exc_p=0.8, p=0.2, exc_w=1, inh_w=-1):
main_range = np.arange(n_count)
n = int(p * len(main_range))
idx = np.random.choice(main_range, n, replace=False)
exc_n = int(exc_p * len(idx))
wegths = []
for i in range(exc_n):
wegths.append(exc_w * np.random.rand())
for i in range(exc_n, len(idx)):
wegths.append(inh_w * np.random.rand())
return idx, wegths
dt = 0.1 * ms
defaultclock.dt = dt
single_exeample_time = 500 * ms
stimulus = TimedArray([], dt=dt)
batch_ecg, _ = next(iter(train_loader))
ecg = batch_ecg[0].numpy() if isinstance(batch_ecg, torch.Tensor) else batch_ecg[0]
# Input neuron params
N_input = 1
v_rest_input = 0*volt
v_threshold_input = 3*mV
g_l_input = 0.1
tau_input = 5*ms
input_eqs = '''
dv/dt = (v_rest_input - v + I/g_l_input) / tau_input : volt
I = abs(I_previous - I_current) * volt : volt
I_previous = stimulus(t-dt) : 1
I_current = stimulus(t) : 1
'''
input_group = NeuronGroup(N_input, input_eqs, threshold='v>v_threshold_input', reset='v=v_rest_input', method='euler')
input_group.v = 'v_rest_input'
print(f"[{datetime.now()}] Input neuron group created")
# LSM neuron params
N_lsm = 1000
N_lsm_exc = int(0.8 * N_lsm)
N_lsm_inh = N_lsm - N_lsm_exc
v_th_lsm = -55*mV
v_rs_lsm = -65*mV
tau_lsm = 5*ms
g_l_lsm = 1
refreac_lsm = 4*ms
lsm_eqs = '''
dv/dt = (v_rs_lsm - v) / tau_lsm : volt
'''
lsm_group = NeuronGroup(N_lsm, lsm_eqs, threshold='v>v_th_lsm', reset='v=v_rs_lsm', refractory=refreac_lsm, method='euler')
lsm_group.v = 'v_rs_lsm'
print(f"[{datetime.now()}] LSM neuron group created")
import numpy as np
def get_random_20_percent_indices(n_count = N_lsm, exc_p = 0.8, p = 0.2, exc_w = 1, inh_w = -1):
main_range = np.arange(n_count)
n = int(p * len(main_range))
idx = np.random.choice(main_range, n, replace=False)
exc_n = int(exc_p * len(idx))
wegths = []
for i in range(exc_n):
wegths.append(exc_w*np.random.rand())
for i in range(exc_n, len(idx)):
wegths.append(inh_w*np.random.rand())
return idx, wegths
w_i_lsm = 0.6
S_input_lsm = Synapses(input_group, lsm_group, 'w : 1', on_pre='v_post += 10*w*mV')
indeces, weigths = get_random_20_percent_indices(exc_w=w_i_lsm, inh_w=w_i_lsm)
for i in indeces:
S_input_lsm.connect(i=0, j=i)
S_input_lsm.w = weigths
print(f"[{datetime.now()}] Input -> LSM synapses created")
S_lsm_lsm = Synapses(lsm_group, lsm_group, 'w : 1', on_pre='v_post += w*mV')
weigths = []
for i in range(N_lsm):
post_indices, post_weights = get_random_20_percent_indices(exc_w=0.3, inh_w=-0.5)
S_lsm_lsm.connect(i=i, j=post_indices)
for indx in range(len(post_indices)):
weigths.append(post_weights[indx])
# for indx in range(len(post_indices)):
# if i != post_indices[indx]:
# S_lsm_lsm.connect(i=i, j=post_indices[indx])
# weigths.append(post_weights[indx])
S_lsm_lsm.w = weigths
print(f"[{datetime.now()}] LSM -> LSM synapses created")
# Output neuron params
N_output = 5
v_th_out = -55*mV
v_rs_out = -65*mV
tau_out = 5*ms
g_l_out = 1
out_eqs = '''
dv/dt = (v_rs_out - v) / tau_out : volt
'''
output_group = NeuronGroup(N_output, out_eqs, threshold='v>v_th_out', reset='v=v_rs_out', method='euler')
output_group.v = 'v_rs_out'
print(f"[{datetime.now()}] Output neuron group created")
S_lsm_out = Synapses(lsm_group, output_group, 'w : 1', on_pre='v_post += w*volt', name = 'S_lsm_out')
weigths = []
for f in range(N_output):
post_indices, post_weights = get_random_20_percent_indices(exc_w=0.3, inh_w=-1)
for indx in range(len(post_indices)):
S_lsm_out.connect(i=post_indices[indx], j=f)
weigths.append(post_weights[indx])
S_lsm_out.w = weigths
print(f"[{datetime.now()}] LSM -> output synapses created")
input_group_ST = StateMonitor(input_group, variables=True, record=True)
input_group_SP = SpikeMonitor(input_group, record=True)
lsm_group_ST = StateMonitor(lsm_group, variables=True, record=True)
lsm_group_SP = SpikeMonitor(lsm_group, record=True)
output_group_ST = StateMonitor(output_group, variables=True, record=True)
output_group_SP = SpikeMonitor(output_group, record=True)
print(f"[{datetime.now()}] Monitors created created")
network = Network([input_group, lsm_group, output_group, S_input_lsm, S_lsm_lsm, S_lsm_out, input_group_ST, input_group_SP, lsm_group_ST, lsm_group_SP, output_group_ST, output_group_SP])
network.run(0*ms)
print(f"[{datetime.now()}] Network built")
def print_nn():
fig, axs = plt.subplots(nrows=3, ncols=1)
axs[0].plot(input_group_SP.t/ms, input_group_SP.i, '.')
axs[0].set_title('Time (ms)')
axs[0].set_xlabel('Neuron index')
axs[0].set_ylabel('Input Neuron Spikes')
axs[1].plot(output_group_SP.t/ms, output_group_SP.i, '.')
axs[1].set_title('Time (ms)')
axs[1].set_xlabel('Neuron index')
axs[1].set_ylabel('Output Neuron Spikes')
axs[2].plot(lsm_group_SP.t/ms, lsm_group_SP.i, '.')
axs[2].set_title('Time (ms)')
axs[2].set_xlabel('Neuron index')
axs[2].set_ylabel('LSM Neuron Spikes')
plt.show()
def train_network(target_idx):
target_ratio = 0.8
out_counts = np.bincount(output_group_SP.i, minlength=N_output)
total_spikes = out_counts.sum()
ratio = out_counts[target_idx] / total_spikes if total_spikes > 0 else 0
if (ratio >= target_ratio):
print(f"[{datetime.now()}] Target ratio {ratio} reached, no weight update needed.")
return
for j in range(N_output):
for i in range(N_lsm):
if len(network['S_lsm_out'].w[i, j]) == 0:
continue
w = network['S_lsm_out'].w[i, j][0]
if j == target_idx:
delta = 0.01 * (- (ratio - target_ratio))
if w > 0:
delta = +abs(delta)
else:
suppression = out_counts[j] / total_spikes if total_spikes > 0 else 0
delta = -0.01 * suppression
if w < 0:
delta = -abs(delta)
new_w = w + delta
network['S_lsm_out'].w[i, j] = new_w
print(f"[{datetime.now()}] Output counts: {out_counts}, total spikes: {total_spikes}, ratio for target {target_idx}: {ratio}")
for batch_idx, (inputs, targets) in enumerate(train_loader):
for idx, input in enumerate(inputs):
print(f"[{datetime.now()}] Processing batch {batch_idx}, sample {idx}")
ecg = input.numpy() if isinstance(inputs, torch.Tensor) else input
stimulus = TimedArray(ecg.copy(order='C'), dt=dt)
network.run(single_exeample_time)
train_network(targets[idx].item())
# print_nn()
break