Docs

Training Guide

ZeroProofML training separates smooth optimization from strict deployment behavior. The model may train on projective tuples and regularized gradients, but validation and deployment should decode with the same SCM masks that production will use.

Training Workflow

  1. Build a model that emits either SCM-aware outputs or projective (P, Q) tuples.
  2. Lift targets into projective form when using rational heads.
  3. Combine fit, margin, sign, and rejection terms.
  4. Pick a gradient policy for singular paths.
  5. Monitor coverage, denominator statistics, and validation masks.
  6. Freeze tau_infer from held-out data before export.

Projective Head Pattern

A common PyTorch model emits numerator and denominator tensors:

import torch
import torch.nn as nn

class RationalHead(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )
        self.head = nn.Linear(64, 2)

    def forward(self, x):
        p, q = self.head(self.backbone(x)).unbind(dim=-1)
        return p, q

Training losses work on (P, Q) without directly dividing by Q.

Target Lifting

Finite targets are lifted into homogeneous coordinates:

y -> (Y_n, Y_d) = (y, 1)

Bottom or censored targets can be represented with denominator zero according to the task. The public target helpers live under zeroproofml.training.

Loss Stack

ZeroProofML v0.5.1 combines four ideas:

Term Purpose
Implicit loss Fits P/Q to the target without direct division
Margin loss Pushes denominators away from the training threshold
Sign consistency Aligns projective orientation so poles have stable direction
Rejection loss Penalizes coverage below a target rate

The implicit fit term uses a cross-product shape:

E = (P * Y_d - Q * Y_n)^2

This keeps gradients defined near small denominators.

Example wiring:

from zeroproofml.losses import LossConfig, SCMTrainingLoss

loss_fn = SCMTrainingLoss(
    LossConfig(
        tau_train=1e-4,
        target_coverage=0.95,
        lambda_margin=0.1,
        lambda_sign=1.0,
        lambda_rej=0.01,
    )
)

Gradient Policies

Gradient policies control how backward passes treat bottom paths.

Policy Typical use
CLAMP Default for SCM-only graphs; zero bottom gradients and clamp finite gradients
PROJECT Projective rational heads near poles
REJECT Learn through coverage/rejection signals rather than local singular gradients
PASSTHROUGH Debugging only
from zeroproofml.autodiff.policies import GradientPolicy, gradient_policy

with gradient_policy(GradientPolicy.PROJECT):
    loss.backward()

Prefer explicit policy choices in rational heads. PASSTHROUGH is useful for diagnostics, but it is usually the wrong production default.

Trainer Configuration

SCMTrainer provides the reference training loop:

from zeroproofml.training import SCMTrainer, TrainingConfig

config = TrainingConfig(
    max_epochs=50,
    mixed_precision=True,
    tau_train_min=5e-5,
    tau_train_max=2e-4,
    coverage_threshold=0.90,
    coverage_patience=5,
)

trainer = SCMTrainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    train_loader=train_loader,
    config=config,
)
trainer.fit()

Important controls:

Setting Meaning
tau_train_min, tau_train_max Training-time denominator threshold range
mixed_precision / use_amp AMP support where appropriate
coverage_threshold, coverage_patience Coverage-based early stop
val_loader Per-epoch validation with val_ metrics
gradient_policy Global policy override for training steps
loss_curriculum Optional per-epoch loss-weight schedule
log_hook Callback for JSONL metrics

Coverage Control

Coverage is the finite-output rate:

from zeroproofml.losses import coverage, rejection_loss

cov = coverage(outputs=decoded, is_bottom=bottom_mask)
rej = rejection_loss(bottom_mask, target_coverage=0.95)

Use rejection loss when a model solves difficult cases by returning bottom too often. Keep it balanced: high coverage is not useful if the accepted finite predictions are unsafe or numerically unstable.

Logging

The stable logging path is JSONL:

from zeroproofml.training import TrainingConfig
from zeroproofml.utils.logging import JsonlLogger

config = TrainingConfig(
    log_hook=JsonlLogger("runs/scm_train_metrics.jsonl")
)

Trainer records use the zeroproofml.metric_log schema and include fields such as:

  • loss
  • coverage
  • bottom_frac
  • tau_train
  • denom_abs_min
  • denom_abs_mean
  • validation metrics with val_ prefixes

TensorBoard and DataFrame helpers are available in optional extras, but JSONL is the release-facing artifact format.

Dtype Guidance

Use float64 when the task spends real time near singular surfaces:

import torch
torch.set_default_dtype(torch.float64)

Mixed precision can still be useful away from denominators, but rational heads and strict thresholds should be validated carefully before relying on AMP in safety-sensitive paths.

Training Checklist

  • Use projective tuples for rational heads that must learn near poles.
  • Compute fit losses without direct division where possible.
  • Track coverage and denominator statistics per epoch.
  • Keep tau_train and tau_infer conceptually separate.
  • Validate strict inference on a held-out split before export.
  • Save JSONL logs and checkpoint metadata with each serious run.