Any possibility of a jax runtime?

I imagine the functional programming requirement of jax would preclude a drop-in replacement for the numpy fallback, but how feasible would it be to implement a jax runtime? Would this provide a simpler path to running on the GPU and maybe even training parameters with gradient descent like jaxley?

Hi @kjohnsen. This topic came up a few times (see e.g. Implement surrogate gradient descent method · Issue #1207 · brian-team/brian2 · GitHub and PyTorch/Tensorflow backend? · Issue #1014 · brian-team/brian2 · GitHub), but I think it would be a major effort that would probably have to come from someone… that is not me :slight_smile:
I don’t have a very in-depth knowledge of jax, but I am seeing some difficulties to map its approach to Brian’s. As you say, It uses a functional approach with immutable arrays, whereas Brian has a fixed array for each state variable which is modified in-place. I guess we could work around this by wrapping the Jax code with numpy code that overwrites the state variable array with the result of jax’s computation, but this will 1) add overhead and 2) probably mess up all the automatic gradient functionality?

Maybe @dan has an opinion on this? One way to play around with this would be to see whether Dan’s examples here could be easily use jax instead of numba: GitHub - thesamovar/cuba_with_and_without_brian: Simple example of converting a Brian script into pure Python using numba

That person wouldn’t be me either :sweat_smile:. Hopefully someone can do it someday! I forgot to mention another advantage would be hopefully supporting things like network operations (which at the moment is preventing me from using brian2cuda everywhere)

Yes, network_operation are the big thing we can only have in runtime mode, but then it might be difficult to combine with JAX as well, given that network_operation are all about updating the global state variables that then are picked up automatically by all other operations.

Out of curiosity: what are you currently using network_operation for? It might be possible to replace them by some clever use of existing features like run_regularly, and/or user-defined functions…

I’m using network operations to enable simulated closed-loop control, so whenever a sample is taken. Recording devices perform arbitrary code and pass it to the closed-loop processor. That processor computes a control signal for stimulation devices and stores it in a buffer to deliver after some delay (simulating latency). Then the stimulator devices also can execute arbitrary code when updating state variables—all of that gets introduced by a single network operation here.

For an example of code that devices run, here is light, LFP, and the microscope.

I’ve come to the conclusion this would be too big of a headache to implement in native Brian equations, at least at this point. Maybe it would have been worth it to do it from the beginning.

I don’t know Jax I’m afraid. I think other backends are possible, including PyTorch, JAX, etc., but it’s a serious time commitment to implement them in a useful way.

1 Like

Hi, I agree that this is quite a complex setup that would be rather difficult to express with equations and custom C++ functions.

Not sure that it applies to your use case, but I realized that one feasible way to link standalone code (C++ or CUDA) to arbitrary Python code is to use pipes, which is rather straightforward with bash under Linux (and I guess similar for macOS).
E.g. this dummy code communicates every 5ms with another process by printing out the current values of the membrane potential and reading in some kind of external input calculated based on these values (or their history, or whatever):

# cpp_pipes.py
from brian2 import *
set_device('cpp_standalone')

G = NeuronGroup(5, '''dv/dt = (-v + ext(i, t))/ (10*ms) : 1''', method='euler')
G.v = 'rand()'


@implementation('cpp', '''
double ext(int i, double t) {
    static double last_t = -1;
    static double values[10];
    if (t - last_t > 0.005) { // if last update was more than 5ms ago
        // write current values of v to stdout
        for (int i = 0; i < %N%; i++) {
            std::cout << brian::%V_VARIABLE%[i] << " ";
        }
        std::cout << std::endl;
        std::cout.flush();
        // read in ext values from stdin
        for (int i = 0; i < %N%; i++) {
            std::cin >> values[i];
        }
        last_t = t;
    }
    return values[i];
}
'''.replace('%V_VARIABLE%', device.get_array_name(G.v.variable)).replace("%N%", str(len(G))))
@check_units(i=1, t=second, result=1)
def ext(i, t):
    pass

state_mon = StateMonitor(G, 'v', record=True)

run(100*ms, report='stderr')

plt.plot(state_mon.t/ms, state_mon.v.T)
plt.show()

This general Python script reads in membrane potential values and writes out input values. Here this switches on inputs for membrane potentials < 0.5:

# input_provider.py
import numpy as np
import sys
if __name__ == '__main__':
    # Read in the current membrane potential values
    while True:
        try:
            line = input()
        except EOFError:
            break
        values = np.array([float(d) for d in line.split()])
        # Example for some closed-loop interaction:
        ext_inputs = np.zeros(len(values))
        ext_inputs[values < 0.5] = 1

        # Send the inputs to stdout
        print(' '.join([str(i) for i in ext_inputs]))
        sys.stdout.flush()

With bash, you can then start the two processes and have them communicate via two pipes, giving you a cheap&dirty implementation of inter-process-communication:

$ mkfifo fifo
$ python -u cpp_pipes.py < fifo | python -u input_provider.py > fifo

Given that the C++ code only asks every 5ms for new input values, you get this oscillation around the 0.5 threshold value:

I should probably move this example somewhere else where it is a bit more discoverable instead of being at the end of a rather unrelated discussion :sweat_smile:

Interesting! So this would allow me to implement the closed-loop part without network operations. I would still need to implement quite a bit of nontrivial code for reading out the recording device measurements and applying the stimulation device updates, but it’s something for me to chew on.

1 Like