Source code for tod_calibrate

"""
Runtime calibration for tod_exact_gen_batched.

Calibrates three knobs jointly:
  * n_processes (P)        — worker processes (parallel over days)
  * numba_threads (T)      — threads per worker (parallel over batch via prange)
  * batch_size (B)         — samples per fused-kernel invocation

The fused kernel (a1d5d36) parallelises the entire Rodrigues+gather over
prange(B). Spawning P workers each using T = NUMBA_NUM_THREADS_DEFAULT (=all
cores) oversubscribes by P×; the new search enforces P*T ≤ N_cores.

Strategy (~30s wall time):
  Phase A — single-process throughput vs T at a fixed B.
  Phase B — for best T, sweep B around 16×T to land on the throughput plateau.
  Phase C — enumerate (P, T) with P*T ≤ N_cores; pick max  P × tp(T).
            Memory budget per process must accommodate mp_stacked (already in
            shared memory, but counted defensively) plus transient per-batch
            buffers.

Beam clustering calibration is unchanged (driven by science accuracy, not speed).
"""

import gc
import time
import numpy as np
import healpy as hp

import numba

import tod_config as config
from tod_io import load_scan_data_batch
from tod_core import precompute_rotation_vector_batch, beam_tod_batch
from tod_utils import _fmt_time, _get_memory_per_process
from tod_beam_math import compute_bell
from beam_cluster import cluster_beam_pixels

# Per-batch transient memory (bytes per sample). The fused kernel no longer
# materialises a (B, S, 3) Rodrigues buffer (commit a1d5d36); only small
# B-scaled arrays remain (pointings, weights, output).
_BYTES_PER_SAMPLE_TRANSIENT = {
    "nearest": 80,
    "bilinear": 120,
    "bicubic": 400,
    "gaussian": 400,
}
_BYTES_PER_SAMPLE_TRANSIENT_DEFAULT = 200
_MEMORY_SAFETY_FACTOR = 1.5

# Each thread needs at least this many samples to amortise prange overhead.
_MIN_SAMPLES_PER_THREAD = 256
# Above this we hit a plateau; further B growth wastes memory.
_TARGET_SAMPLES_PER_THREAD = 1024

# Probe sizing — keep total wall time near 30s.
_PROBE_TARGET_SECONDS = 1.5  # per measurement cell
_PROBE_MIN_SAMPLES = 20_000
_PROBE_MAX_SAMPLES = 400_000


# ── Memory model ─────────────────────────────────────────────────────────────


def _per_proc_static_bytes(beam_data, nside):
    """Per-worker static memory: mp_stacked for every beam file."""
    npix = 12 * nside * nside
    return sum(
        d["mp_stacked"].nbytes
        if "mp_stacked" in d
        else len(d["comp_indices"]) * npix * 4
        for d in beam_data.values()
    )


