Description of problem
Hello and Good day. I am trying to run a simulation for lets say 50k Hodgkin-Huxley type neurons . However, I might need to simulate the system for long times (50s in biological equivalent). Even without the synapses, the run-times are impractically large on my 16 core ubuntu system ( 3d 4h 10m) . Is there a way to parallelize the code ( on CPU or GPU to simulate all the differential eqns in different cores) or otherwise to deal with this issue ?
Minimal code to reproduce problem
######### IMPORTS ##############
from brian2 import *
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
######### Differential Eqns & constants ##############
start_scope()
defaultclock.dt = 0.01*ms
gnaf = 140*nS
gnas = 12*nS
gkv = 6*nS
gka = 58.1*nS
gks = 8.11*nS
gl = 0.39*nS
Cm = 4*pF
#reversal potentials
ena = 58*mV
ek = -81*mV
el = -65*mV
eqs = '''
#dv/dt = (Iapp - gnaf*m**3*h*(v-ena) - gkv*n**4*(v-ek) - gka*ma**3*ha*(v-ek) - gks*ms**2*(v-ek) - gl*(v-el))/Cm : volt
dv/dt = (Iapp -inaf -inas -ikv -iks -ika - gl*(v- el))/Cm : volt
# #fast potassium current ka
ika = gka*mka**3*hka*(v-ek) : amp
dhka/dt = (hkainf - hka)/tau_hka : 1
tau_hka = 2.5*ms + (tau_hka0/(1 + exp(( v -(-62*mV))/(16*mV))) ): second
tau_hka0 = ( (90*ms - 2.5*ms) / (1 + exp(( v -(52*mV))/(15*mV))) ) : second
hkainf = 1/(1 + exp((v - (-74.7*mV))/(7*mV))) : 1
dmka/dt = (mkainf - mka)/tau_mka : 1
tau_mka = 0.35*ms + (tau_mka0/(1 + exp(( v -(-20*mV))/(12*mV))) ) : second
tau_mka0 = ( (1.65*ms - 0.35*ms) / (1 + exp(( v -(20*mV))/(20*mV))) ) : second
mkainf = 1/(1 + exp((-20.1*mV - v)/(16.1*mV))) : 1
#slow potassium current ks
iks = gks*mks**3*hks*(v-ek) : amp
dhks/dt = (hksinf - hks)/tau_hks : 1
tau_hks = 150*ms + ( (200*ms - 150*ms) / (1 + exp(( v -(52*mV))/(15*mV))) ) : second
hksinf = 1/(1 + exp((v - (-74.7*mV))/(7*mV))) : 1
dmks/dt = (mksinf - mks)/tau_mks : 1
tau_mks = 0.5*ms + ( (5.0*ms - 0.5*ms) / (1 + exp(( v -(20*mV))/(20*mV))) ) : second
mksinf = 1/(1 + exp((-20.1*mV - v)/(16.1*mV))) : 1
#potassium current kv
ikv = gkv*mkv**4*(v-ek) : amp
dmkv/dt = (mkvinf - mkv)/tau_mkv : 1
tau_mkv = 1.85*ms + ( (3.53*ms - 1.85*ms) / (1 + exp(( v -(45*mV))/(13.71*mV))) ) : second
mkvinf = 1/(1 + exp((-37.6*mV - v)/(27.24*mV))) : 1
#slow sodium current
inas = gnas*mnas**3*hnas*(v-ena) : amp
dhnas/dt = (hnasinf - hnas)/tau_hnas : 1
tau_hnas = 1.9*ms + ( (12.24*ms - 1.9*ms) / (1 + exp(( v -(-32.6*mV))/(8.0*mV))) ) : second
hnasinf = 1/(1 + exp((v - (-51.4*mV))/(5.9*mV))) : 1
dmnas/dt = (mnasinf - mnas)/tau_mnas : 1
tau_mnas = 0.093*ms + ( (0.83*ms - 0.093*ms) / (1 + exp(( v -(-20.3*mV))/(6.45*mV))) ) : second
mnasinf = 1/(1 + exp((-30.1*mV - v)/(6.65*mV))) : 1
#fast sodium current
inaf = gnaf*mnaf**3*hnaf*(v-ena) : amp
dhnaf/dt = (hnafinf - hnaf)/tau_hnaf : 1
tau_hnaf = 0.12*ms + ( (1.66*ms - 0.12*ms) / (1 + exp(( v -(-8.03*mV))/(8.69*mV))) ) : second
hnafinf = 1/(1 + exp((v - (-51.4*mV))/(5.9*mV))) : 1
dmnaf/dt = (mnafinf - mnaf)/tau_mnaf : 1
tau_mnaf = 0.093*ms + ( (0.83*ms - 0.093*ms) / (1 + exp(( v -(-20.3*mV))/(6.45*mV))) ) : second
mnafinf = 1/(1 + exp((-30.1*mV - v)/(6.65*mV))) : 1
'''
######### RUNNER ##############
nKC = 50000
neuron = NeuronGroup(nKC, eqs, method='rk4')
neuron.v = '-80*mV*rand()*i'
neuron.hnaf = 1.0
neuron.mnaf = 0.0
neuron.hnas = 1.0
neuron.mnas = 0.0
neuron.mkv = 0.0
neuron.mks = 0.0
neuron.hks = 1.0
neuron.hka = 1.0
neuron.mka = 0.0
#record the membrane potential for all neurons
#M = StateMonitor(neuron, 'v', record=True)
#can also record the gating variables
M = StateMonitor(neuron, ['v'], record=True)
Iapp = 0*pA
run(50*ms, report='text')
Iapp = 18*pA
run(50*second, report='text')
######### PLOTTING AND SAVING DATA ##############
#plot all the neurons
plt.figure(figsize=(12, 6))
for i in range(10):
plt.plot(M.t/ms, M.v[i]/mV)
plt.xlabel('Time (ms)')
plt.ylabel('v (mV)')
plt.ylim(-80,20)
plt.show()
#save the data as .npy file
np.save('v.npy', M.v)
np.save('t.npy', M.t)
def spike_counter2(t,V,th, sparse = 100) :
V = np.array(V)
t = np.array(t)
#sparse the V and t array to find the spike times
V = V[::sparse]
t = t[::sparse]
ind1 = []
for i in range(np.shape(V)[0] -1) :
if ((V[i] - th) <= 0.0 and (V[i+1] - th) > 0.0 ) :
ind1.append(i)
arr = t[ind1]
return arr
def rastor_plotter(nl,t,th,dpi = 100, size = (20,10), color = "black") :
spike_times = []
for i in range(np.shape(nl)[0]) :
spike_times.append(spike_counter2(t,nl[i],th))
# Plot the raster plot
figure(figsize=size, dpi=dpi)
plt.eventplot(spike_times, lineoffsets=range(np.shape(nl)[0]), linelengths=4.0,color=color)
plt.xlabel('Time (seconds)')
plt.ylabel('Neuron #')
#store the spike times
#np.save('spike_times.npy', spike_times)
return figure
# In[22]:
#plot the rastor plot
rastor_plotter(M.v/mV,M.t/ms,-20, dpi = 200, size = (20,10), color = "blue")
plt.savefig('rastor_plot.png')
plt.show()