Back to Blog

Hand-rolling BitNet b1.58 in Rust: from autograd to tensor cores

Why I wrote every layer myself, what AVX-512 actually buys you on Zen 4, and the cuBLAS int8 GEMM saga that ended in a 3000-kernel-launches-per-step reality check

13 min read
rust machine-learning cuda simd performance

Published 2026-05-06 · 13 min read

Why bother

Burn and candle exist and they are excellent. Neither will teach you why the gradient of a ternary quantiser is the identity, or what it actually costs to launch a CUDA kernel, because both hide the answer behind a layer of abstraction. I wanted to know.

So I wrote BitNet b1.58 in pure Rust with no third-party ML dependencies. Just std for the default build. Optional cudarc behind a feature flag for the GPU back-end. Tensor type, autograd, ternary quantisation, transformer block, training loop, KV-cached inference, three on-disk export formats, and a CUDA back-end that runs real ternary training on Ada tensor cores via cuBLAS int8 GEMM. 139 tests on the default build, more with --features cuda enabled, and zero warnings on cargo build --release.

A lot of it works. One key part doesn't, and that is itself an instructive story.

What BitNet b1.58 actually is

BitNet b1.58 is a ternary-weight transformer. Each weight is one of {-1, 0, +1}. Activations are quantised to INT8 per row. The forward pass is discrete; the backward pass uses a straight-through estimator (STE), which is what makes the whole thing trainable.

The STE in two lines: in the forward, you quantise. In the backward, you pretend you didn't. The gradient flows through the quantiser as if it were the identity. It should not work. It does, because the quantiser is a piecewise-constant projection of an underlying f32 master weight that is updated continuously, and the gradient signal accumulates in the masters until a master crosses a threshold and the ternary value flips.

In practice this is a Var method on the autograd tape that wraps absmean_ternary from bitlinear.rs:

// Forward: gamma * W_q (quant + dequant via absmean_ternary).
// Backward: identity passthrough - the STE.
pub fn quantise_weights_ste(self) -> Var<'t> {
    let w_val = self.value();
    let (w_q, gamma) = absmean_ternary(&w_val);                  // {-1, 0, +1} + scalar gamma
    let w_eff_data: Vec<f32> = w_q.data.iter().map(|v| v * gamma).collect();
    let w_eff = Tensor { data: w_eff_data, shape: w_q.shape.clone() };

    // Identity backward: just route the gradient straight back to the master weight.
    let p0 = self.id;
    let backward = Box::new(move |g: &Tensor| -> Vec<(NodeId, Tensor)> {
        vec![(p0, g.clone())]
    });
    let id = self.tape.push(Node {
        value: w_eff,
        grad: RefCell::new(Tensor::zeros(w_q.shape)),
        backward: Some(backward),
        parents: vec![self.id],
    });
    Var { tape: self.tape, id }
}

You keep two copies of the weights at all times during training. The masters in f32, used by the optimiser. The ternary projection used in the forward and exported at the end. The compression ratio only materialises when training is done.

What I'd been trying

Before this I had read three papers on quantisation, used Burn for one toy training run, and worked through a couple of Karpathy walkthroughs. I knew the shape but not the substance. The specific things I wanted to feel in my hands:

  • How a tape-based autograd works, in the small, with no metaprogramming.
  • Where SIMD actually helps in a transformer forward pass and where it doesn't.
  • What the per-step cost of a CUDA kernel launch actually is, measured.
  • Whether a hand-rolled implementation can hit numerical parity with a GPU implementation, and what parity even means when the GPU is doing parallel reductions.

I deliberately picked constraints that would make the project simpler. f32 throughout, no FP16 or BF16. No third-party ML deps. Single crate. Character-level vocabulary. Two transformer blocks for the M4-M10 demo, six for the Shakespeare run, eight for Shakespeare-large. The Shakespeare config lands around 5M parameters (hidden_dim 192, twelve heads, head_dim 16); Shakespeare-large bumps that to ~8.5M (hidden_dim 256, eight blocks). The demo itself is much smaller, with hidden_dim 16, and exists for the integration tests rather than for actual training. This is all small enough that you can train end to end in 8 to 15 minutes on CPU and reason about every line of every file.

The architecture

src/
  tensor.rs        row-major f32, AVX-512/AVX2/scalar matmul
  autograd.rs      tape-based reverse-mode, STE quantisers, RoPE, RMSNorm
  bitlinear.rs     absmean_ternary + absmax_int8 quantiser primitives
  attention.rs     multi-head SDPA with causal mask + RoPE
  ffn.rs           SwiGLU position-wise
  block.rs         RMSNorm + attention + FFN + residuals
  model.rs         Model, ModelConfig, parameter visitor
  optim.rs         AdamW, gradient clipping, cosine LR with warmup
  inference_kv.rs  KV-cached generator, ~50-100x faster per token
  device.rs        per-op traits so a generic helper compiles for CPU + CUDA
  cuda.rs          NVRTC kernels, cuBLAS, CudaTensor, end-to-end forward
  export.rs        f32-with-masters, ternary i8, base-3 packed

Three points worth pulling out.

  • The tape-based autograd is about 200 lines of core machinery. Each op pushes a Node onto a RefCell<Vec<Node>> during the forward, and the backward walks the tape in reverse. STE is a special case: forward quantises, backward pretends the input is the output. Once the tape works, every other op in the project is "implement forward, register a backward closure" and you stop thinking about it.
  • The device module is what makes the CPU and CUDA back-ends share code. Every op (MatMul, Add, Mul, MulScalar, Transpose2D, Softmax, CausalMask, Rope, Silu, RmsNorm) is a trait. A generic block_inference<T> with the union of those traits as a where-clause compiles for both Tensor (CPU) and CudaTensor (GPU) without duplication. The design has worked very well (monomorphisation is exactly the optimisation you want here) and I no longer have to maintain two copies of the model code.
  • Three export formats. F32 with masters is the only one you can resume training from cleanly. Ternary i8 is one byte per weight, 2.92x compression. Base-3 packed exploits the fact that 3^5 = 243 < 256, so you can fit five ternary weights into a single byte. That's about 1.6 bits per weight and gives a 6.02x compression for the demo model. Embeddings stay f32 in all formats.

SIMD and the AVX-512 surprise

The CPU matmul has a runtime-detected three-way fallback ladder. Widest path first: AVX-512 (16 f32 per inner-loop step) on Zen 4 and Sapphire Rapids, falling back to AVX2 (8 f32 per step), and below that to a scalar AXPY. Multi-threaded across output rows via std::thread::scope.

I expected AVX-512 to be the fast path everywhere it was available. It isn't, on Zen 4. AVX-512 underperforms AVX2 on Zen 4 in my benchmarks, and the cause is memory bandwidth. The ALUs can chew through a 16-wide multiply-add per cycle, but the L2 and the memory controller cannot feed them fast enough at the matmul shapes I'm running. AVX2 finishes nearly as much work per cycle and thrashes the cache less.

The fix is an opt-out env var. BITNET_MATMUL_SIMD=avx2 skips the AVX-512 path on Zen 4. Detection is at runtime via is_x86_feature_detected!, so the same binary picks the right path on Sapphire Rapids and Zen 4 without recompilation.

The other thing I had to be careful about: all three matmul paths use plain multiply-add, not FMA. This is intentional. If one path uses FMA and another doesn't, the rounding differs by one ULP per multiply, and the two paths produce different values for the same inputs. I want bit-identical output across the SIMD ladder so that swapping in AVX-512 doesn't change the loss curve. The performance cost is small; the consistency win is worth it.

The CUDA back-end

Three phases got the GPU path to a working forward pass.

  1. Phase 1. cuBLAS sgemm for matmul, plus 9 hand-rolled NVRTC kernels for everything else: add, mul, mul_scalar, transpose_2d, causal_mask, softmax, rope, silu, rmsnorm. Each kernel is a small chunk of device code compiled at runtime by NVRTC. cudarc 0.19 makes this less awful than it sounds.
  2. Phase 2. The per-op trait architecture in device.rs. CudaTensor implements MatMul + Add + Mul + ... and the generic helpers in the model code compose without modification. block_inference<T> runs on either back-end, picked by the type.
  3. Phase 3. CudaModel end-to-end forward. Embed -> blocks -> RMSNorm -> LM head matmul, all device-resident. NOT bit-identical to CPU, because parallel reduction reorders sums and cuBLAS picks its own internal tile schedule. The actual numerical agreement was within 1e-3 per op, 5e-3 per block, 2e-2 end to end. Acceptable.

So far, so good. Phase 4 added per-op backward kernels. Phase 5 is where it got interesting.

Phase 5: real ternary training on tensor cores

The point of BitNet b1.58 is INT8 inference and INT8-or-better training. Phase 5.a added STE quantiser kernels and a BitLinear trait. Phase 5.b rewrote CudaTensor::matmul to call cublasGemmEx with CUDA_R_8I inputs, CUDA_R_32I accumulator, CUBLAS_COMPUTE_32I compute, and CUBLAS_GEMM_DEFAULT_TENSOR_OP algorithm selection. That last constant is the magic word that gets the call dispatched to Ada's tensor cores.