def _max_batch_for_memory(mem_per_proc_gb, beam_data, nside, interp_mode):
    """Largest B that fits the per-batch transient buffers in budget."""
    bps = _BYTES_PER_SAMPLE_TRANSIENT.get(
        interp_mode, _BYTES_PER_SAMPLE_TRANSIENT_DEFAULT
    )
    static_gb = _per_proc_static_bytes(beam_data, nside) / 1e9
    transient_budget_gb = mem_per_proc_gb / _MEMORY_SAFETY_FACTOR - static_gb
    if transient_budget_gb <= 0.05:
        return 0
    return max(1, int(transient_budget_gb * 1e9 // bps))


# ── Probe runner ─────────────────────────────────────────────────────────────


def _make_probe_data(beam_data, folder_scan, probe_day, n_samples):
    theta_p, phi_p, psi_p = load_scan_data_batch(
        folder_scan, probe_day, 0, n_samples, dtype=config.precision_dtype
    )
    n = min(n_samples, len(phi_p))
    return phi_p[:n], theta_p[:n], psi_p[:n]


def _run_one(
    nside,
    mp,
    beam_data,
    ra0,
    dec0,
    phi_p,
    theta_p,
    psi_p,
    bs,
    interp_mode,
    center_idx=None,
    z_skip_threshold=-1.0,
):
    """Run probe at given batch size. Returns wall time."""
    n = len(phi_p)
    n_batches = (n + bs - 1) // bs
    t0 = time.perf_counter()
    for b in range(n_batches):
        s, e = b * bs, min((b + 1) * bs, n)
        phi_b, theta_b, psi_b = phi_p[s:e], theta_p[s:e], psi_p[s:e]
        rot_vecs, betas = precompute_rotation_vector_batch(
            ra0, dec0, phi_b, theta_b, center_idx=center_idx
        )
        psis_b = -betas + psi_b
        for data in beam_data.values():
            beam_tod_batch(
                nside,
                mp,
                data,
                rot_vecs,
                phi_b,
                theta_b,
                psis_b,
                interp_mode=interp_mode,
                z_skip_threshold=z_skip_threshold,
            )
    return time.perf_counter() - t0


def _measure_throughput(
    nside,
    mp,
    beam_data,
    ra0,
    dec0,
    phi_full,
    theta_full,
    psi_full,
    bs,
    n_threads,
    interp_mode,
    prefix="",
    center_idx=None,
    z_skip_threshold=-1.0,
):
    """Set thread count, measure throughput at given batch size.

    phi_full/theta_full/psi_full must hold _PROBE_MAX_SAMPLES samples
    (loaded once by the caller); this function slices them.
    """
    numba.set_num_threads(max(1, n_threads))

    # Adapt probe size to target ~_PROBE_TARGET_SECONDS.
    # Use a small pilot to estimate samples/sec, then size the real probe.
    pilot_n = max(bs * 2, 4_000)
    pilot_n = min(pilot_n, len(phi_full))
    phi_p, theta_p, psi_p = phi_full[:pilot_n], theta_full[:pilot_n], psi_full[:pilot_n]
    pilot_t = _run_one(
        nside,
        mp,
        beam_data,
        ra0,
        dec0,
        phi_p,
        theta_p,
        psi_p,
        bs,
        interp_mode,
        center_idx=center_idx,
        z_skip_threshold=z_skip_threshold,
    )
    rate = len(phi_p) / max(pilot_t, 1e-6)
    target_n = int(rate * _PROBE_TARGET_SECONDS)
    target_n = max(_PROBE_MIN_SAMPLES, min(target_n, len(phi_full)))
    target_n = max(target_n, bs * 4)  # at least 4 batches

    if target_n > len(phi_p):
        phi_p = phi_full[:target_n]
        theta_p = theta_full[:target_n]
        psi_p = psi_full[:target_n]

    # Take the best of 2 short runs to suppress noise.
    best_t = float("inf")
    for _ in range(2):
        t = _run_one(
            nside,
            mp,
            beam_data,
            ra0,
            dec0,
            phi_p,
            theta_p,
            psi_p,
            bs,
            interp_mode,
            center_idx=center_idx,
            z_skip_threshold=z_skip_threshold,
        )
        best_t = min(best_t, t)
    gc.collect()
    tp = len(phi_p) / best_t
    print(
        prefix + f"  T={n_threads:>3d}  B={bs:>6d}  "
        f"n={len(phi_p):>7d}  t={_fmt_time(best_t):>7s}  "
        f"tp={tp:>12,.0f} samp/s"
    )
    return tp


def _thread_candidates(n_cores):
    """Powers of 2 up to n_cores, plus n_cores itself."""
    cands = set()
    t = 1
    while t <= n_cores:
        cands.add(t)
        t *= 2
    cands.add(n_cores)
    return sorted(cands)


def _process_thread_pairs(n_cores, max_processes):
    """All (P, T) with P*T ≤ n_cores, P ≤ max_processes, T ≥ 1."""
    pairs = []
    for t in _thread_candidates(n_cores):
        max_p = min(n_cores // t, max_processes)
        for p in range(1, max_p + 1):
            if p * t <= n_cores:
                pairs.append((p, t))
    return pairs


# ── Public entry point ──────────────────────────────────────────────────────


[docs] def calibrate_runtime( beam_data, folder_scan, probe_day, mp, n_cpu_ceiling, max_processes_user, interp_mode="bilinear", prefix="", center_idx=None, z_skip_threshold=-1.0, ): """Joint (n_processes, numba_threads, batch_size) calibration. Args: beam_data: from prepare_beam_data (after clustering, with mp_stacked). folder_scan: scan directory. probe_day: any valid day index for probe data. mp: list of sky-map components. n_cpu_ceiling: hard ceiling from scheduler/affinity (_get_ncpus()). max_processes_user: user-configured n_processes (laptop cap, etc). Acts as an upper bound on P. interp_mode: 'nearest' or 'bilinear'. prefix: log prefix. Returns: (n_processes, n_threads, batch_size) """ nside = hp.get_nside(next(iter(mp.values())) if isinstance(mp, dict) else mp[0]) first_bf = next(iter(beam_data)) ra0, dec0 = beam_data[first_bf]["ra"], beam_data[first_bf]["dec"] n_cores = max(1, n_cpu_ceiling) max_p = max(1, min(max_processes_user, n_cores)) total_mem_gb = _get_memory_per_process(1) print( prefix + f"[calibrate] n_cores={n_cores} max_processes={max_p} " f"total_mem={total_mem_gb:.1f} GB interp={interp_mode}" ) # ── Phase A: throughput vs threads at a reference batch size ──────────── # Use a B large enough to feed the maximum thread count (T=n_cores). ref_bs = max(_TARGET_SAMPLES_PER_THREAD * n_cores, 4096) # Cap by memory at P=1. bs_cap_p1 = _max_batch_for_memory(total_mem_gb, beam_data, nside, interp_mode) if bs_cap_p1 < 256: raise RuntimeError( f"[calibrate] memory too tight: per-process budget " f"{total_mem_gb:.2f} GB cannot fit static beam arrays " f"({_per_proc_static_bytes(beam_data, nside) / 1e9:.2f} GB) plus " f"any reasonable batch." ) ref_bs = min(ref_bs, bs_cap_p1) # Load probe scan data once — _measure_throughput slices as needed. phi_full, theta_full, psi_full = _make_probe_data( beam_data, folder_scan, probe_day, _PROBE_MAX_SAMPLES ) print(prefix + f"[calibrate] Phase A — sweep threads at B={ref_bs}") tp_by_threads = {} for t in _thread_candidates(n_cores): tp = _measure_throughput( nside, mp, beam_data, ra0, dec0, phi_full, theta_full, psi_full, bs=ref_bs, n_threads=t, interp_mode=interp_mode, prefix=prefix, center_idx=center_idx, z_skip_threshold=z_skip_threshold, ) tp_by_threads[t] = tp # ── Phase B: at each candidate T, find a good B (only for non-trivial T) ─ # We sweep B at the top few T candidates because B-scaling can differ. print(prefix + "[calibrate] Phase B — sweep batch size at top thread counts") top_threads = sorted(tp_by_threads, key=lambda t: -tp_by_threads[t])[:3] bs_by_threads = {} tp_at_bs = {} # (T, B_chosen) -> throughput for t in top_threads: # Candidate batch sizes around the target sweet spot for this T. target = _TARGET_SAMPLES_PER_THREAD * t cands = sorted( { max(_MIN_SAMPLES_PER_THREAD * t, 256), max(target // 2, 512), target, target * 2, } ) cands = [b for b in cands if 256 <= b <= bs_cap_p1] if not cands: cands = [min(ref_bs, bs_cap_p1)] # Prepend the already-measured ref point if not already in the list best_b, best_tp = ref_bs, tp_by_threads[t] for b in cands: if b == ref_bs: tp = tp_by_threads[t] else: tp = _measure_throughput( nside, mp, beam_data, ra0, dec0, phi_full, theta_full, psi_full, bs=b, n_threads=t, interp_mode=interp_mode, prefix=prefix, center_idx=center_idx, ) tp_at_bs[(t, b)] = tp if tp > best_tp: best_tp, best_b = tp, b bs_by_threads[t] = best_b tp_by_threads[t] = best_tp # update with best B for this T # ── Phase C: enumerate (P, T) with P*T ≤ n_cores; pick best ───────────── print(prefix + "[calibrate] Phase C — score (P, T) combinations") print( prefix + f" {'P':>3s} {'T':>3s} {'B':>6s} " f"{'tp/proc':>14s} {'est total':>14s} status" ) print(prefix + " " + "-" * 60) best = None # (score, P, T, B) for p, t in _process_thread_pairs(n_cores, max_p): if t not in tp_by_threads: continue # only score Ts we measured mem_per_proc = total_mem_gb / p bs_cap = _max_batch_for_memory(mem_per_proc, beam_data, nside, interp_mode) if bs_cap < _MIN_SAMPLES_PER_THREAD * t: print( prefix + f" {p:>3d} {t:>3d} {'-':>6s} {'-':>14s} {'-':>14s} oom" ) continue # Use the B chosen in phase B for this T, capped by per-proc memory. b_choice = min(bs_by_threads.get(t, ref_bs), bs_cap) # Look up throughput at (T, b_choice) if measured, else use phase A ref. tp_per_proc = tp_at_bs.get((t, b_choice), tp_by_threads[t]) score = p * tp_per_proc marker = "" if best is None or score > best[0]: best = (score, p, t, b_choice) marker = " ←" print( prefix + f" {p:>3d} {t:>3d} {b_choice:>6d} " f"{tp_per_proc:>14,.0f} {score:>14,.0f}{marker}" ) if best is None: raise RuntimeError( "[calibrate] no viable (P, T) combination — " "memory too tight for any batch size." ) _, P, T, B = best print( prefix + f"[calibrate] → n_processes={P} numba_threads={T} " f"batch_size={B} (est total tp={best[0]:,.0f} samp/s)" ) return P, T, B
# ── Beam clustering calibration (unchanged) ────────────────────────────────── def _run_clustering_probe( nside, mp, beam_entries, rot_vecs, phi_b, theta_b, psis_b, interp_mode ): """Run beam_tod_batch for all entries and accumulate into a (3, B) array.""" B = len(phi_b) tod = np.zeros((3, B), dtype=np.float64) for data in beam_entries: contrib = beam_tod_batch( nside, mp, data, rot_vecs, phi_b, theta_b, psis_b, interp_mode=interp_mode, ) for comp, vals in contrib.items(): tod[comp] += vals return tod
[docs] def calibrate_beam_clustering( beam_data, folder_scan=None, probe_day=None, mp=None, error_threshold=1e-3, bell_lmax=None, interp_mode="bilinear", ): """Find (tail_fraction, n_clusters) maximising speedup s.t. B_ell error ≤ error_threshold. Computes reference B_ell (power_cut=1.0) from unclustered beam, then sweeps a fixed (tail_fraction × n_clusters) grid. The pair maximising speedup with relative-RMS B_ell divergence ≤ error_threshold wins; if no pair qualifies, the minimum-divergence pair is returned with a warning. """ tail_fractions = (0.005, 0.01, 0.02, 0.03, 0.05, 0.075, 0.10, 0.15, 0.20, 0.30) n_clusters_list = (10, 20, 50, 100, 200, 500, 1000, 2000) if bell_lmax is None: if mp is not None: _ref = next(iter(mp.values())) if isinstance(mp, dict) else mp[0] bell_lmax = 2 * hp.get_nside(_ref) else: bell_lmax = 500 def _bell_from_vecs(vec, bvals): theta_pix, phi_pix = hp.vec2ang(vec) dec_offset = np.pi / 2.0 - theta_pix _, bell = compute_bell( phi_pix, dec_offset, bvals.astype(np.float64), lmax=bell_lmax, power_cut=1.0, verbose=False, ) return bell print( f"[clust_calib] Computing reference B_ell (power_cut=1.0, lmax={bell_lmax}) …" ) ref_bells = {} for bf, data in beam_data.items(): ref_bells[bf] = _bell_from_vecs(data["vec_orig"], data["beam_vals"]) S_bf = {bf: data["n_sel"] for bf, data in beam_data.items()} # Pre-compute n_tail for each (beam_file, tail_fraction) to enable # short-circuiting when K_req >= n_tail (no clustering occurs). n_tail_per_bf_tf = {} for bf, data in beam_data.items(): bv = data["beam_vals"] n_tail_per_bf_tf[bf] = {} sort_idx = np.argsort(bv) cumsum = np.cumsum(bv[sort_idx]) S = len(bv) for tf in tail_fractions: n_tail = int(np.searchsorted(cumsum, tf, side="right")) n_tail_per_bf_tf[bf][tf] = max(1, min(n_tail, S - 1)) print("[clust_calib] Sweeping clustering parameters …") results = [] for tf in tail_fractions: for K_req in n_clusters_list: K_out_per_bf = {} bell_divs = [] for bf, data in beam_data.items(): n_tail = n_tail_per_bf_tf[bf][tf] if K_req >= n_tail: # Tail already fits in K_req clusters — no reduction possible. K_out_per_bf[bf] = S_bf[bf] bell_divs.append(0.0) continue vec_c, bv_c, _ = cluster_beam_pixels( data["vec_orig"], data["beam_vals"], n_clusters=K_req, tail_fraction=tf, verbose=False, ) K_out_per_bf[bf] = len(bv_c) bell_clust = _bell_from_vecs(vec_c, bv_c) bell_ref = ref_bells[bf] ref_rms = float(np.sqrt(np.mean(bell_ref**2))) bell_div = float(np.sqrt(np.mean((bell_clust - bell_ref) ** 2))) / ( ref_rms + 1e-30 ) bell_divs.append(bell_div) mean_bell_div = float(np.mean(bell_divs)) speedup = float(np.mean([S_bf[bf] / K_out_per_bf[bf] for bf in beam_data])) K_out_repr = int(np.mean(list(K_out_per_bf.values()))) results.append((tf, K_req, K_out_repr, speedup, mean_bell_div)) print() print(f"[clust_calib] error_threshold={error_threshold:.1e}") print( f"{'tail%':>6s} {'K':>5s} {'K_out':>6s} {'speedup':>8s} " f"{'B_ell div':>10s} {'status'}" ) print("-" * 56) prev_tf = None for tf, K_req, K_out, speedup, bell_div in results: if prev_tf is not None and tf != prev_tf: print("-" * 56) status = "✓" if bell_div <= error_threshold else "✗" print( f"{tf * 100:>5.1f}% {K_req:>5d} {K_out:>6d} {speedup:>8.2f}x " f"{bell_div:>10.2e} {status}" ) prev_tf = tf print("-" * 56) passing = [r for r in results if r[4] <= error_threshold] if passing: best = max(passing, key=lambda x: x[3]) best_tf, best_K_req = best[0], best[1] print( f"\n[clust_calib] Recommendation: tail_fraction={best_tf}, " f"n_clusters={best_K_req} " f"(speedup={best[3]:.2f}x, B_ell div={best[4]:.2e})" ) else: best = min(results, key=lambda x: x[4]) best_tf, best_K_req = best[0], best[1] print( f"\n[clust_calib] WARNING: no (tf, K) pair achieved B_ell div " f"<= {error_threshold:.1e}." ) print( f"[clust_calib] Returning minimum-divergence pair: " f"tail_fraction={best_tf}, n_clusters={best_K_req} " f"(B_ell div={best[4]:.2e})" ) return float(best_tf), int(best_K_req)