Oh right, I forgot that BF16 allocates the same number of sign and exponent bits as F32, so we can convert from BF16 to F32 by just padding zeros to the fraction bits. I'd guess conversion in the opposite direction would require rounding rather than truncation if we want to preserve as much intelligence as possible, so something like (https://hhhhhojeihsu.github.io/tensorflow_1.8_woboq/tensorfl...).
In any case, a constraint of my exercise was to compare a hand-written inference implementation to existing inference engines (calm/llama.cpp) which both use FP16 for weight and KV cache quantization by default rather than BF16, so a BF16 backend will be an interesting addition but not so relevant to the blog post.
Yes, we can totally use cuBLAS, and it's likely what I'd go with in a production implementation for any kernels which I wasn't planning on fusing, but as mentioned in the post, defeats the point of the exercise and usually less fun than doing things from scratch. That said, I do plan on exploring for myself at some point what kind of gain I could get from switching existing hand-rolled ops to cuBLAS/cuDNN (and then trying to match that).
Thanks for linking your prefill code with GEMMs! Actually, your code looks pretty readable and well-annotated. It'll be cool digging in as your CPU backend seems better developed than mine. Cool observation about how using batched-GEMM instead of GEMV speeds things up in the single-batch forward() function. I am curious about your intuition behind the reduction in memory bandwidth ("if multiple cores read the same locations in xb..."). Is this about how xb might be kept in the L3 cache which is shared between cores in Zen 3? I wonder if you could validate this idea by adapting your handwritten matmul() function to operate on 2 matrices and outputs at the same time.
+1, rounding f32 to bf16 is helpful. For the other direction, the approach we take in Highway/gemma.cpp is to load a full vector of bf16, then shift/AND to isolate the odd/even elements and convert to float. These can execute two per cycle, whereas promoting 16->32 bit is often just one per cycle (though a different port than FMA).
I plan to write my own GEMM implementations so I can do kernel fusion. I am using GEMM implementations from Nvidia and Intel until I have my own performant substitutes. So far, my initial attempts at writing my own GEMM implementations have not been performance competitive. Intel and Nvidia use tiling algorithms. I know (unintentionally) from profiling that Intel is doing a memcpy trick that is presumably similar to the one described here:
As for the annotations, some might be outdated. Much like the GPU version, I had published the CPU version at the request of another OSS developer and before I had reached the point where I would make a cleanup pass to address any stale/obsolete/missing comments. I made some effort to have good comments for my own benefit during the current stage of development, but I was not disciplined enough to make a cleanup pass unnecessary. That is partly why I claimed it is messy. I am happy to hear that they look good even before I have done a cleanup pass.
As for xb being in L3, that is part of it. I suspect that Intel has multiple threads doing “tiles” from both GEMM (GEMV in GEMM) operations in parallel and the reads and prefetches on xb are being coalesced by L3 cache. I also had the idea of adapting my handwritten matmul to operate on 2 outputs at the same time, and even tried it, but it did not work. At the time, I had thought this was sue to register spilling occurred because I ran out of ymm registers on Zen 3. Presumably, this experiment could be attempted on Zen 4 using the additional ymm registers that AVX-512 gives to AVX2.
However, two further thoughts occurred to me while writing this response. The first is that this experiment could be done on Zen 3 by doing do 6 rows at a time per matrix instead of 8. The second is that this will be throwing a larger number of parallel linear accesses at the core’s memory prefetcher than what I suspect Intel does, so this might fail to work even if the underlying idea about how Intel is getting better performance is accurate. Honestly, I had been surprised to see Zen 3 handle 9 linear accesses per loop iteration, as I had felt that doing 9 (8 rows + xb) had already been asking too much, but to my happy surprise, it was not.
Right now, some GPU-based experiments are at the top of my mental to do list. I will likely circle back to the CPU side to try to understand how to get this performance from my own GEMM/GEMV kernels at some point after doing them.
In any case, a constraint of my exercise was to compare a hand-written inference implementation to existing inference engines (calm/llama.cpp) which both use FP16 for weight and KV cache quantization by default rather than BF16, so a BF16 backend will be an interesting addition but not so relevant to the blog post.
Yes, we can totally use cuBLAS, and it's likely what I'd go with in a production implementation for any kernels which I wasn't planning on fusing, but as mentioned in the post, defeats the point of the exercise and usually less fun than doing things from scratch. That said, I do plan on exploring for myself at some point what kind of gain I could get from switching existing hand-rolled ops to cuBLAS/cuDNN (and then trying to match that).
Thanks for linking your prefill code with GEMMs! Actually, your code looks pretty readable and well-annotated. It'll be cool digging in as your CPU backend seems better developed than mine. Cool observation about how using batched-GEMM instead of GEMV speeds things up in the single-batch forward() function. I am curious about your intuition behind the reduction in memory bandwidth ("if multiple cores read the same locations in xb..."). Is this about how xb might be kept in the L3 cache which is shared between cores in Zen 3? I wonder if you could validate this idea by adapting your handwritten matmul() function to operate on 2 matrices and outputs at the same time.