Source code for sample_based_tod_generation_gridint

import os
import time
import multiprocessing
from multiprocessing.shared_memory import SharedMemory
from functools import partial

import numpy as np
import healpy as hp

import tod_config as config
from tod_io import load_scan_information, open_scan_day
from tod_core import precompute_rotation_vector_batch, beam_tod_batch
from tod_calibrate import calibrate_runtime, calibrate_beam_clustering
from tod_utils import _get_ncpus, _fmt_time, _should_print_batch
from tod_pipeline_helpers import (
    prepare_beam_data,
    apply_beam_clustering,
    apply_hwp_modulation,
    resolve_spin2_skip_threshold,
    save_runtime_calibration,
    save_clustering_calibration,
)

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

# ── Config ────────────────────────────────────────────────────────────────────
folder_scan = config.FOLDER_SCAN
folder_tod_output = config.FOLDER_TOD_OUTPUT
beam_files = [config.beam_file_I, config.beam_file_Q, config.beam_file_U]
start_day = config.start_day
end_day = config.end_day

interp_mode = config.beam_interp_method

# ── Worker-global state (populated by _worker_init in each spawned process) ───
# Beam data lives in shared memory; the sky map itself is consumed via
# mp_stacked (per-beam) so we only need its nside in the worker.
_g_nside = None  # int — HEALPix nside of the input sky map
_g_beam_data = None  # beam_data dict with mp_stacked from shared memory
_g_shm_handles = []  # SharedMemory handles kept alive for worker lifetime


# ── Pool initialiser ─────────────────────────────────────────────────────────


def _worker_init(beam_data_static, nside, beam_shm_descs, n_threads):
    """
    Called once in each spawned worker process.

    Attaches to the SharedMemory blocks created by the parent and builds
    zero-copy numpy views that the worker uses for the lifetime of the process.
    The SharedMemory handles are stored in _g_shm_handles so they are not
    garbage-collected (which would invalidate the buffer).

    Parameters
    ----------
    beam_data_static : dict  — beam_data without large arrays (small scalars /
                               tiny arrays safe to pickle: beam_vals, vec_orig,
                               psi_grid, sel, ra, dec, …)
    nside            : int   — HEALPix nside of the input sky map
    beam_shm_descs   : dict  — {beam_filename: {'name', 'shape', 'dtype'}}
                               for mp_stacked
    """
    global _g_nside, _g_beam_data, _g_shm_handles

    if n_threads is not None and n_threads > 0:
        import numba

        numba.set_num_threads(int(n_threads))

    _g_nside = int(nside)

    # Attach to each beam entry's mp_stacked block
    _g_beam_data = {}
    for bf, static in beam_data_static.items():
        desc = beam_shm_descs[bf]
        shm = SharedMemory(name=desc["name"])
        _g_shm_handles.append(shm)
        ms = np.ndarray(desc["shape"], dtype=desc["dtype"], buffer=shm.buf)
        entry = dict(static)
        entry["mp_stacked"] = ms
        _g_beam_data[bf] = entry


# ── TOD generation ────────────────────────────────────────────────────────────


