Nested Learning in Practice: Geometry of the Deep Optimizer, Multi‑Clock Transformers, and Reference Code

Nested Learning in Practice: Geometry of the Deep Optimizer, Multi‑Clock Transformers, and Reference Code

Pronam ChatterjeePronam Chatterjee
7 min read

1. Why Nested Learning?

Standard deep learning draws a hard line between architecture (layers) and optimization (the training rule). Nested Learning (NL) observes both are stateful learners with update rules; the only principled difference is how often each state is updated—its frequency. Orchestrating different frequencies gives you plasticity without amnesia, the core problem in continual learning.

2. Formal setup

We use standard notation. Inputs xtRdinx_t \in \mathbb{R}^{d_{\text{in}}}. Let a module MM (a layer, an optimizer’s momentum, or an attention cache) be an associative memory that learns a mapping KVK \to V.

2.1 Associative memory

Definition 1 (Associative Memory) Given keys KRdkK \subset\mathbb{R}^{d_k} and values VRdvV \subset\mathbb{R}^{d_v}, an associative memory is an operator M:KVM: K \to V learned by minimizing an internal objective

M  =  argminM  L~(M(K),V). M^\star \;=\; \arg\min_M\; \tilde{\mathcal{L}}\big(M(K),\, V\big).

This perspective applies uniformly to optimizers (they compress gradient streams), attention (key–value stores), and linear layers.

2.2 Update frequency and levels

