WW-PGD: WeightWatcher Projected Gradient Descent

Status: Experimental  |  GitHub  |  WeightWatcher

Abstract

WW-PGD is a lightweight spectral projection add-on for PyTorch optimizers (SGD, Adam, AdamW, Muon, etc.). It wraps your existing optimizer and periodically projects each layer’s weight spectrum toward a critical heavy-tailed manifold motivated by HTSR/SETOL: the tail exponent α is driven toward α ≈ 2 and the SETOL ERG condition trace-log(λ) in the tail = 0 (equivalently detX = 1) is enforced on the tail subspace. In practice, WW-PGD uses WeightWatcher diagnostics (detX_num and num_pl_spikes) to select the tail via a midpoint rule at each projection step (epoch or batch boundary).


1 · What WW-PGD Does

1.1 “Add-on” optimizer design

WW-PGD does not replace your optimizer. Instead:

1.2 Tail selection via WeightWatcher midpoint rule

At each projection step, WeightWatcher provides (per layer):

WW-PGD selects the working tail size using a midpoint rule: k_mid = floor((detX_num + num_pl_spikes)/2). As the model approaches the critical regime (α → 2), SETOL predicts these two quantities converge, so the midpoint becomes effectively exact.

1.3 Projection target: HTSR + SETOL critical conditions

1.4 How the projection is applied (high level)

For each layer weight matrix, WW-PGD:

  1. Computes an SVD / eigen-spectrum (tail only is modified).
  2. Constructs a rank-ordered power-law template for the tail (target α schedule, never targeting α < 2).
  3. Applies a stable Cayley-style update in log-eigenvalue space.
  4. Retracts to satisfy the ERG trace-log condition on the tail.
  5. Reconstructs the weight matrix and blends it back into the model.

2 · Results & Figures

2.1 MNIST: Plain vs Augmented Test

The following plots show mean ± std across runs for: plain test accuracy, augmented test accuracy, and layer-wise α trajectories from WeightWatcher. The augmented evaluation uses mild, in-distribution perturbations (small rotation/translation + light blur), intended as a robustness proxy.

MNIST plain test accuracy (mean ± std)
Figure 1. MNIST plain test accuracy (mean ± std): Baseline vs WW-PGD.
MNIST augmented test accuracy (mean ± std)
Figure 2. MNIST augmented test accuracy (mean ± std): Baseline vs WW-PGD.
Layer-wise alpha (mean ± std)
Figure 3. Layer-wise HTSR exponent α (mean ± std): WW-PGD tends to stabilize α trajectories toward the critical regime.

2.2 FashionMNIST summary (QuickStart notebook)

The FashionMNIST experiments are documented in the QuickStart notebook: WW_PGD_QuickStart.ipynb .

Final results (epoch 35, mean ± std)
  • Baseline: plain = 98.05% ± 0.13, augmented = 96.24% ± 0.17
  • WW-PGD: plain = 97.99% ± 0.17, augmented = 96.23% ± 0.20

Interpretation (early read):


3 · How to Use WW-PGD

3.1 Install

pip install git+https://github.com/CalculatedContent/WW_PGD.git

3.2 Minimal usage (wrap any optimizer)

import torch
import torch.nn as nn
import torch.nn.functional as F
import ww_pgd

model = nn.Linear(10, 10)

base_opt = torch.optim.AdamW(model.parameters(), lr=1e-3)

cfg = ww_pgd.WWTailConfig(
    warmup_epochs=0,
    ramp_epochs=5,
    min_tail=5,
    blend_eta=0.5,
    cayley_eta=0.25,
)

opt = ww_pgd.WWPGDWrapper(model, base_opt, cfg)

for epoch in range(num_epochs):
    for xb, yb in loader:
        loss = F.cross_entropy(model(xb), yb)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

    opt.apply_tail_projection(epoch=epoch, num_epochs=num_epochs)

3.3 Practical knobs


4 · Performance Notes & Feedback

WW-PGD can be slower than plain training because it performs spectral analysis and reconstruction. This overhead is the main reason we use: warmup, ramping, and (often) epoch-boundary projections rather than per-step projections.

We are actively working on performance improvements (faster decompositions, better batching of diagnostics, and more selective projection policies). Feedback is valuable—especially cases where: