Canonicalization of Batched Einstein Summations for Tuning Retrieval
We present an algorithm for normalizing \emph{Batched Einstein Summation} expressions by mapping mathematically equivalent formulations to a unique normal form. Batches of einsums with the same Einstein notation that exhibit substantial data reuse appear frequently in finite element methods (FEM), numerical linear algebra, and computational chemistry. To effectively exploit this temporal locality for high performance, we consider groups of einsums in batched form. Representations of equivalent batched einsums may differ due to index renaming, permutations within the batch, and, due to the commutativity and associativity of multiplication operation. The lack of a canonical representation hinders the reuse of optimization and tuning knowledge in software systems. To this end, we develop a novel encoding of batched einsums as colored graphs and apply graph canonicalization to derive a normal form. In addition to the canonicalization algorithm, we propose a representation of einsums using functional array operands and provide a strategy to transfer transformations operating on the normal form to \emph{functional batched einsums} that exhibit the same normal form; crucial for fusing surrounding computations for memory bound einsums. We evaluate our approach against JAX, and observe a geomean speedup of $4.7\times$ for einsums from the TCCG benchmark suite and an FEM solver.