A transformer project in pure Jax (using only jnp, no jax.nn), that I will occasionally update with new stuff (e.g. BlockAttention at some point).
This is mainly an academic exercise, to improve my understanding of transformers and Jax. It almost certainly does not follow Jax best practices.
If you want to contribute something, that's great, and I'd be very pleased, so long as you explain slowly and clearly what your code does (like I'm a particularly dim golden retriever).
P.S. I accept no liability if this codebase is used in creating an AI overlord that enslaves humanity. :)