"""
Numba JIT replacements for HEALPix RING-scheme helper routines.
These functions mirror the HEALPix C++ internals and are designed to be
called from within parallel Numba kernels.
_ring_above_jit — scalar ring_above helper (nopython, no parallel).
_ring_info_jit — scalar ring layout helper: (n_pix, first_pix, phi0, dphi).
_ring_z_jit — scalar ring centre z = cos(theta) helper.
_get_interp_weights_jit — parallel (prange over N) replacement for
hp.get_interp_weights; mirrors the HEALPix C++
get_interpol algorithm exactly.
get_interp_weights_numba— public wrapper; drop-in replacement for hp.get_interp_weights.
_ring_interp_single_jit — bilinear neighbour lookup for one unit-vector query;
acos-free for the normal case (sin weight formula).
_ring_interp_with_angles_jit — same as _ring_interp_single_jit but also returns
(z_n, phi_n) of each of the 4 neighbours for
callers that need the neighbour sky positions
(e.g. the spin-2 Q/U frame-rotation kernel).
_pix2ang_ring_jit — scalar (theta, phi) from RING pixel index (nopython).
_pix2ang_ring_batch — parallel batch kernel over an array of pixel indices.
pix2ang_numba — public wrapper; drop-in for hp.pix2ang(nest=False).
_query_disc_jit — nopython query_disc: returns int64 array of RING pixel
indices within a disc, callable from inside JIT kernels.
query_disc_numba — public wrapper; drop-in for hp.query_disc(nest=False).
_gather_ring_stencil_jit — fast Keys/Catmull-Rom stencil gather via ring walk.
Replaces _query_disc_into_jit in the bicubic hot loop,
eliminating the ~9 acos calls per (b,s) element.
"""
import math
import numpy as np
import numba
# Module-level float64 constants captured by Numba as compile-time literals.
_TWO_PI = 2.0 * math.pi
_INV_TWO_PI = 1.0 / _TWO_PI
_TWO_THIRDS = 2.0 / 3.0 # HEALPix polar-cap / equatorial boundary
# ── HEALPix RING-scheme helpers (nopython, no parallel) ───────────────────────
# These three functions mirror the HEALPix C++ internals for get_interpol.
# They must NOT carry parallel=True because they are called from within a
# prange body inside _get_interp_weights_jit.
@numba.jit(nopython=True, cache=True)
def _ring_above_jit(nside, z):
"""
Index of the last ring whose z-centre is > z (HEALPix RING, 1-based).
Mirrors ring_above() in healpix_base.cc. Returns 0 when z is above ring 1
(near north pole) and 4*nside-1 when z is below ring 4*nside-1 (near south
pole); the caller is responsible for clamping as needed.
"""
az = abs(z)
if az > _TWO_THIRDS: # polar cap
tp = nside * math.sqrt(3.0 * (1.0 - az))
ir = int(tp) # floor for tp >= 0
if z < 0.0:
ir = 4 * nside - ir - 1 # south-cap mirror
else: # equatorial belt
ir = int(nside * (2.0 - 1.5 * z))
return ir
@numba.jit(nopython=True, cache=True)
def _ring_info_jit(nside, ir, npix_total):
"""
Ring layout for ring ir (1-based, RING scheme).
Returns
-------
n_pix : int number of pixels in the ring
first_pix : int index of the first pixel in the ring
phi0 : float longitude of the first pixel [rad]
dphi : float pixel angular spacing [rad]
"""
if ir < nside: # north polar cap
n_pix = 4 * ir
first_pix = 2 * ir * (ir - 1)
s = 1 # always shifted
elif ir <= 3 * nside: # equatorial belt
n_pix = 4 * nside
first_pix = 2 * nside * (nside - 1) + (ir - nside) * 4 * nside
# shifted when (ir - nside) is EVEN — matches HEALPix C++ get_ring_info_small
s = 1 if (ir - nside) % 2 == 0 else 0
else: # south polar cap
i2 = 4 * nside - ir
n_pix = 4 * i2
first_pix = npix_total - 2 * i2 * (i2 + 1)
s = 1 # always shifted
dphi = _TWO_PI / n_pix
phi0 = s * dphi * 0.5
return n_pix, first_pix, phi0, dphi
@numba.jit(nopython=True, cache=True)
def _ring_z_jit(nside, ir):
"""cos(theta) at the centre of ring ir (1-based, RING scheme)."""
if ir < nside:
tmp = float(ir)
return 1.0 - tmp * tmp / (3.0 * nside * nside)
elif ir <= 3 * nside:
return (2.0 / 3.0) * (2.0 - float(ir) / nside)
else:
tmp = float(4 * nside - ir)
return -(1.0 - tmp * tmp / (3.0 * nside * nside))
# ── Standalone parallel interp-weights kernel ─────────────────────────────────
@numba.jit(nopython=True, parallel=True, cache=True)
def _get_interp_weights_jit(nside, theta_arr, phi_arr, pix_out, wgt_out):
"""
Parallel Numba replacement for hp.get_interp_weights (RING scheme).
Mirrors the HEALPix get_interpol algorithm including the north/south pole
boundary special cases. Each of the N iterations is fully independent;
parallelised with prange.
Phi interpolation uses float-modulo then floor (equivalent to the
jax_healpy reference) so that points with phi < phi0 wrap correctly.
Parameters
----------
nside : int
theta_arr : (N,) float64 colatitude [rad]
phi_arr : (N,) float64 longitude [rad]
pix_out : (4, N) int64 written in place
wgt_out : (4, N) float64 written in place
"""
npix_total = 12 * nside * nside
N = theta_arr.shape[0]
for i in numba.prange(N):
theta = theta_arr[i]
phi = phi_arr[i]
z = math.cos(theta)
ir_above = _ring_above_jit(nside, z)
ir_below = ir_above + 1
if ir_above == 0:
# ── North-pole boundary ───────────────────────────────────────────
# Point is north of ring 1. Use ring 1 for all pixel selection;
# the "above" pair are the two opposite pixels in ring 1 (shift +2).
na, fpa, phi0a, dphia = _ring_info_jit(nside, 1, npix_total)
tw = ((phi - phi0a) / dphia) % float(na)
ip = int(tw)
frac = tw - ip
ip2 = (ip + 1) % na
# "below" pixels: the two straddling ring-1 neighbours
p2 = fpa + ip
p3 = fpa + ip2
# "above" pixels: opposite pixels (shifted by na/2 = 2 for ring 1)
p0 = (ip + 2) % na # fpa = 0 for ring 1
p1 = (ip2 + 2) % na
# theta weight: theta1 = 0 at north pole → w = theta / theta2
za = _ring_z_jit(nside, 1)
ta = math.acos(za)
w_theta = theta / ta
nf = (1.0 - w_theta) * 0.25 # north_factor (equal spread)
pix_out[0, i] = p0
pix_out[1, i] = p1
pix_out[2, i] = p2
pix_out[3, i] = p3
wgt_out[0, i] = nf
wgt_out[1, i] = nf
wgt_out[2, i] = (1.0 - frac) * w_theta + nf
wgt_out[3, i] = frac * w_theta + nf
elif ir_below == 4 * nside:
# ── South-pole boundary ───────────────────────────────────────────
# Point is south of the last ring (4*nside-1). Use that ring for
# all pixel selection; the "below" pair are the two opposite pixels.
ir_last = 4 * nside - 1
na, fpa, phi0a, dphia = _ring_info_jit(nside, ir_last, npix_total)
tw = ((phi - phi0a) / dphia) % float(na)
ip = int(tw)
frac = tw - ip
ip2 = (ip + 1) % na
# "above" pixels: normal ring ir_last neighbours
p0 = fpa + ip
p1 = fpa + ip2
# "below" pixels: opposite pixels in the same 4-pixel last ring
p2 = (ip + 2) % na + fpa
p3 = (ip2 + 2) % na + fpa
# theta weight toward south pole
za = _ring_z_jit(nside, ir_last)
ta = math.acos(za)
w_theta_south = (theta - ta) / (math.pi - ta)
sf = w_theta_south * 0.25 # south_factor
pix_out[0, i] = p0
pix_out[1, i] = p1
pix_out[2, i] = p2
pix_out[3, i] = p3
wgt_out[0, i] = (1.0 - frac) * (1.0 - w_theta_south) + sf
wgt_out[1, i] = frac * (1.0 - w_theta_south) + sf
wgt_out[2, i] = sf
wgt_out[3, i] = sf
else:
# ── Normal case ───────────────────────────────────────────────────
za = _ring_z_jit(nside, ir_above)
zb = _ring_z_jit(nside, ir_below)
ta = math.acos(za)
tb = math.acos(zb)
w_below = (theta - ta) / (tb - ta)
w_above = 1.0 - w_below
# Ring above → pixels 0, 1
na, fpa, phi0a, dphia = _ring_info_jit(nside, ir_above, npix_total)
tw = ((phi - phi0a) / dphia) % float(na)
iphia = int(tw)
fphia = tw - iphia
pix_out[0, i] = fpa + iphia
pix_out[1, i] = fpa + (iphia + 1) % na
wgt_out[0, i] = w_above * (1.0 - fphia)
wgt_out[1, i] = w_above * fphia
# Ring below → pixels 2, 3
nb, fpb, phi0b, dphib = _ring_info_jit(nside, ir_below, npix_total)
tw = ((phi - phi0b) / dphib) % float(nb)
iphib = int(tw)
fphib = tw - iphib
pix_out[2, i] = fpb + iphib
pix_out[3, i] = fpb + (iphib + 1) % nb
wgt_out[2, i] = w_below * (1.0 - fphib)
wgt_out[3, i] = w_below * fphib
[docs]
def get_interp_weights_numba(nside, theta, phi):
"""
Drop-in Numba replacement for ``hp.get_interp_weights(nside, theta, phi)``.
Returns ``(pixels, weights)`` with shapes ``(4, N)`` and dtypes ``int64`` /
``float64``, identical to the healpy convention. Input arrays are
automatically cast to float64 and ravelled.
"""
theta = np.asarray(theta, dtype=np.float64).ravel()
phi = np.asarray(phi, dtype=np.float64).ravel()
N = theta.shape[0]
pix_out = np.empty((4, N), dtype=np.int64)
wgt_out = np.empty((4, N), dtype=np.float64)
_get_interp_weights_jit(nside, theta, phi, pix_out, wgt_out)
return pix_out, wgt_out
# ── Single-pixel bilinear interpolation ──────────────────────────────────────
@numba.jit(nopython=True, cache=True)
def _ring_interp_single_jit(nside, z, phi_w, npix_total):
"""Bilinear HEALPix RING neighbour lookup for one unit-vector query.
Mirrors the HEALPix C++ ``get_interpol`` algorithm bit-for-bit, including
the polar boundary cases. The θ-linear weight is computed via ``acos``
so the kernel is numerically identical to ``hp.get_interp_weights`` at
every nside.
Parameters
----------
nside : int
z : float cos θ of the query point, clamped to [−1, 1]
phi_w : float longitude of the query point in [0, 2π)
npix_total : int 12 * nside * nside (pre-computed by caller)
Returns
-------
p0, p1, p2, p3 : int64 RING pixel indices of the four neighbours
w0, w1, w2, w3 : float64 bilinear interpolation weights (sum to 1)
"""
ir_above = _ring_above_jit(nside, z)
ir_below = ir_above + 1
if ir_above == 0:
# ── North-pole boundary ───────────────────────────────────────────────
na, fpa, phi0a, dphia = _ring_info_jit(nside, 1, npix_total)
tw = ((phi_w - phi0a) / dphia) % float(na)
ip_a = int(tw)
frac = tw - ip_a
ip_a2 = (ip_a + 1) % na
p0 = fpa + (ip_a + 2) % na
p1 = fpa + (ip_a2 + 2) % na
p2 = fpa + ip_a
p3 = fpa + ip_a2
za = _ring_z_jit(nside, 1)
ta = math.acos(za)
# Query colatitude θ from clamped z; safe since z in [-1, 1].
theta = math.acos(z)
w_theta = theta / ta
nf = (1.0 - w_theta) * 0.25
w0 = nf
w1 = nf
w2 = (1.0 - frac) * w_theta + nf
w3 = frac * w_theta + nf
elif ir_below == 4 * nside:
# ── South-pole boundary ───────────────────────────────────────────────
ir_last = 4 * nside - 1
na, fpa, phi0a, dphia = _ring_info_jit(nside, ir_last, npix_total)
tw = ((phi_w - phi0a) / dphia) % float(na)
ip_a = int(tw)
frac = tw - ip_a
ip_a2 = (ip_a + 1) % na
p0 = fpa + ip_a
p1 = fpa + ip_a2
p2 = (ip_a + 2) % na + fpa
p3 = (ip_a2 + 2) % na + fpa
za = _ring_z_jit(nside, ir_last)
ta = math.acos(za)
theta = math.acos(z)
w_theta_south = (theta - ta) / (math.pi - ta)
sf = w_theta_south * 0.25
w0 = (1.0 - frac) * (1.0 - w_theta_south) + sf
w1 = frac * (1.0 - w_theta_south) + sf
w2 = sf
w3 = sf
else:
# ── Normal case — exact θ via acos, matches hp.get_interp_weights ────
za = _ring_z_jit(nside, ir_above)
zb = _ring_z_jit(nside, ir_below)
ta = math.acos(za)
tb = math.acos(zb)
theta = math.acos(z)
w_below = (theta - ta) / (tb - ta)
w_above = 1.0 - w_below
na, fpa, phi0a, dphia = _ring_info_jit(nside, ir_above, npix_total)
tw = ((phi_w - phi0a) / dphia) % float(na)
iphia = int(tw)
fphia = tw - iphia
p0 = fpa + iphia
p1 = fpa + (iphia + 1) % na
w0 = w_above * (1.0 - fphia)
w1 = w_above * fphia
nb, fpb, phi0b, dphib = _ring_info_jit(nside, ir_below, npix_total)
tw = ((phi_w - phi0b) / dphib) % float(nb)
iphib = int(tw)
fphib = tw - iphib
p2 = fpb + iphib
p3 = fpb + (iphib + 1) % nb
w2 = w_below * (1.0 - fphib)
w3 = w_below * fphib
return p0, p1, p2, p3, w0, w1, w2, w3
@numba.jit(nopython=True, cache=True)
def _ring_interp_with_angles_jit(nside, z, phi_w, npix_total):
"""Bilinear HEALPix RING neighbour lookup, returning neighbour angles too.
Identical math to :func:`_ring_interp_single_jit` but additionally returns
``(z_n, phi_n)`` for each of the four neighbours. Intended for callers
(e.g. the spin-2 Q/U kernel) that need the neighbour sky positions and
would otherwise have to re-derive them from the pixel index.
Parameters
----------
nside : int
z : float cos θ of the query point, clamped to [−1, 1]
phi_w : float longitude of the query point in [0, 2π)
npix_total : int 12 * nside * nside (pre-computed by caller)
Returns
-------
p0, p1, p2, p3 : int64 RING pixel indices of the four neighbours
w0, w1, w2, w3 : float64 bilinear interpolation weights (sum to 1)
z_n0, z_n1, z_n2, z_n3 : float64 cos θ of each neighbour's ring
phi_n0, phi_n1, phi_n2, phi_n3 : float64 φ of each neighbour [rad]
"""
ir_above = _ring_above_jit(nside, z)
ir_below = ir_above + 1
if ir_above == 0:
# ── North-pole boundary ───────────────────────────────────────────────
na, fpa, phi0a, dphia = _ring_info_jit(nside, 1, npix_total)
tw = ((phi_w - phi0a) / dphia) % float(na)
ip_a = int(tw)
frac = tw - ip_a
ip_a2 = (ip_a + 1) % na
p0 = fpa + (ip_a + 2) % na
p1 = fpa + (ip_a2 + 2) % na
p2 = fpa + ip_a
p3 = fpa + ip_a2
za = _ring_z_jit(nside, 1)
ta = math.acos(za)
theta = math.acos(z)
w_theta = theta / ta
nf = (1.0 - w_theta) * 0.25
w0 = nf
w1 = nf
w2 = (1.0 - frac) * w_theta + nf
w3 = frac * w_theta + nf
z_n0 = za
z_n1 = za
z_n2 = za
z_n3 = za
phi_n0 = phi0a + ((ip_a + 2) % na) * dphia
phi_n1 = phi0a + ((ip_a2 + 2) % na) * dphia
phi_n2 = phi0a + ip_a * dphia
phi_n3 = phi0a + ip_a2 * dphia
elif ir_below == 4 * nside:
# ── South-pole boundary ───────────────────────────────────────────────
ir_last = 4 * nside - 1
na, fpa, phi0a, dphia = _ring_info_jit(nside, ir_last, npix_total)
tw = ((phi_w - phi0a) / dphia) % float(na)
ip_a = int(tw)
frac = tw - ip_a
ip_a2 = (ip_a + 1) % na
p0 = fpa + ip_a
p1 = fpa + ip_a2
p2 = fpa + (ip_a + 2) % na
p3 = fpa + (ip_a2 + 2) % na
za = _ring_z_jit(nside, ir_last)
ta = math.acos(za)
theta = math.acos(z)
w_theta_south = (theta - ta) / (math.pi - ta)
sf = w_theta_south * 0.25
w0 = (1.0 - frac) * (1.0 - w_theta_south) + sf
w1 = frac * (1.0 - w_theta_south) + sf
w2 = sf
w3 = sf
z_n0 = za
z_n1 = za
z_n2 = za
z_n3 = za
phi_n0 = phi0a + ip_a * dphia
phi_n1 = phi0a + ip_a2 * dphia
phi_n2 = phi0a + ((ip_a + 2) % na) * dphia
phi_n3 = phi0a + ((ip_a2 + 2) % na) * dphia
else:
# ── Normal case — exact θ via acos, matches hp.get_interp_weights ────
za = _ring_z_jit(nside, ir_above)
zb = _ring_z_jit(nside, ir_below)
ta = math.acos(za)
tb = math.acos(zb)
theta = math.acos(z)
w_below = (theta - ta) / (tb - ta)
w_above = 1.0 - w_below
na, fpa, phi0a, dphia = _ring_info_jit(nside, ir_above, npix_total)
tw = ((phi_w - phi0a) / dphia) % float(na)
iphia = int(tw)
fphia = tw - iphia
iphia1 = (iphia + 1) % na
p0 = fpa + iphia
p1 = fpa + iphia1
w0 = w_above * (1.0 - fphia)
w1 = w_above * fphia
nb, fpb, phi0b, dphib = _ring_info_jit(nside, ir_below, npix_total)
tw = ((phi_w - phi0b) / dphib) % float(nb)
iphib = int(tw)
fphib = tw - iphib
iphib1 = (iphib + 1) % nb
p2 = fpb + iphib
p3 = fpb + iphib1
w2 = w_below * (1.0 - fphib)
w3 = w_below * fphib
z_n0 = za
z_n1 = za
z_n2 = zb
z_n3 = zb
phi_n0 = phi0a + iphia * dphia
phi_n1 = phi0a + iphia1 * dphia
phi_n2 = phi0b + iphib * dphib
phi_n3 = phi0b + iphib1 * dphib
return (
p0,
p1,
p2,
p3,
w0,
w1,
w2,
w3,
z_n0,
z_n1,
z_n2,
z_n3,
phi_n0,
phi_n1,
phi_n2,
phi_n3,
)
# ── HEALPix pix2ang (RING scheme, scalar) ────────────────────────────────────
@numba.jit(nopython=True, cache=True)
def _pix2ang_ring_jit(nside, pix):
"""
(theta, phi) [rad] for a single RING-scheme pixel index.
Mirrors hp.pix2ang(nside, pix, nest=False) for a scalar pixel.
Covers all three zones (north polar cap / equatorial belt / south polar cap)
and delegates phi to _ring_info_jit so that the shift convention stays
consistent with the rest of this module.
Parameters
----------
nside : int
pix : int pixel index in RING scheme
Returns
-------
theta : float colatitude [rad]
phi : float longitude [rad]
"""
npix = 12 * nside * nside
ncap = 2 * nside * (nside - 1) # pixels in north polar cap
if pix < ncap: # ── north polar cap ──
# Ring iring (1-based) starts at pixel 2*iring*(iring-1).
iring = int(0.5 * (1.0 + math.sqrt(1.0 + 2.0 * pix)))
ip_in = pix - 2 * iring * (iring - 1)
elif pix < npix - ncap: # ── equatorial belt ──
ip = pix - ncap
iring = ip // (4 * nside) + nside
ip_in = ip % (4 * nside)
else: # ── south polar cap ──
# ip_s counts from the south-pole end (pix=npix-1 → ip_s=0).
ip_s = npix - pix - 1
iring_s = int(0.5 * (1.0 + math.sqrt(1.0 + 2.0 * ip_s)))
iring = 4 * nside - iring_s
first_s = npix - 2 * iring_s * (iring_s + 1)
ip_in = pix - first_s
_n, _fp, phi0, dphi = _ring_info_jit(nside, iring, npix)
phi = phi0 + ip_in * dphi
z = _ring_z_jit(nside, iring)
return math.acos(z), phi
@numba.jit(nopython=True, parallel=True, cache=True)
def _pix2ang_ring_batch(nside, pix_arr, theta_out, phi_out):
"""Parallel batch kernel: fills theta_out / phi_out for every pix in pix_arr."""
for i in numba.prange(pix_arr.shape[0]):
theta_out[i], phi_out[i] = _pix2ang_ring_jit(nside, pix_arr[i])
[docs]
def pix2ang_numba(nside, pix, nest=False):
"""
Drop-in Numba replacement for ``hp.pix2ang(nside, pix, nest=False)``.
Returns ``(theta, phi)`` float64 arrays of shape ``(N,)``.
Only RING scheme (nest=False) is supported.
"""
if nest:
raise ValueError("pix2ang_numba only supports nest=False (RING scheme)")
pix_arr = np.asarray(pix, dtype=np.int64).ravel()
N = pix_arr.shape[0]
theta_out = np.empty(N, dtype=np.float64)
phi_out = np.empty(N, dtype=np.float64)
_pix2ang_ring_batch(nside, pix_arr, theta_out, phi_out)
return theta_out, phi_out
# ── HEALPix ang2pix (RING scheme, scalar) ────────────────────────────────────
@numba.jit(nopython=True, cache=True)
def _ang2pix_ring_jit(nside, theta, phi):
"""
Nearest RING-scheme pixel index for (theta, phi) [rad].
Returns the pixel whose centre is geometrically closest to (theta, phi).
Matches hp.ang2pix(nside, theta, phi, nest=False) for scalar inputs.
Algorithm
---------
1. Use _ring_above_jit to identify two candidate rings (ir_above and
ir_above+1) that bracket the query latitude.
2. For each candidate ring the nearest pixel in phi is found via
ip = int(phi * n_pix / (2π)) % n_pix — the Voronoi boundary between
pixel k and k+1 falls at k*dphi regardless of phi0 (shift).
3. Check both ip and ip+1 for each ring (4 candidates total) and return
the one with the maximum cos(angular distance).
"""
npix_total = 12 * nside * nside
z = math.cos(theta)
phi_w = phi % _TWO_PI # wrap to [0, 2π)
# ── Two candidate global rings bracketing z ───────────────────────────────
ir_above = _ring_above_jit(nside, z)
# Clamp so ir_above and ir_below are both valid ring indices.
if ir_above < 1:
ir_above = 1
elif ir_above > 4 * nside - 2:
ir_above = 4 * nside - 2
ir_below = ir_above + 1
# ── For each candidate ring find the best-phi pixel ───────────────────────
best_pix = -1
best_cos = -2.0 # maximise cos(angular_dist) ≡ minimise distance
sin_th = math.sin(theta)
cos_th = z
for ir_g in (ir_above, ir_below):
if ir_g < 1 or ir_g > 4 * nside - 1:
continue
n_pix, first_pix, phi0, dphi = _ring_info_jit(nside, ir_g, npix_total)
z_c = _ring_z_jit(nside, ir_g)
sin_z_c = math.sqrt(max(0.0, 1.0 - z_c * z_c))
# Nearest pixel in phi: Voronoi boundary at multiples of dphi.
ip_base = int(phi_w * n_pix / _TWO_PI) % n_pix
for ip_try in (ip_base, (ip_base + 1) % n_pix):
phi_c = phi0 + ip_try * dphi
cos_d = sin_th * sin_z_c * math.cos(phi_w - phi_c) + cos_th * z_c
if cos_d > best_cos:
best_cos = cos_d
best_pix = first_pix + ip_try
return best_pix
# ── HEALPix query_disc (RING scheme) ─────────────────────────────────────────
@numba.jit(nopython=True, cache=True)
def _query_disc_jit(nside, theta_q, phi_q, radius_rad, inclusive):
"""
Pixel indices within angular radius *radius_rad* of (theta_q, phi_q).
Mirrors hp.query_disc(nside, vec, radius, inclusive=..., nest=False).
Returns a 1-D int64 array of RING pixel indices (order not guaranteed).
This function is nopython-safe and can be called from within prange bodies
or other JIT-compiled kernels.
Algorithm
---------
1. Optionally widen *radius_rad* by the maximum pixel angular radius when
inclusive=True (approximated as sqrt(π / (3·nside²))).
2. Find the ring band [ir_min, ir_max] whose z-centres are within the
widened disc by calling _ring_above_jit on the disc's latitude limits.
3. For each ring, compute the phi half-width of the disc cross-section
(from the spherical-cap intersection formula) and map it to a pixel-
index range; handle wrap-around via modulo.
Parameters
----------
nside : int
theta_q : float disc-centre colatitude [rad]
phi_q : float disc-centre longitude [rad]
radius_rad : float disc radius [rad]
inclusive : bool if True enlarge radius by ~max pixel radius
Returns
-------
result : (M,) int64 RING pixel indices inside the disc
"""
npix_total = 12 * nside * nside
z_q = math.cos(theta_q)
sin_th_q = math.sqrt(max(0.0, 1.0 - z_q * z_q))
# Approximate max pixel angular radius: sqrt(π / (3·nside²))
if inclusive:
search_rad = radius_rad + math.sqrt(math.pi / (3.0 * nside * nside))
else:
search_rad = radius_rad
if search_rad >= math.pi:
return np.arange(npix_total, dtype=np.int64)
cos_search = math.cos(search_rad)
# Ring-index band whose z-centres intersect the widened disc.
theta_lo = max(0.0, theta_q - search_rad)
theta_hi = min(math.pi, theta_q + search_rad)
ir_min = max(1, _ring_above_jit(nside, math.cos(theta_lo)) + 1)
ir_max = min(4 * nside - 1, _ring_above_jit(nside, math.cos(theta_hi)))
if ir_min > ir_max:
return np.empty(0, dtype=np.int64)
# Conservative upper bound: every ring in the band has at most 4*nside pixels.
max_pix = 4 * nside * (ir_max - ir_min + 1) + 8
if max_pix > npix_total:
max_pix = npix_total
result = np.empty(max_pix, dtype=np.int64)
count = 0
for ir in range(ir_min, ir_max + 1):
z_r = _ring_z_jit(nside, ir)
sin_th_r = math.sqrt(max(0.0, 1.0 - z_r * z_r))
n_p, fp, phi0, dphi = _ring_info_jit(nside, ir, npix_total)
denom = sin_th_q * sin_th_r
if denom < 1e-12:
# Near pole: the whole ring is inside the disc.
for ip in range(n_p):
result[count] = fp + ip
count += 1
continue
# cos(dphi_half) from the spherical-cap intersection formula:
# cos(d) = sin(θ_q)·sin(θ_r)·cos(Δφ) + cos(θ_q)·cos(θ_r) = cos(search_rad)
x = (cos_search - z_q * z_r) / denom
if x > 1.0:
continue # ring too far from disc centre
if x <= -1.0:
for ip in range(n_p): # entire ring inside disc
result[count] = fp + ip
count += 1
continue
dphi_half = math.acos(x)
# Pixel-index range within ring (may be negative or > n_p, handled by %).
# Exact ceil/floor — no fudge factor needed since the disc test already
# has a small tolerance and boundary-exact pixels are not science-critical.
ip_lo = int(math.ceil((phi_q - dphi_half - phi0) / dphi))
ip_hi = int(math.floor((phi_q + dphi_half - phi0) / dphi))
if ip_hi - ip_lo + 1 >= n_p:
for ip in range(n_p):
result[count] = fp + ip
count += 1
else:
for ip_idx in range(ip_lo, ip_hi + 1):
result[count] = fp + ip_idx % n_p
count += 1
return result[:count]
@numba.jit(nopython=True, cache=True)
def _query_disc_into_jit(nside, theta_q, phi_q, radius_rad, inclusive, out_buf):
"""
Like _query_disc_jit but writes pixel indices into a pre-allocated buffer
instead of allocating a new array. Returns the count M of pixels found.
Parameters
----------
nside : int
theta_q : float disc-centre colatitude [rad]
phi_q : float disc-centre longitude [rad]
radius_rad : float disc radius [rad]
inclusive : bool if True enlarge radius by ~max pixel radius
out_buf : (max_M,) int64 caller-allocated scratch buffer
Returns
-------
M : int number of pixels written into out_buf[:M]
"""
npix_total = 12 * nside * nside
z_q = math.cos(theta_q)
sin_th_q = math.sqrt(max(0.0, 1.0 - z_q * z_q))
if inclusive:
search_rad = radius_rad + math.sqrt(math.pi / (3.0 * nside * nside))
else:
search_rad = radius_rad
if search_rad >= math.pi:
for i in range(npix_total):
out_buf[i] = i
return npix_total
cos_search = math.cos(search_rad)
theta_lo = max(0.0, theta_q - search_rad)
theta_hi = min(math.pi, theta_q + search_rad)
ir_min = max(1, _ring_above_jit(nside, math.cos(theta_lo)) + 1)
ir_max = min(4 * nside - 1, _ring_above_jit(nside, math.cos(theta_hi)))
if ir_min > ir_max:
return 0
count = 0
for ir in range(ir_min, ir_max + 1):
z_r = _ring_z_jit(nside, ir)
sin_th_r = math.sqrt(max(0.0, 1.0 - z_r * z_r))
n_p, fp, phi0, dphi = _ring_info_jit(nside, ir, npix_total)
denom = sin_th_q * sin_th_r
if denom < 1e-12:
for ip in range(n_p):
out_buf[count] = fp + ip
count += 1
continue
x = (cos_search - z_q * z_r) / denom
if x > 1.0:
continue
if x <= -1.0:
for ip in range(n_p):
out_buf[count] = fp + ip
count += 1
continue
dphi_half = math.acos(x)
ip_lo = int(math.ceil((phi_q - dphi_half - phi0) / dphi - 1e-10))
ip_hi = int(math.floor((phi_q + dphi_half - phi0) / dphi + 1e-10))
if ip_hi - ip_lo + 1 >= n_p:
for ip in range(n_p):
out_buf[count] = fp + ip
count += 1
else:
for ip_idx in range(ip_lo, ip_hi + 1):
out_buf[count] = fp + ip_idx % n_p
count += 1
return count
@numba.jit(nopython=True, cache=True)
def _gather_ring_stencil_jit(nside, vz, ph, out_buf, z_buf, phi_buf):
"""
Gather RING pixel indices for the Keys/Catmull-Rom bicubic stencil,
and simultaneously populate z_buf / phi_buf with (cos θ, φ) for each
gathered pixel — avoiding a second pass over the pixel index in the
hot loop.
While building the stencil this function already has ring geometry in hand
(from _ring_info_jit and _ring_z_jit). Returning (z, phi) alongside the
pixel index removes ~40 redundant function-call chains per (b, s) element:
each chain would otherwise re-run _ring_info_jit + _ring_z_jit + branch
logic to recover the same values from the pixel index.
Replaces _query_disc_into_jit in the bicubic hot loop, eliminating the
~9 acos + ~4 cos calls that dominate disc-search cost (~240 ns → ~5 ns).
Geometry
--------
HEALPix ring spacing is ~0.65 h_pix in both the equatorial belt and the
polar cap, so ±4 rings cover the Keys north–south support |yi| < 2.
The phi pixel step satisfies:
· equatorial belt: step ≥ 1.14 h_pix (at the equatorial–polar boundary)
· polar cap: step ≈ √π h_pix ≈ 1.77 h_pix
In both zones ±2 phi pixels per ring covers the east–west support |xi| < 2.
Stencil: 7 rings × 5 phi pixels = 35 candidates maximum.
Rings with n_p ≤ 4 (only ir = 1 at any nside) include all their pixels
directly, to avoid duplicate indices from modulo wrapping.
Parameters
----------
nside : int
vz : float64 cos(θ) of the query point (= vec_rot[b,s,2])
ph : float64 longitude [rad], in [0, 2π)
out_buf : (≥ 45,) int64 caller-allocated pixel index buffer
z_buf : (≥ 45,) float64 caller-allocated cos(θ) buffer
phi_buf : (≥ 45,) float64 caller-allocated φ [rad] buffer
Returns
-------
M : int number of entries written into out_buf[:M] / z_buf[:M] / phi_buf[:M]
"""
npix_total = 12 * nside * nside
ir_center = _ring_above_jit(nside, vz)
# _ring_above_jit returns 0 at/above the very north pole; clamp to [1, 4n-1].
if ir_center < 1:
ir_center = 1
elif ir_center > 4 * nside - 1:
ir_center = 4 * nside - 1
# Ring ±4 is always outside the Keys support: even at the minimum equatorial
# ring spacing of ~0.65 h_pix, ring ±4 sits at ±2.6 h_pix > 2 h_pix (support
# boundary). In the polar cap (spacing ~0.80 h_pix) ring ±3 is already at
# ±2.4 h_pix > 2, so only rings ±1 and ±2 ever contribute there. Gathering
# ring ±4 forces the inner loop to compute coordinates for 10 always-zero-
# weight candidates. Using ±3 eliminates those 10 wasted iterations entirely.
ir_lo = max(1, ir_center - 3)
ir_hi = min(4 * nside - 1, ir_center + 3)
count = 0
for ir in range(ir_lo, ir_hi + 1):
n_p, fp, phi0, dphi = _ring_info_jit(nside, ir, npix_total)
z_ring = _ring_z_jit(nside, ir) # cos(θ) for this ring — computed once
if n_p <= 4:
# ir = 1 at any nside: only 4 pixels in the ring.
# ±2 wrapping would repeat pixel indices, so include all directly.
for ip in range(n_p):
out_buf[count] = fp + ip
z_buf[count] = z_ring
phi_buf[count] = phi0 + ip * dphi
count += 1
else:
# Nearest pixel in phi, then ±2 neighbours with wrap-around.
ip_center = int(math.floor((ph - phi0) / dphi + 0.5)) % n_p
for dip in range(-2, 3):
ip_in = (ip_center + dip) % n_p
out_buf[count] = fp + ip_in
z_buf[count] = z_ring
phi_buf[count] = phi0 + ip_in * dphi
count += 1
return count
[docs]
def query_disc_numba(nside, vec, radius_rad, inclusive=True, nest=False):
"""
Drop-in Numba replacement for ``hp.query_disc(nside, vec, radius, ...)``.
Parameters
----------
nside : int
vec : array-like, shape (3,) unit vector pointing to disc centre
radius_rad : float disc radius [rad]
inclusive : bool if True, pixels that partially overlap are included
(default True, same as healpy's default)
nest : bool only nest=False (RING) is supported
Returns
-------
pix : (M,) int64 RING pixel indices inside the disc (order not guaranteed)
"""
if nest:
raise ValueError("query_disc_numba only supports nest=False (RING scheme)")
v = np.asarray(vec, dtype=np.float64).ravel()
z = float(np.clip(v[2], -1.0, 1.0))
theta_q = math.acos(z)
phi_q = math.atan2(float(v[1]), float(v[0])) % _TWO_PI
return _query_disc_jit(nside, theta_q, phi_q, float(radius_rad), bool(inclusive))