Status: Experimental | GitHub | WeightWatcher
WW-PGD is a lightweight spectral projection add-on for PyTorch optimizers (SGD, Adam, AdamW, Muon, etc.).
It wraps your existing optimizer and periodically projects each layer’s weight spectrum toward a
critical heavy-tailed manifold motivated by HTSR/SETOL: the tail exponent
α is driven toward α ≈ 2 and the SETOL ERG condition
trace-log(λ) in the tail = 0 (equivalently detX = 1) is enforced on the tail subspace.
In practice, WW-PGD uses WeightWatcher diagnostics (detX_num and num_pl_spikes) to
select the tail via a midpoint rule at each projection step (epoch or batch boundary).
WW-PGD does not replace your optimizer. Instead:
At each projection step, WeightWatcher provides (per layer):
detX_num: a trace-log based estimate of effective tail size (ERG / detX tail).num_pl_spikes: number of power-law “spike” outliers in the tail.
WW-PGD selects the working tail size using a midpoint rule:
k_mid = floor((detX_num + num_pl_spikes)/2).
As the model approaches the critical regime (α → 2), SETOL predicts these two quantities converge,
so the midpoint becomes effectively exact.
For each layer weight matrix, WW-PGD:
The following plots show mean ± std across runs for: plain test accuracy, augmented test accuracy, and layer-wise α trajectories from WeightWatcher. The augmented evaluation uses mild, in-distribution perturbations (small rotation/translation + light blur), intended as a robustness proxy.
The FashionMNIST experiments are documented in the QuickStart notebook: WW_PGD_QuickStart.ipynb .
Interpretation (early read):
pip install git+https://github.com/CalculatedContent/WW_PGD.git
import torch
import torch.nn as nn
import torch.nn.functional as F
import ww_pgd
model = nn.Linear(10, 10)
base_opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
cfg = ww_pgd.WWTailConfig(
warmup_epochs=0,
ramp_epochs=5,
min_tail=5,
blend_eta=0.5,
cayley_eta=0.25,
)
opt = ww_pgd.WWPGDWrapper(model, base_opt, cfg)
for epoch in range(num_epochs):
for xb, yb in loader:
loss = F.cross_entropy(model(xb), yb)
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
opt.apply_tail_projection(epoch=epoch, num_epochs=num_epochs)
WW-PGD can be slower than plain training because it performs spectral analysis and reconstruction. This overhead is the main reason we use: warmup, ramping, and (often) epoch-boundary projections rather than per-step projections.
We are actively working on performance improvements (faster decompositions, better batching of diagnostics, and more selective projection policies). Feedback is valuable—especially cases where: