MPAX: Mathematical Programming in JAX
We present MPAX (Mathematical Programming in JAX), an open-source first-order solver for large-scale linear programming (LP) and convex quadratic programming (QP) built natively in JAX. The primary goal of MPAX is to exploit modern machine learning infrastructure for large-scale mathematical programming, while also providing advanced mathematical programming algorithms that are easy to integrate into machine learning workflows. MPAX implements two PDHG variants, r2HPDHG for LP and rAPDHG for QP, together with diagonal preconditioning, adaptive restarts, adaptive step sizes, primal-weight updates, infeasibility detection, and feasibility polishing. Leveraging JAX's compilation and parallelization ecosystem, MPAX provides across-hardware portability, batched solving, distributed optimization, and automatic differentiation. We evaluate MPAX on CPUs, NVIDIA GPUs, and Google TPUs, observing substantial GPU speedups over CPU baselines and competitive performance relative to GPU-based codebases on standard LP/QP benchmarks. Our numerical experiments further demonstrate MPAX's capabilities in high-throughput batched solving, near-linear multi-GPU scaling for dense LPs, and efficient end-to-end differentiable training. The solver is publicly available at https://github.com/MIT-Lu-Lab/MPAX.