Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX

The paper introduces Eventax, a JAX-based framework that resolves the trade-off between gradient bias and model flexibility in spiking neural networks by combining differentiable ODE solvers with event-based spike handling to enable exact gradient training for arbitrary neuron models.

Lukas König, Manuel Kuhn, David Kappel, Anand Subramoney

Published 2026-03-10
📖 5 min read🧠 Deep dive

Imagine you are trying to teach a team of fireflies to communicate in a specific pattern to solve a puzzle. These fireflies don't talk continuously; they only flash their lights at precise moments. In the world of computer science, these fireflies are Spiking Neural Networks (SNNs), and they are incredibly energy-efficient and fast, much like the human brain.

However, there's a big problem: How do you teach them?

The Old Problem: The "Rough Sketch" vs. The "Perfect Map"

To teach a neural network, you need to calculate "gradients"—essentially, a map that tells the network which way to nudge its settings to get a better result.

  1. The "Rough Sketch" Method (Discrete Time):
    Imagine trying to track a firefly's flash by checking a clock every 10 milliseconds. If the firefly flashes at 10.04ms, you miss it or guess it happened at 10.00ms. To make the math work, researchers use "surrogate gradients," which are like guessing the slope of a cliff by looking at a blurry photo. It's fast and works for many types of fireflies, but the map is inaccurate. You might steer the firefly in the wrong direction because your "blurry photo" lied to you.

  2. The "Perfect Map" Method (Continuous Time):
    Imagine using a super-precise GPS that knows exactly when the firefly flashes. This gives you a perfect map. But here's the catch: this GPS only works if the firefly follows a very simple, predictable flight path (like a standard "Leaky Integrate-and-Fire" model). If the firefly has a complex, wobbly flight pattern (like a biological neuron), the GPS breaks because it can't calculate the exact path without a pre-written formula.

The Trade-off: You could have speed and flexibility (but bad accuracy), or perfect accuracy (but only for simple fireflies).

The Solution: Eventax (The "Smart Time-Traveler")

The authors of this paper built a new framework called Eventax. Think of Eventax as a super-smart time-traveling detective that solves the mystery of how to train complex fireflies perfectly.

Here is how it works, using simple analogies:

1. The "Root-Finding" Detective

Instead of checking a clock every 10ms, Eventax uses a mathematical tool called a Differential Equation Solver (specifically, one built into a library called Diffrax).

  • Analogy: Imagine you are driving a car toward a finish line. A normal driver checks the speedometer every second. Eventax is like a driver who can instantly calculate exactly when the car will cross the line, even if the car is accelerating or braking unpredictably.
  • The Magic: When a neuron is about to "fire" (flash), Eventax doesn't guess. It uses a "root-finding" algorithm to pinpoint the exact nanosecond the spike happens.

2. The "Backwards Time-Travel"

Once the firefly flashes, the network needs to learn from the mistake.

  • Analogy: Usually, to learn from a mistake, you have to rewind the tape. Eventax can rewind the tape perfectly. Because it knows the exact moment the flash happened, it can trace the cause-and-effect chain backward through time without losing any precision.
  • The Result: It calculates the "gradient" (the lesson) exactly, without any blurry guesses.

3. The "Universal Adapter"

The best part? Eventax doesn't care what kind of firefly you have.

  • Analogy: Old GPS systems only worked for cars. Eventax is like a universal adapter that works for cars, bicycles, skateboards, and even flying saucers.
  • Real-world application: The authors used Eventax to train not just simple fireflies, but complex ones that mimic real human brain cells (with dendrites and spikes) and even "Event-based Gated Recurrent Units" (which are like memory banks for time-based data).

What Did They Prove?

The team tested Eventax on several challenges:

  • The Yin-Yang Puzzle: A classic pattern recognition task. Eventax solved it with different types of neurons, showing that complex, biologically realistic neurons actually learned better and faster than the simple ones.
  • MNIST (Handwritten Digits): They taught the network to recognize numbers. It performed just as well as the best existing methods, proving it's not just a cool toy, but a practical tool.
  • The "Delayed XOR" Game: This is a logic puzzle where the network has to remember two inputs separated by time and decide if they are the same or different. Eventax solved this perfectly, proving it can handle memory and time effectively.

Why Should You Care?

  1. Better Brain Simulations: Scientists can now simulate complex biological neurons without simplifying them, helping us understand how the real brain works.
  2. Neuromorphic Hardware: There are new computer chips designed to run like brains (neuromorphic chips). Eventax allows engineers to design software that perfectly matches these chips, leading to super-efficient AI that uses very little electricity.
  3. No More "Guessing": By removing the need for "surrogate gradients" (the blurry photos), we get more reliable AI training.

In a Nutshell

Eventax is a new tool that lets us train "event-based" neural networks with perfect precision, regardless of how complex the neurons are. It combines the speed of modern math solvers with the flexibility of event-driven computing, bridging the gap between simple AI models and the complex, beautiful machinery of the human brain. It's like upgrading from a sketchpad to a high-definition, 4K video camera for training AI.