Einsums are the regexes of tensor programming. Should be avoided at all costs IMO. Ideally we should be able to write native loops that get auto-vectorized into einsums for which there is a CUDA/PTX emitting factory. But for some reason neither PyTorch nor JAX/TF took this route and now we are here.
Some of the einsum expressions I have seen for grouped multi headed/query attention is mind-boggling and they get shipped to prod.
JAX kind of did take this route, no? The main issue is that it’s going to be hard/impossible to compile Python loops to GPU kernels. It’s also maybe not the most ergonomic solution, which is why there is shorthand like einsum. Einsum can be much more clear than a loop because what it can do is so much more limited.
JAX tries to be a functional language that has a Python front end. The problem is if you are outside Google and don't really understand the XLA compiler, then you are screwed.
Some of the einsum expressions I have seen for grouped multi headed/query attention is mind-boggling and they get shipped to prod.