[docs] def tod_exact_gen_batched( beam_data, day_index, nside, batch_size, process_name=None, z_skip_threshold=-1.0, fsamp=None, ): """Generate TOD for a single observation day using batched processing. Opens the scan files as persistent memory-maps (avoiding repeated ``open``/``mmap`` syscalls per batch), then processes the day in ``ceil(n_samples / batch_size)`` batches. Each batch computes Rodrigues rotation vectors, calls :func:`~tod_core.beam_tod_batch` for every beam entry, and accumulates the results. Args: beam_data (dict): Pre-loaded beam data from :func:`prepare_beam_data`. Must include ``'mp_stacked'`` for the Numba gather path. day_index (int): Zero-based index of the observation day. nside (int): HEALPix nside of the input sky map. batch_size (int): Number of detector samples per processing batch. Use the value returned by :func:`~tod_calibrate._calibrate_n_processes`. process_name (str | None): Label for log messages (e.g. the ``multiprocessing.Process`` name). Defaults to ``None``. Returns: numpy.ndarray: TOD array of shape ``(3, n_samples)``, dtype matching ``tod_config.precision_dtype``. Axis 0 is the Stokes component index ``[I, Q, U]``. """ prefix = f"[{process_name}] " if process_name else "" # Open mmaps once for the whole day — avoids re-opening 3 files per batch, # which at batch_size=8 would otherwise dominate I/O overhead. theta_mmap, phi_mmap, psi_mmap = open_scan_day(folder_scan, day_index) n_samples = len(phi_mmap) first_bf = next(iter(beam_data)) ra0, dec0 = beam_data[first_bf]["ra"], beam_data[first_bf]["dec"] _cx, _cy = config.beam_center_x, config.beam_center_y beam_center_idx = (_cx, _cy) if (_cx is not None and _cy is not None) else None batch_size = max(1, min(batch_size, n_samples)) n_batches = (n_samples + batch_size - 1) // batch_size print( prefix + f"Day {day_index}{n_samples} samples, batch_size={batch_size}, " + f"n_batches={n_batches}" ) tod_day = np.zeros((3, n_samples), dtype=config.precision_dtype) start_time = time.time() for batch_idx in range(n_batches): bs = batch_idx * batch_size be = min(bs + batch_size, n_samples) # ETA if _should_print_batch(batch_idx, n_batches): elapsed = time.time() - start_time if batch_idx > 0: eta = elapsed / batch_idx * (n_batches - batch_idx) eta_str = _fmt_time(eta) else: eta_str = "..." print( prefix + f"Batch {batch_idx + 1}/{n_batches} samples {bs}-{be - 1} ETA {eta_str}" ) theta_b = np.array(theta_mmap[bs:be], dtype=config.precision_dtype) phi_b = np.array(phi_mmap[bs:be], dtype=config.precision_dtype) psi_b = np.array(psi_mmap[bs:be], dtype=config.precision_dtype) rot_vecs, betas = precompute_rotation_vector_batch( ra0, dec0, phi_b, theta_b, center_idx=beam_center_idx ) psis_b = -betas + psi_b tod_batch = np.zeros((3, be - bs), dtype=config.precision_dtype) for data in beam_data.values(): contrib = beam_tod_batch( nside, None, data, rot_vecs, phi_b, theta_b, psis_b, interp_mode=interp_mode, z_skip_threshold=z_skip_threshold, ) for comp, vals in contrib.items(): tod_batch[comp] += vals if config.hwp_enabled: apply_hwp_modulation( tod_batch, day_index=day_index, sample_start=bs, fsamp=fsamp, f_hwp=config.hwp_rotation_frequency_hz, phi0=config.hwp_initial_phase_rad, ) tod_day[:, bs:be] = tod_batch total = time.time() - start_time print( prefix + f"Done — {n_samples} samples in {_fmt_time(total)} ({total / n_batches:.2f}s/batch)" ) return tod_day
# ── Per-day worker (used by multiprocessing pool) ───────────────────────────── def _process_day(day_index, batch_size, Nb, z_skip_threshold=-1.0, fsamp=None): """ Worker entry point. beam_data and mp are *not* passed as arguments — they live in the module-level globals populated by _worker_init, so no pickling / copying of the large sky-map arrays occurs per task. """ process_name = multiprocessing.current_process().name print(f"[{process_name}] Processing day {day_index + 1}/{Nb}") try: tod_day = tod_exact_gen_batched( _g_beam_data, day_index, _g_nside, batch_size, process_name=process_name, z_skip_threshold=z_skip_threshold, fsamp=fsamp, ) output_file = os.path.join(folder_tod_output, f"tod_day_{day_index}.npy") np.save(output_file, tod_day) print(f"[{process_name}] Saved {output_file}") return day_index, True, None except Exception as e: print(f"[{process_name}] Error on day {day_index}: {e}") return day_index, False, str(e) # ── Main ──────────────────────────────────────────────────────────────────────
[docs] def main(n_cpu_ceiling): t0 = time.time() Nb, fsamp = load_scan_information(folder_scan) start = max(start_day or 0, 0) end = min(end_day or Nb, Nb) days = range(start, end) os.makedirs(folder_tod_output, exist_ok=True) # Load the sky map here (inside main / under __name__ guard) so that # spawned worker processes — which re-import this module — never execute # this line themselves. print( f"Loading sky map (precision={config.precision}, " f"fields={list(config.map_fields)})..." ) _raw = hp.read_map(config.path_to_map, field=tuple(config.map_fields)) if len(config.map_fields) == 1: _raw = (_raw,) MP = { c: np.asarray(m).astype(config.precision_dtype) for c, m in zip(config.map_fields, _raw) } # ── Load exact beam data (clustering applied separately below) ───────────── print("Loading beam data...") beam_data = prepare_beam_data(beam_files) # ── Beam pixel clustering ────────────────────────────────────────────────── # Calibration: sweep (tail_fraction, K) grid on a probe batch to find the # best setting within the configured error tolerance. Runs only when # clustering_calibration_enabled=True; disabled automatically after first run. if config.clustering_calibration_enabled: print("Running beam clustering calibration …") best_tf, best_K = calibrate_beam_clustering( beam_data, folder_scan=folder_scan, probe_day=start, mp=MP, error_threshold=config.clustering_error_threshold, interp_mode=interp_mode, ) save_clustering_calibration(best_tf, best_K) # Update in-memory config so clustering is applied this run too config.n_beam_clusters = best_K config.beam_cluster_tail_fraction = best_tf if config.n_beam_clusters is not None: print( f"Applying beam clustering " f"(tail_fraction={config.beam_cluster_tail_fraction}, " f"n_clusters={config.n_beam_clusters}) …" ) apply_beam_clustering( beam_data, n_clusters=config.n_beam_clusters, tail_fraction=config.beam_cluster_tail_fraction, ) # Stack sky-map components per beam entry into a contiguous (C, N) array # in the active precision. The Numba gather kernel requires this layout. for data in beam_data.values(): data["mp_stacked"] = np.ascontiguousarray( np.stack([MP[c] for c in data["comp_indices"]]) # (C, N_hp) ) z_skip_threshold = resolve_spin2_skip_threshold( beam_data, config.spin2_skip_tolerance ) use_cached = not config.calibration_enabled if use_cached: ncpus = config.calibration_n_processes n_threads = config.calibration_numba_threads batch_size = config.calibration_batch_size print( f"Using cached calibration: n_processes={ncpus}, " f"numba_threads={n_threads}, batch_size={batch_size}" ) else: print("Calibrating runtime (n_processes × numba_threads × batch_size)...") _cx, _cy = config.beam_center_x, config.beam_center_y ncpus, n_threads, batch_size = calibrate_runtime( beam_data, folder_scan, probe_day=start, mp=MP, n_cpu_ceiling=n_cpu_ceiling, max_processes_user=config.n_processes, interp_mode=interp_mode, center_idx=(_cx, _cy) if (_cx is not None and _cy is not None) else None, z_skip_threshold=z_skip_threshold, ) save_runtime_calibration(ncpus, n_threads, batch_size) print( f"Processing days {start}{end - 1} ({len(days)} days, " f"{ncpus} workers × {n_threads} threads)" ) nside = hp.get_nside(next(iter(MP.values()))) if ncpus > 1: # ── Allocate shared memory ───────────────────────────────────────── # Workers consume the sky map exclusively via mp_stacked (per-beam, # already in shared memory below); only nside is needed otherwise. # One block per unique beam file for mp_stacked (C, npix) — dtype # follows config.precision_dtype. beam_shms = {} beam_shm_descs = {} for bf, data in beam_data.items(): ms = data["mp_stacked"] shm = SharedMemory(create=True, size=ms.nbytes) np.ndarray(ms.shape, dtype=ms.dtype, buffer=shm.buf)[:] = ms beam_shms[bf] = shm beam_shm_descs[bf] = { "name": shm.name, "shape": ms.shape, "dtype": ms.dtype, } # Only small arrays remain in the pickle payload: beam_vals, vec_orig, # psi_grid, sel, ra, dec, comp_indices, n_sel. _SHARED_KEYS = {"mp_stacked"} beam_data_static = { bf: {k: v for k, v in data.items() if k not in _SHARED_KEYS} for bf, data in beam_data.items() } worker = partial( _process_day, batch_size=batch_size, Nb=Nb, z_skip_threshold=z_skip_threshold, fsamp=fsamp, ) try: with multiprocessing.Pool( processes=ncpus, initializer=_worker_init, initargs=(beam_data_static, nside, beam_shm_descs, n_threads), ) as pool: results = pool.map(worker, days) finally: # Release shared memory only after all workers have finished. for shm in beam_shms.values(): shm.close() shm.unlink() failed = [r for r in results if not r[1]] print(f"\nDone — {len(results) - len(failed)}/{len(results)} days OK") for day, _, err in failed: print(f" Day {day} failed: {err}") else: if n_threads is not None and n_threads > 0: import numba numba.set_num_threads(int(n_threads)) for day_index in days: tod_day = tod_exact_gen_batched( beam_data, day_index, nside, batch_size, process_name="main", z_skip_threshold=z_skip_threshold, fsamp=fsamp, ) output_file = os.path.join(folder_tod_output, f"tod_day_{day_index}.npy") np.save(output_file, tod_day) print(f"\nTotal run time: {(time.time() - t0) / 60:.2f}m")
if __name__ == "__main__": multiprocessing.set_start_method(config.mp_start_method) main(_get_ncpus())