Hi@mstimberg, thanks a lot! In the jupyter notebook environment, I tried to set up a simple network, e.g.:
N = 25
M = 5
taum = 10ms
taupre = 20ms
taupost = taupre
Ee = 0mV
vt = -54mV
vr = -60mV
El = -74mV
taue = 5*ms
F = 15
gmax = .01
dApre = .01
dApost = -dApre * taupre / taupost * 1.05
dApost *= gmax
dApre *= gmax
class Model():
def init(self):
unit = {}
unit['PG'] = PoissonGroup(N, rates=np.zeros(N)*Hz, name='PG')
eqs_neurons = '''
dv/dt = (ge * (Ee-v) + El - v) / taum : volt
dge/dt = -ge / taue : 1
'''
unit['G'] = NeuronGroup(M, eqs_neurons, threshold='v>vt', reset='v = vr', method='euler',name='G')
unit['S'] = Synapses(unit['PG'], unit['G'],
'''w : 1
dApre/dt = -Apre / taupre : 1 (event-driven)
dApost/dt = -Apost / taupost : 1 (event-driven)''',
on_pre='''ge += w
Apre += dApre
w = clip(w + Apost, 0, gmax)''',
on_post='''Apost += dApost
w = clip(w + Apre, 0, gmax)''',
name='S')
unit['S'].connect()
unit['S'].w = 'rand() * gmax'
unit['M'] = StateMonitor(unit['S'], 'w', record=True, name='M')
self.net = Network(unit.values())
self.net.run(0*second)
def __getitem__(self, key):
return self.net[key]
def train(self, data):
self.net['PG'].rates = data*Hz
self.net.run(20*second, report='text')
model = Model()
model.train(F)
numpy.savez_compressed(‘1.npz’, i=model[‘S’].i[:], j=model[‘S’].j[:], w=model[‘S’].w[:])
weight = np.load(‘1.npz’)
weight[‘w’]
I can get the exact synapse weight value with the way you provided. But I need to import ‘.pt file’ for simulation on Mem Torch, when I use torch.save() to save the network parameters instead, e.g.:
class Model2(nn.Module):
def init(self):
super().init()
unit = {}
...............................
def train(self, data):
self.net['PG'].rates = data*Hz
self.net.run(20*second, report='text')
model2 = Model2()
model2.train(F)
torch.save(model2.state_dict(), ‘brian2_uns.pt’)
the program can run normally and a ‘.pt file‘ is generated, but when I use torch.load() to load the ‘.pt file’, e.g.:
file = torch.load(‘brian2_uns.pt’)
for parameter in file.parameters():
print(parameter)
there is the following warning:
AttributeError Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_15528\3549789619.py in <cell line: 2>()
1 file = torch.load(‘brian2_uns.pt’)
----> 2 for parameter in file.parameters():
3 print(parameter)
AttributeError: ‘collections.OrderedDict’ object has no attribute ‘parameters’
Is this because of a bug in my program or something? How do I modify my program to correctly generate and import .pt file? Thanks.