Let one data‑point update be one unit of time. For any component AA, define its **update frequency**\[f_A \;=\; \text{\#updates of }A \text{ per unit time}.\]We write ABA \succ B (“A is faster”) if (i) fA>fBf_A > f_B or (ii) fA=fBf_A = f_B but the state of BB at time tt depends on the state of AA at time tt. Sorting components by \succ yields *levels*; higher level ⇒ lower frequency.

2.3 Deep optimizer (L2 regression view)

For a linear map y=Wxy = W x, the local “surprise” is gy=yL(W;x)g_y = \nabla_y \mathcal L(W; x). The usual outer gradient is gW=gyxg_W = g_y x^\top. NL proposes an inner L2 objective that regresses WxW x to gyg_y:

minW  12Wxgy22 min_W \; \tfrac12\|W x - g_y\|_2^2

One gradient step of size α\alpha gives

W  =  Wα(Wxgy)x  =  W(Iαxx)+αgyx.W' \;=\; W - \alpha(Wx - g_y)x^\top \;=\; W\big(I - \alpha x x^\top\big) + \alpha g_y x^\top.

Combining this with an outer step WWηgWW \leftarrow W - \eta g_W yields\

Wt+1  =  Wt(Iαxtxt)    ηWL(Wt;xt)(★)\boxed{\, W_{t+1} \;=\; W_t\big(I - \alpha x_t x_t^\top\big) \; - \; \eta \, \nabla_W \mathcal L(W_t; x_t) \,} \tag{★}

(For mini‑batches, replace xxx x^\top with the Gram matrix 1BXX\tfrac{1}{B}X^\top X.)

2.4 Continuum Memory System (CMS)


CMS composes memories with distinct clocks:

yt  =  MLP(fk)(MLP(f1)(xt)),θ(f) updates every C() steps.y_t \;=\; \mathrm{MLP}^{(f_k)}\big(\cdots\mathrm{MLP}^{(f_1)}(x_t)\big),\quad\theta^{(f_\ell)} \text{ updates every } C^{(\ell)} \text{ steps.}

With per‑level accumulation, update θ(f)\theta^{(f_\ell)} only when the period divides the global step.

Figure 1. Multi‑time‑scale update schedule for three levels. Only due levels step; others accumulate gradients.

3. Geometry of the deep‑optimizer projection


The transformation WW(Iαxx)W \mapsto W(I - \alpha x x^\top) is a right‑side projection in input space:WW(Iαxx)W \to W(I - \alpha x x^\top)

Figure 2. Outputs WuWu (solid) vs W(Iαxx)uW(I - \alpha x x^\top)u (dashed) for unit directions uu. The largest change is along xx; the thick ray is WxWx.

Figure 3. Update magnitude [W(Iαxx)W]u2\|[W(I-\alpha x x^\top)-W]u\|_2 vs (u,x)\angle(u,x). Maximal when uxu\parallel x, minimal when uxu\perp x.

Figure 4. Batched projection with P=Iα1BXXP = I - \alpha\tfrac{1}{B}X^\top X: unit circle maps to an ellipse. Labels show shrink factors 1αλi1-\alpha\lambda_i along Gram eigenvectors.

4. Multi-Clock Training Sequence


The figure below illustrates the interaction between model blocks, accumulators, and optimizers during a training step. Each block has its own clock; at each global step, we accumulate gradients, then apply updates only for blocks whose period divides the step counter.

Figure 6 Multi‑clock training sequence showing gradient accumulation, period checking, and update scheduling. At step t=127, all three blocks update since their periods (1, 16, 128) all divide 128 evenly.

5. Implementation (runnable code)


This section includes complete, ready‑to‑run modules with detailed comments.


Listing 1 — Minimal reference (DeepL2GD + CMS + toy continual learning)


See `nested_learning_minimal.py` in the accompanying files for the fully-commented version. The code demonstrates:

  • NLLinear: Linear layers with cached inputs for Gram-matrix computation
  • DeepL2GD: Applies projection before base optimizer update
  • CMSSequential: Multi-timescale learner with per-block periods
  • NLTrainer: Trains on sequential tasks and measures retention/plasticity

Key experiment: train on Task A, then Task B. Report accuracy on both tasks to measure how well the multi-clock system balances new learning with memory of past tasks.

Listing 2 — EMA‑smoothed deep optimizer


See `deep_l2gd_ema.py` for the fully-commented version. **Key difference from DeepL2GD**:- Maintains an exponential moving average of Gram matrices- Updates: Gt=(1β)Gbatch+βGt1G_t = (1-\beta) G_{\text{batch}} + \beta G_{t-1}- Smooths out noise from small or biased minibatches- Useful for noisy continual learning scenarios where gradient directions fluctuate

Listing 3 — Convolutional variant (channel‑covariance approximation)

See `nl_conv.py` for the fully-commented version. Applies Nested Learning to Conv2d:- Channel-level covariance instead of full spatial Gram matrix (for efficiency)- Input reshape: (B,C,H,W)(BHW,C)(B, C, H, W) \to (B \cdot H \cdot W, C) to compute channel statistics- Weight averaging across kernel spatial dimensions reduces dimensionality- Scaling preserves average kernel magnitude during projection

Listing 4 — Tiny Transformer wired with clocks

See `nl_transformer_tiny.py` for the fully-commented version. Demonstrates multi-clock Transformer:

  • MHA: Q/K/V and O projections with NLLinear input caching
  • Block: Attention + FFN with residuals
  • wire_clocks(): Assigns C=1 (fast) to attention heads, C=8 to O projection, C=64 (slow) to FFN
  • step_with_clocks(): Applies updates based on period divisibility
  • make_synth_bigrams(): Two bigram regimes for continual learning Trains on Task A, transitions to Task B with distribution shift

Figure 5. Multi‑clock schedule: Q/K/V update every step (C=1), output projection slower (C=8), FFN slowest (C=64).

6. Practical recipes

  • Clock assignment: fast for sequence‑local paths (Q/K/V, head), slower for knowledge‑heavy paths (O, FFN). Geometric periods $\{1,8,64,512\}$ are a robust default.
  • Deep‑optimizer strength: $\alpha\in [10^{-4}, 10^{-3}]$ to start. Larger values risk underfitting dominant directions.
  • EMA vs per‑batch: use EMA when minibatches are noisy or small; per‑batch when you want rapid adaptation.

7. Reproducibility

  • Minimal run: `python nested_learning_minimal.py` (reports Task‑A/B accuracies before/after sequential training). -Transformer demo: `python nl_transformer_tiny.py` (synthetic bigram regimes to simulate distribution shift).
  • Conv projection: import `NLConv2d` and call `conv_channel_projection_step` at due steps.
Pronam Chatterjee
Author spotlight

About Pronam Chatterjee

A visionary with 25 years of technical leadership under his belt, Pronam isn’t just ahead of the curve; he’s redefining it. His expertise extends beyond the technical, making him a sought-after speaker and published thought leader.

Whether strategizing the next technology and data innovation or his next chess move, Pronam thrives on pushing boundaries. He is a father of two loving daughters and a Golden Retriever.

With a blend of brilliance, vision, and genuine connection, Pronam is more than a leader; he’s an architect of the future, building something extraordinary

Related Posts

View all posts
This website uses cookies to enhance user experience and analyze site usage. By clicking "Accept All", you consent to our use of cookies for analytics purposes. Privacy Policy