Training Guide
ZeroProofML training separates smooth optimization from strict deployment behavior. The model may train on projective tuples and regularized gradients, but validation and deployment should decode with the same SCM masks that production will use.
Training Workflow
- Build a model that emits either SCM-aware outputs or projective
(P, Q)tuples. - Lift targets into projective form when using rational heads.
- Combine fit, margin, sign, and rejection terms.
- Pick a gradient policy for singular paths.
- Monitor coverage, denominator statistics, and validation masks.
- Freeze
tau_inferfrom held-out data before export.
Projective Head Pattern
A common PyTorch model emits numerator and denominator tensors:
import torch
import torch.nn as nn
class RationalHead(nn.Module):
def __init__(self, input_dim: int):
super().__init__()
self.backbone = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
)
self.head = nn.Linear(64, 2)
def forward(self, x):
p, q = self.head(self.backbone(x)).unbind(dim=-1)
return p, q
Training losses work on (P, Q) without directly dividing by Q.
Target Lifting
Finite targets are lifted into homogeneous coordinates:
y -> (Y_n, Y_d) = (y, 1)
Bottom or censored targets can be represented with denominator zero according to the task. The public target helpers live under zeroproofml.training.
Loss Stack
ZeroProofML v0.5.1 combines four ideas:
| Term | Purpose |
|---|---|
| Implicit loss | Fits P/Q to the target without direct division |
| Margin loss | Pushes denominators away from the training threshold |
| Sign consistency | Aligns projective orientation so poles have stable direction |
| Rejection loss | Penalizes coverage below a target rate |
The implicit fit term uses a cross-product shape:
E = (P * Y_d - Q * Y_n)^2
This keeps gradients defined near small denominators.
Example wiring:
from zeroproofml.losses import LossConfig, SCMTrainingLoss
loss_fn = SCMTrainingLoss(
LossConfig(
tau_train=1e-4,
target_coverage=0.95,
lambda_margin=0.1,
lambda_sign=1.0,
lambda_rej=0.01,
)
)
Gradient Policies
Gradient policies control how backward passes treat bottom paths.
| Policy | Typical use |
|---|---|
CLAMP |
Default for SCM-only graphs; zero bottom gradients and clamp finite gradients |
PROJECT |
Projective rational heads near poles |
REJECT |
Learn through coverage/rejection signals rather than local singular gradients |
PASSTHROUGH |
Debugging only |
from zeroproofml.autodiff.policies import GradientPolicy, gradient_policy
with gradient_policy(GradientPolicy.PROJECT):
loss.backward()
Prefer explicit policy choices in rational heads. PASSTHROUGH is useful for diagnostics, but it is usually the wrong production default.
Trainer Configuration
SCMTrainer provides the reference training loop:
from zeroproofml.training import SCMTrainer, TrainingConfig
config = TrainingConfig(
max_epochs=50,
mixed_precision=True,
tau_train_min=5e-5,
tau_train_max=2e-4,
coverage_threshold=0.90,
coverage_patience=5,
)
trainer = SCMTrainer(
model=model,
optimizer=optimizer,
loss_fn=loss_fn,
train_loader=train_loader,
config=config,
)
trainer.fit()
Important controls:
| Setting | Meaning |
|---|---|
tau_train_min, tau_train_max |
Training-time denominator threshold range |
mixed_precision / use_amp |
AMP support where appropriate |
coverage_threshold, coverage_patience |
Coverage-based early stop |
val_loader |
Per-epoch validation with val_ metrics |
gradient_policy |
Global policy override for training steps |
loss_curriculum |
Optional per-epoch loss-weight schedule |
log_hook |
Callback for JSONL metrics |
Coverage Control
Coverage is the finite-output rate:
from zeroproofml.losses import coverage, rejection_loss
cov = coverage(outputs=decoded, is_bottom=bottom_mask)
rej = rejection_loss(bottom_mask, target_coverage=0.95)
Use rejection loss when a model solves difficult cases by returning bottom too often. Keep it balanced: high coverage is not useful if the accepted finite predictions are unsafe or numerically unstable.
Logging
The stable logging path is JSONL:
from zeroproofml.training import TrainingConfig
from zeroproofml.utils.logging import JsonlLogger
config = TrainingConfig(
log_hook=JsonlLogger("runs/scm_train_metrics.jsonl")
)
Trainer records use the zeroproofml.metric_log schema and include fields such as:
losscoveragebottom_fractau_traindenom_abs_mindenom_abs_mean- validation metrics with
val_prefixes
TensorBoard and DataFrame helpers are available in optional extras, but JSONL is the release-facing artifact format.
Dtype Guidance
Use float64 when the task spends real time near singular surfaces:
import torch
torch.set_default_dtype(torch.float64)
Mixed precision can still be useful away from denominators, but rational heads and strict thresholds should be validated carefully before relying on AMP in safety-sensitive paths.
Training Checklist
- Use projective tuples for rational heads that must learn near poles.
- Compute fit losses without direct division where possible.
- Track coverage and denominator statistics per epoch.
- Keep
tau_trainandtau_inferconceptually separate. - Validate strict inference on a held-out split before export.
- Save JSONL logs and checkpoint metadata with each serious run.