Inter-Layer Alignment Re-Rise Predicts Grokking in Transformers: A Pilot Study
Description
We identify an internal signal that detects impending grokking ∼3,000 training steps
before the accuracy transition without any knowledge of the learned algorithm. The
signal is a non-monotonic re-rise of inter-layer CKA (Ob01, between layers 0 and 1 of
a small transformer): after an initial drop during memorization, Ob01 begins rising again at
step ≈750 exclusively in models that will grokk. This re-rise discriminates grokking from
non-grokking in 16/16 vs 0/5 seeds (Fisher exact p = 0.000049; one-sample t = 4.51,
p < 0.001), with onset consistently ∼3,297 ± 461 steps before the accuracy jump. The signal replicates across two primes (p = 97, p = 113) and across a structurally distinct modular
task: modular multiplication (6/6 seeds, onset step 750, lead 3,083±186 steps) identical onset to addition despite dierent underlying algorithms. A task-contrast experiment
with k-sparse parity reveals the boundary of the signal's scope: parity generalizes quickly
but Ob01 drops rather than rises the opposite dynamic. This demarcates a meaningful
distinction: the re-rise is a signature of tasks requiring distributed inter-layer coordination (Fourier-type algorithms), not of grokking or generalization in general. Practically,
the re-rise onset provides an early-termination criterion within the modular arithmetic family: runs showing no re-rise by step 1,000 do not grokk in any tested conguration, enabling
diagnosis thousands of steps before any test accuracy signal appears. We motivate Ob01 via
O(X, Y ) = I(X; Y )/H(Y ) and hierarchical predictive coding theory, while being explicit
that CKA is a linear proxy for O, not an estimator of it. Replacing CKA with MINE or
kNN MI estimators is the priority follow-up.
Files
main.pdf
Files
(884.0 kB)
| Name | Size | Download all |
|---|---|---|
|
md5:fb995e81efd2deaef935f0a19d813fd0
|
884.0 kB | Preview Download |