Questions about store()

Description of problem

I set up a simple network, such as 100 input neurons and 10 output neurons, I use the “store()” command to save the network state after training (brian_train_model.b2). But I don’t know how to view the weights of the saved synapses weight. Can anyone help me?

Or can I use torch.save(net.state_dict(), ‘brian_train_model.pt’) to save the trained network?

Thank you.

Minimal code to reproduce problem

What you have aready tried

Expected output (if relevant)

Actual output (if relevant)

Full traceback of error (if relevant)

Hi @lwj . The store command is used to store a snapshot of a network that can be restored at a later point, but it is not meant to be used to as a general mechanism to store values for use elsewhere. In principle, you can unpickle the file and get the information out of it, but if the only thing you are interested in are the weights, then there are easier methods. The weights and connection indices are stored as a numpy array, so you can use standard numpy tools to store them to disk, e.g.:

numpy.savez_compressed('brian_train_model.npz',
                       i=synapses.i[:], j=synapses.j[:], w=synapses.w[:])

You can then load this file with numpy.load. You could also use torch.save, but this does the same as Python’s builtin pickle.dump in this case. You could e.g. use torch.save(net.get_states()), but this will save quite a lot of unnecessary information (all state variables for all objects in the network). If synaptic weights/indices are all you need, you can instead put everything you are interested in in a dictionary ({'i': synapses.i[:], 'j': synapses.j[:], 'w': synapses.w[:]}) and then store this dictionary using torch.save or pickle.dump.

Hi@mstimberg, thanks a lot! In the jupyter notebook environment, I tried to set up a simple network, e.g.:

N = 25
M = 5
taum = 10ms
taupre = 20
ms
taupost = taupre
Ee = 0mV
vt = -54
mV
vr = -60mV
El = -74
mV
taue = 5*ms
F = 15
gmax = .01
dApre = .01
dApost = -dApre * taupre / taupost * 1.05
dApost *= gmax
dApre *= gmax

class Model():
def init(self):
unit = {}

    unit['PG'] = PoissonGroup(N, rates=np.zeros(N)*Hz, name='PG')
    
    eqs_neurons = '''
                  dv/dt = (ge * (Ee-v) + El - v) / taum : volt
                  dge/dt = -ge / taue : 1
                  '''
    
    unit['G'] = NeuronGroup(M, eqs_neurons, threshold='v>vt', reset='v = vr', method='euler',name='G')
    unit['S'] = Synapses(unit['PG'], unit['G'],
                         '''w : 1
                            dApre/dt = -Apre / taupre : 1 (event-driven)
                            dApost/dt = -Apost / taupost : 1 (event-driven)''',
                         on_pre='''ge += w
                                Apre += dApre
                                w = clip(w + Apost, 0, gmax)''',
                         on_post='''Apost += dApost
                                 w = clip(w + Apre, 0, gmax)''',
                         name='S')
    unit['S'].connect()
    unit['S'].w = 'rand() * gmax'

    unit['M'] = StateMonitor(unit['S'], 'w', record=True, name='M')
    
    self.net = Network(unit.values())
    self.net.run(0*second)
    
def __getitem__(self, key):
    return self.net[key]
    
def train(self, data):
    self.net['PG'].rates = data*Hz
    self.net.run(20*second, report='text')

model = Model()
model.train(F)
numpy.savez_compressed(‘1.npz’, i=model[‘S’].i[:], j=model[‘S’].j[:], w=model[‘S’].w[:])

weight = np.load(‘1.npz’)
weight[‘w’]

I can get the exact synapse weight value with the way you provided. But I need to import ‘.pt file’ for simulation on Mem Torch, when I use torch.save() to save the network parameters instead, e.g.:

class Model2(nn.Module):
def init(self):
super().init()

    unit = {}
   ...............................
def train(self, data):
    self.net['PG'].rates = data*Hz
    self.net.run(20*second, report='text')

model2 = Model2()
model2.train(F)
torch.save(model2.state_dict(), ‘brian2_uns.pt’)

the program can run normally and a ‘.pt file‘ is generated, but when I use torch.load() to load the ‘.pt file’, e.g.:

file = torch.load(‘brian2_uns.pt’)
for parameter in file.parameters():
print(parameter)

there is the following warning:

AttributeError Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_15528\3549789619.py in <cell line: 2>()
1 file = torch.load(‘brian2_uns.pt’)
----> 2 for parameter in file.parameters():
3 print(parameter)

AttributeError: ‘collections.OrderedDict’ object has no attribute ‘parameters’

Is this because of a bug in my program or something? How do I modify my program to correctly generate and import .pt file? Thanks.

Hi @lwj. I’m afraid I cannot help you much with these issues, since it is about PyTorch’s interface, not about Brian itself. You are getting the error because you are storing a dictionary, but then try to access it as if it is an object that has a parameters() method. I am not a good person to ask about PyTorch, but from a cursory look at the documentation I think you’d need to rather do something like:

model2 = Model2()
saved_state = torch.load("brian2_uns.pt")
model2.load_state_dict(saved_state)
for parameter in model2.parameters():
    print(parameter)

But I am not sure whether the state_dict of your PyTorch module actually already contains the state of your Brian network – did you implement get_extra_state to do that? I think you need to have something like this in your nn.Module implementation:

class Model2(nn.Module):
    ...
    def get_extra_state():
         return self.net.get_states(units=False, read_only_variables=False)
    def set_extra_state(state):
         self.net.set_states(state, units=False)

Not sure if this internal state will be actually included in the “parameter” output above, but you should be able to check the weights in the network directly to see whether it worked. But again, not sure about all this :man_shrugging: