Diffrax: Numerical Differential Equation Solvers in JAX
My talk will introduce "Diffrax", a new collection of ordinary/stochastic/controlled differential equation solvers, written in JAX (Python). Highlights include:
- Exceptionally high performance: similar to comparable Julia libraries and frequently ~100 times faster than equivalent PyTorch libaries.
- Numerous features: high-order solvers, implicit solvers, dense solutions multiple adjoint methods, etc.
- Integrates directly with the JAX ecosystem, including e.g. jit/grad/vmap.
- Easily extensible with custom solvers/etc.; includes the ability to handle the stepping yourself if writing a differentiable simulator.
- The main technical novelty is the abstractions used; in particular ODEs/SDEs/etc. are all solved in the same unified way.
GitHub: https://github.com/patrick-kidger/diffrax
(I will also briefly advertise "On Neural Differential Equations", https://arxiv.org/abs/2202.02435, which is a new textbook for NDEs.)"