The four constants in one call site (via cudarc's gemm_ex wrapper):

unsafe {
    gemm_ex(
        blas_handle,
        sys::cublasOperation_t::CUBLAS_OP_N,
        sys::cublasOperation_t::CUBLAS_OP_N,
        n_i, m_i, k_i,
        &alpha_h as *const i32 as *const c_void,
        b_ptr as *const c_void,         sys::cudaDataType::CUDA_R_8I,  n_i,
        a_ptr as *const c_void,         sys::cudaDataType::CUDA_R_8I,  k_i,
        &beta_h as *const i32 as *const c_void,
        c_ptr as *mut c_void,           sys::cudaDataType::CUDA_R_32I, n_i,
        sys::cublasComputeType_t::CUBLAS_COMPUTE_32I,
        sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
    )
    .expect("cublasGemmEx int8 failed");
}

(Row-major-via-column-major adapter: ask cuBLAS for C_col = B_col @ A_col of shape (N, M); the bytes that come out match the desired C_row = A_row @ B_row. Same trick the f32 MatMul impl uses.)

It worked. Phase 5.a tests passed with the int8 path active. cuda-shakespeare ran end to end and the loss trajectory matched the CPU reference run.

There is one subtlety in the int8 path. cuBLAS int8 GEMM has stricter stride alignment requirements than f32 sgemm, and my lm_head projection has n = vocab = 65, which doesn't satisfy the alignment. The tensor-core GEMM call fails for that one shape. Solution: shape-fallback to f32 sgemm when stride alignment fails. The fallback is per-call, not per-tensor, so it only kicks in on the LM head projection. The rest of the model goes through int8 tensor cores.

So I had real BitNet ternary training running on Ada tensor cores. And then I measured it.

The reality check

v0.13 config, 5M parameters, sequence length 64, batch 16

  CPU per-step (Zen 4, AVX2):              ~180 ms
  GPU per-step (Ada, cuBLAS int8 GEMM):    ~300 ms
  Estimated per-step launch overhead:      ~6 - 8 ms (~3000 launches)

The GPU was slower than the CPU. By a substantial margin.

I spent two days hunting for a kernel bug before I realised what the actual problem was. There is no kernel bug. There is, instead, a wall I had walked into face-first. The problem is per-step launch overhead. A single training step at this config involves something like 3000 kernel launches across the forward, the backward, and the optimiser update. Each launch has a fixed per-call cost - a few microseconds, but with 3000 calls per step that adds up to 6 to 8 ms of pure launch overhead per step, plus the latency of synchronising host and device after each one. At small batch sizes the GPU is starved. The tensor cores spend more time waiting for the next launch than they do computing.

This is the GPU equivalent of "calling a function 10 million times in a hot loop". The fix is either to fuse kernels (one big launch instead of many small ones), to use CUDA graphs (record a sequence of launches once, replay them with one host-side call), or to make each launch do enough work that the per-call overhead amortises - which means bigger batches.

I knew about kernel launch overhead. I had read the docs. What I hadn't internalised is how brutal it is at small batch sizes with a model that has roughly 60 ops per layer and six layers. The CPU has effectively zero per-call overhead because the whole forward is just function calls in the same process.

This was not a failure of the project. It is the project finding out something I now know in my bones. But the headline "real BitNet training on Ada tensor cores" is true and the headline "GPU is faster than CPU at this scale" is false, and I want to be precise about both.

What I'd do differently

The list is short and concrete.

  • Plan for kernel fusion from the start. The 9-NVRTC-kernel design was a great Phase 1, but the right architecture for Phase 4+ is one fused kernel per logical block. The tape would record a fused-block op, and the generic helpers would dispatch to the fused version on CUDA and the unfused version on CPU.
  • Use CUDA graphs once a step is stable. Once your training loop's launch sequence is fixed, recording it once and replaying gives you most of the benefit of fusion with much less code. I will do this in Phase 6.
  • Larger batches by default on GPU. The 5M-param config was sized for CPU comfort. The right GPU config has batch size 4 to 8 times larger, which makes each kernel do enough work that the launch overhead stops dominating.
  • Per-step profiling earlier. I had Nsight Systems on hand. I should have run it on day one of Phase 4, not after I noticed the wall-clock numbers were off.

The bits that worked

The CPU forward and backward, including the SIMD ladder, the autograd, the AdamW with cosine LR and warmup, the gradient clipping. End-to-end training on TinyShakespeare in 8 to 15 minutes on CPU. Best val_ppl so far is 4.869 at 30,000 cumulative steps. The KV-cached inference path gives roughly 50-100x speedup per token versus recomputing the full forward each step. All three export formats round-trip cleanly through export::import. The 6.02x packed format has been useful for keeping checkpoint files small.

The CUDA back-end, despite being slower than the CPU at small batch sizes, demonstrates that you can drive Ada tensor cores from Rust through cudarc and cublasGemmEx with int8 inputs and int32 accumulator without leaving Rust. That alone took several iterations to get right and the result is something I can build on.

What's next

Phase 6 is kernel fusion plus CUDA graphs. The expected outcome is GPU faster than CPU by roughly the ratio you would expect from raw FLOP throughput once launch overhead is amortised - somewhere between 5x and 20x at the v0.13 config. After that, larger configs to actually exploit the tensor-core throughput, because the whole point of BitNet b1.58 is that ternary weights and INT8 activations let you run much bigger models at the same memory footprint.

Code is at github.com/tidynest/bitnet-toy. 139 tests on the default build, architecture walkthrough in docs/ARCHITECTURE.md, training recipe in docs/TRAINING.md. If you spot a place where the autograd is doing something wrong, please open an issue. This is a learning project; the whole point is finding out where I went off the rails.