Source code for tod_core

"""
Core numerical routines for sample-based TOD generation.

All functions are stateless and take only arrays as arguments.

Rotation kernels (in tod_rotations.py)
---------------------------------------
_rodrigues_jit                — fused double Rodrigues rotation (recenter + pol. roll),
                                materialises a (B, S, 3) buffer.  Used by the
                                numpy-fallback path and by tests.
_rodrigues_apply_one_jit      — scalar per-(b, s) fused Rodrigues; inlined into
                                the production gather kernels so the (B, S, 3)
                                intermediate is never materialised.
_rotation_params              — per-sample scalars needed by the Rodrigues kernels.
_recenter_and_rotate          — fused recenter + pol-roll wrapper.
precompute_rotation_vector_batch — Rodrigues vectors and pol. angle offsets for a batch.

Gather/accumulate kernels
-------------------------
_gather_accum_jit          — scalar bilinear accumulation from pre-computed pixels/weights
                             (in tod_bilinear.py; used by tests).
_gather_accum_fused_jit    — fully fused Rodrigues + bilinear gather + per-b
                             direct-mapped spin-2 cache + accumulation
                             (in tod_bilinear.py).
_gather_accum_nearest_jit  — fused Rodrigues + nearest-pixel gather + accumulation
                             (in tod_nearest.py).

HEALPix RING helpers (in numba_healpy.py)
-----------------------------------------
_ring_above_jit, _ring_info_jit, _ring_z_jit,
_get_interp_weights_jit, get_interp_weights_numba
"""

import numpy as np
import healpy as hp

import tod_config as config
from numba_healpy import get_interp_weights_numba
from tod_rotations import (
    _rotation_params,
    _recenter_and_rotate,
    precompute_rotation_vector_batch,
)
from tod_nearest import (
    _gather_accum_nearest_jit,
)
from tod_bilinear import (
    _gather_accum_jit,
    _gather_accum_fused_jit,
)


[docs] def beam_tod_batch( nside, mp, data, rot_vecs, phi_b, theta_b, psis_b, interp_mode="bilinear", z_skip_threshold=-1.0, ): """Accumulate the TOD contribution of one beam entry for a batch of samples. Uses Numba JIT kernels that fuse the Rodrigues rotation into the gather + accumulation step: no ``(B, S, 3)`` intermediate is materialised on the production ``mp_stacked`` path, and there is no S-tile loop — one kernel call per beam entry per batch. Args: nside (int): HEALPix ``nside`` of the sky map. mp (list[numpy.ndarray]): Sky map components ``[I, Q, U]``. Each element is a 1-D ``float32`` array of length ``12 * nside**2``. Used only on the numpy-fallback path (when ``mp_stacked`` is not provided). data (dict): Beam data entry as returned by :func:`prepare_beam_data`. Required keys: ``'vec_orig'``, ``'beam_vals'``, ``'comp_indices'``. Production path additionally requires ``'mp_stacked'``. rot_vecs (numpy.ndarray): Rodrigues rotation vectors from :func:`precompute_rotation_vector_batch`, shape ``(B, 3)``. phi_b (numpy.ndarray): Boresight longitude [rad], shape ``(B,)``. theta_b (numpy.ndarray): Boresight colatitude [rad], shape ``(B,)``. psis_b (numpy.ndarray): Combined rotation angle ``psi_b - beta`` [rad], shape ``(B,)``. interp_mode (str): Sky-map interpolation strategy. One of: * ``'bilinear'`` *(default)* — 4-pixel bilinear HEALPix interpolation with spin-2 Q/U frame correction. * ``'nearest'`` — single nearest-pixel lookup; fastest, no pixel mixing. (``'gaussian'`` and ``'bicubic'`` are available on their respective branches.) z_skip_threshold (float): Per-``b`` spin-2 skip cutoff on ``|cos θ_pts|``. Boresight samples with ``|bz| > z_skip_threshold`` apply the full Q/U frame correction; samples in the equatorial band (``|bz| <= z_skip_threshold``) bypass it. ``-1.0`` (default) disables the optimisation — spin-2 is always applied, bit-identical to the un-optimised path. Pass the value returned by :func:`tod_bilinear.compute_spin2_skip_z_threshold` to enable. Returns: dict[int, numpy.ndarray]: Mapping from Stokes component index to a ``(B,)`` ``float32`` array containing the beam-weighted sky-map accumulation for that component over the batch. """ B = phi_b.shape[0] vec_orig = data["vec_orig"] # (S, 3) beam_vals = data["beam_vals"] # (S,) S = vec_orig.shape[0] comp_indices = data["comp_indices"] C = len(comp_indices) mp_stacked = data.get("mp_stacked") # (C, N) float32, or None # Q and U channel positions within the C-dim of mp_stacked. # Convention: input map fields are [T, Q, U] at indices [0, 1, 2]. c_q = comp_indices.index(1) if 1 in comp_indices else -1 c_u = comp_indices.index(2) if 2 in comp_indices else -1 use_nearest = interp_mode == "nearest" if interp_mode not in ("nearest", "bilinear"): raise ValueError( f"interp_mode {interp_mode!r} not available on main branch; " "switch to the 'gaussian' or 'bicubic' branch" ) axes, cos_a, sin_a, ax_pts, cos_p, sin_p = _rotation_params( rot_vecs, phi_b, theta_b, psis_b ) _dt = config.precision_dtype vec_orig_f32 = np.ascontiguousarray(vec_orig, dtype=_dt) beam_vals_f32 = np.ascontiguousarray(beam_vals, dtype=_dt) if mp_stacked is not None: tod_arr = np.zeros((C, B), dtype=np.float64) if use_nearest: _gather_accum_nearest_jit( vec_orig_f32, axes, cos_a, sin_a, ax_pts, cos_p, sin_p, nside, mp_stacked, beam_vals_f32, B, S, tod_arr, c_q, c_u, float(z_skip_threshold), ) else: _gather_accum_fused_jit( vec_orig_f32, axes, cos_a, sin_a, ax_pts, cos_p, sin_p, nside, mp_stacked, beam_vals_f32, B, S, tod_arr, c_q, c_u, float(z_skip_threshold), ) return {comp: tod_arr[i].astype(_dt) for i, comp in enumerate(comp_indices)} # ── Fallback: healpy-based gather when mp_stacked is not provided. # Not on the production hot path — materialises the (B, S, 3) rotated # vector buffer via the batch Rodrigues kernel. vec_rot = _recenter_and_rotate(vec_orig_f32, rot_vecs, phi_b, theta_b, psis_b) theta_flat, phi_flat = hp.vec2ang(vec_rot.reshape(-1, 3).astype(np.float64)) pixels, weights = get_interp_weights_numba(nside, theta_flat, phi_flat) mp_gathered = np.stack([mp[c][pixels] for c in comp_indices]) mp_flat = np.einsum("ckn,kn->cn", mp_gathered, weights) tod_chunk = mp_flat.reshape(C, B, S) @ beam_vals_f32 return {comp: tod_chunk[i].astype(_dt) for i, comp in enumerate(comp_indices)}