"""
planet.py — Keplerian planet orbit and pixel-level transit geometry for sajax.
This module is a standalone companion to sajax/core.py. It can be used
independently to compute transit light curves, or integrated with sajax
via ``build_combined_model`` / ``compute_combined_light_curve`` (defined in
core.py) to correctly model active-region crossing events — i.e. cases where
the planet occultes a starspot or facula during transit.
Architecture
------------
The module is intentionally *geometry-only*: it computes where the planet is
on the sky at each epoch and which stellar-disc pixels it occults. The flux
integration (limb darkening, active-region weighting) is handled by the
existing sajax machinery in core.py. This clean separation means that the
transit model inherits sajax's full limb-darkening parametrisation
automatically — no extra parameters are required.
Orbital convention (Winn 2010 / Eastman et al. 2013)
------------------------------------------------------
X — sky-plane east-west (positive east)
Y — sky-plane north-south (positive north, foreshortened by cos i)
Z — line-of-sight toward observer (Z > 0 ⟹ planet in front of star)
All sky positions are in units of the stellar radius R*.
Minimum parameter set
---------------------
t0 : mid-transit epoch [days]
period : orbital period [days]
a_over_rstar : semimajor axis / R* (dimensionless)
May be derived from stellar density via
``stellar_density_to_a_over_rstar()``.
inclination : orbital inclination [rad] (90 / π/2 = perfect edge-on)
ecc : orbital eccentricity [0, 1)
omega_peri : argument of periastron [rad]
(ω = 0° → periapsis at ascending node;
ω = 90° → periapsis at inferior conjunction /
transit centre for a circular orbit)
k : planet-to-star radius ratio Rp / R*
Limb darkening
--------------
The same LDC law stored in the sajax model dict is applied automatically
to occulted pixels — no separate transit LDC parameters are required.
Public API
----------
``_kepler(M, ecc)`` — differentiable Kepler solver
``planet_sky_position(...)`` — single-epoch sky coords (X, Y, Z)
``compute_planet_sky_positions(...)`` — vectorised over an array of times
``_compute_planet_mask(...)`` — per-pixel occultation mask
``build_transit_model(...)`` — pre-compute positions for all times
``stellar_density_to_a_over_rstar()`` — unit-conversion convenience
"""
from __future__ import annotations
import numpy as np
import jax.numpy as jnp
from jax import vmap, nn as jax_nn
# ---------------------------------------------------------------------------
# 1. Kepler's equation solver (differentiable, JIT-safe)
# ---------------------------------------------------------------------------
def _kepler(M: jnp.ndarray, ecc: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
"""
Solve Kepler's equation M = E - e sin E for the eccentric anomaly E,
then convert to the true anomaly f and return (sin f, cos f).
Implementation details
~~~~~~~~~~~~~~~~~~~~~~
* Symmetry fold: M is mapped into [0, π) then restored afterwards,
which halves the domain and removes sign ambiguity.
* Starter: E0 = M + e sin M (good for e ≲ 0.5; adequate for e < 0.9).
* Refinement: 6 Halley iterations (3rd-order convergence) — the residual
drops from O(e²) to < 1e-15 in ≤ 4 steps even at e = 0.95.
* All operations are JAX primitives. The fixed unrolled iteration graph
is fully differentiable via JAX's default automatic differentiation.
No ``custom_jvp`` hook is needed; the iteration count is small enough
that the unrolled gradient does not cause numerical issues.
Parameters
----------
M : mean anomaly [rad] — scalar or array
ecc : orbital eccentricity [0, 1) — scalar
Returns
-------
sinf, cosf : sin and cos of the true anomaly (same shape as M)
"""
# Wrap into [0, 2π) and exploit the symmetry sin(2π − M) = −sin(M)
M = M % (2.0 * jnp.pi)
flip = M > jnp.pi
M_ = jnp.where(flip, 2.0 * jnp.pi - M, M) # now in [0, π)
# Initial guess
E = M_ + ecc * jnp.sin(M_)
# Halley's method: f = E − e sin E − M, f′ = 1 − e cos E, f′′ = e sin E
# ΔE = −f / (f′ − f·f′′ / (2 f′)) = −f·f′ / (f′² − f·f′′/2)
for _ in range(6):
sE = jnp.sin(E)
cE = jnp.cos(E)
f = E - ecc * sE - M_
fp = 1.0 - ecc * cE
fpp = ecc * sE
E = E - f * fp / (fp * fp - 0.5 * f * fpp)
# Restore the original half-plane
E = jnp.where(flip, 2.0 * jnp.pi - E, E)
# Eccentric to true anomaly via the standard formulae
cE = jnp.cos(E)
sE = jnp.sin(E)
denom = 1.0 - ecc * cE
sinf = jnp.sqrt(jnp.maximum(1.0 - ecc ** 2, 0.0)) * sE / denom
cosf = (cE - ecc) / denom
return sinf, cosf
# ---------------------------------------------------------------------------
# 2. Sky-plane position of the planet at a single epoch
# ---------------------------------------------------------------------------
[docs]
def planet_sky_position(
time: jnp.ndarray,
t0: float,
period: float,
a_over_rstar: float,
inclination: float,
ecc: float,
omega_peri: float,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
Compute the planet's sky-plane position (X, Y, Z) in units of R*.
Parameters
----------
time : observation epoch [same units as t0 / period, e.g. days]
t0 : mid-transit epoch (inferior conjunction)
period : orbital period
a_over_rstar : semimajor axis / R* (dimensionless, > 1 for non-grazing)
inclination : orbital inclination [rad] (π/2 = edge-on)
ecc : eccentricity [0, 1)
omega_peri : argument of periastron [rad]
Measured from the ascending node to periapsis.
Returns
-------
X, Y, Z : sky-plane coordinates in units of R*
X — east-west (positive east)
Y — north-south projected (= r sin(ω+f) cos i)
Z — toward observer (Z > 0 ⟹ transit; Z < 0 ⟹ occultation)
Notes
-----
The sky-plane separation from the stellar centre is sqrt(X^2 + Y^2).
A transit (or occultation) event occurs when sqrt(X^2 + Y^2) < 1 + k,
where k = Rp / R*.
"""
# ---- True anomaly at mid-transit ----------------------------------------
# At inferior conjunction (transit centre): ω + f_transit = π/2
# ⟹ f_transit = π/2 − ω
f_transit = 0.5 * jnp.pi - omega_peri
# ---- Time of periastron passage -----------------------------------------
# Convert f_transit → E_transit via
# tan(E/2) = sqrt((1−e)/(1+e)) · tan(f/2)
# Use arctan2 for correct quadrant handling.
half_f = 0.5 * f_transit
E_transit = 2.0 * jnp.arctan2(
jnp.sqrt(1.0 - ecc) * jnp.sin(half_f),
jnp.sqrt(1.0 + ecc) * jnp.cos(half_f),
)
M_transit = E_transit - ecc * jnp.sin(E_transit) # Kepler's eq.
t_peri = t0 - (period / (2.0 * jnp.pi)) * M_transit
# ---- Mean anomaly at observation time ------------------------------------
M = (2.0 * jnp.pi / period) * (time - t_peri)
# ---- Solve Kepler --------------------------------------------------------
sinf, cosf = _kepler(M, ecc)
# ---- Orbital radius in units of R* --------------------------------------
# r = a (1 − e^2) / (1 + e cos f)
r = a_over_rstar * (1.0 - ecc ** 2) / (1.0 + ecc * cosf)
# ---- Sky-plane projection (Winn 2010, eqs. 1–3) -------------------------
# Expand cos(ω+f) and sin(ω+f) via angle-addition formulae to avoid
# computing arctan2(sinf, cosf) (preserves differentiability).
cos_w = jnp.cos(omega_peri)
sin_w = jnp.sin(omega_peri)
cos_wf = cosf * cos_w - sinf * sin_w # cos(ω + f)
sin_wf = sinf * cos_w + cosf * sin_w # sin(ω + f)
X = r * (-cos_wf) # east–west
Y = r * sin_wf * jnp.cos(inclination) # north–south (projected)
Z = r * sin_wf * jnp.sin(inclination) # toward observer
return X, Y, Z
# ---------------------------------------------------------------------------
# 3. Vectorised positions over an array of times
# ---------------------------------------------------------------------------
[docs]
def compute_planet_sky_positions(
times: jnp.ndarray,
t0: float,
period: float,
a_over_rstar: float,
inclination: float,
ecc: float,
omega_peri: float,
) -> jnp.ndarray:
"""
Vectorised wrapper around ``planet_sky_position``.
Parameters
----------
times : (ntime,) array of observation epochs
Returns
-------
xyz : (ntime, 3) array — columns are [X, Y, Z] in units of R*
"""
_pos = vmap(
lambda t: jnp.stack(
planet_sky_position(
t, t0, period, a_over_rstar, inclination, ecc, omega_peri,
)
)
)(jnp.asarray(times, dtype=jnp.float32)) # (ntime, 3)
return _pos
# ---------------------------------------------------------------------------
# 4. Per-pixel transit mask on the sajax stellar grid
# ---------------------------------------------------------------------------
def _compute_planet_mask(
x_disc: jnp.ndarray, # (total_pixels,) pixel x coordinates
y_disc: jnp.ndarray, # (total_pixels,) pixel y coordinates
star_pixel_rad: float,
X: jnp.ndarray, # planet sky-plane x [R*]
Y: jnp.ndarray, # planet sky-plane y [R*]
Z: jnp.ndarray, # planet line-of-sight [R*] — Z > 0 ⟹ transit
k: float, # Rp / R*
) -> jnp.ndarray:
"""
Boolean mask over in-disc pixels: ``True`` where the pixel is occulted
by the planet at this epoch.
The mask is non-zero only when Z > 0 (planet in front of the star).
Pixels inside the planet disc contribute zero flux; if those pixels
coincide with an active region, the spot-crossing anomaly emerges
automatically.
Parameters
----------
x_disc, y_disc : in-disc pixel coordinates [pixels]
star_pixel_rad : stellar radius in pixels
X, Y : planet sky position [R*]
Z : planet line-of-sight position [R*]
k : planet-to-star radius ratio
Returns
-------
jnp.ndarray, shape (total_pixels,), dtype bool_
"""
# Normalise pixel coordinates to stellar radii
xn = x_disc / star_pixel_rad
yn = y_disc / star_pixel_rad
# Squared sky-plane distance from planet centre to each pixel
d2 = (xn - X) ** 2 + (yn - Y) ** 2
# Soft disc mask: sigmoid boundary so gradients flow w.r.t. k and planet position.
# Transition width fixed at 1/10 pixel (matching _compute_ar_mask convention).
# Using 0.1*k instead caused a +~330 ppm systematic bias in transit depth for k~0.1
# because the ε²/r curvature correction is non-negligible when ε/r ~ 10%.
d = jnp.sqrt(d2 + 1e-8)
softness = 1.0 / (10.0 * star_pixel_rad)
disc_mask = jax_nn.sigmoid((k - d) / softness)
# Hard Z gate: planet in front of the star is topologically binary.
z_gate = jnp.where(Z > 0.0, 1.0, 0.0)
# Use jnp.where instead of `if k == 0.0` so this stays JAX-traceable when k
# is a sampled parameter (tracer) inside a numpyro / JAX-jit context.
return jnp.where(k > 0.0, disc_mask * z_gate, jnp.zeros_like(disc_mask))
# ---------------------------------------------------------------------------
# 5. build_transit_model — pre-compute positions for all (oversampled) epochs
# ---------------------------------------------------------------------------
[docs]
def build_transit_model(
times: np.ndarray,
t0: float,
period: float,
a_over_rstar: float,
inclination: float,
ecc: float = 0.0,
omega_peri: float = 0.0,
k: float = 0.1,
) -> dict:
"""
Pre-compute the planet's sky-plane position at every epoch in ``times``.
The returned dict should be stored in the sajax model dict under the key
``"transit"``. The combined model builder ``build_combined_model()``
(in core.py) does this automatically — end users typically do not need
to call this function directly.
Parameters
----------
times : (ntime,) array of observation epochs [days]
Must be the **oversampled** time array when oversampling
is active (see ``build_combined_model``).
t0 : mid-transit epoch [days]
period : orbital period [days]
a_over_rstar : semimajor axis / R* (dimensionless)
inclination : orbital inclination [rad]
ecc : eccentricity (default: 0.0 = circular)
omega_peri : argument of periastron [rad] (default: 0.0)
k : planet-to-star radius ratio Rp / R* (default: 0.1)
Returns
-------
dict with keys
~~~~~~~~~~~~~~
``planet_xyz`` : (ntime, 3) jnp.ndarray — planet (X, Y, Z) per epoch
``k`` : float — planet-to-star radius ratio
"""
times_jax = jnp.asarray(times, dtype=jnp.float32)
xyz = compute_planet_sky_positions(
times_jax, t0, period, a_over_rstar, inclination, ecc, omega_peri,
) # (ntime, 3)
return dict(
planet_xyz = xyz,
k = float(k),
)
# ---------------------------------------------------------------------------
# 6. Unit-conversion convenience
# ---------------------------------------------------------------------------
# Physical constants in SI / solar units needed for Kepler's third law
_G_cgs = 6.674_08e-8 # cm^3 g^-1 s^-2
[docs]
def stellar_density_to_a_over_rstar(
rho_star_gcc: float,
period_days: float,
) -> float:
"""
Convert mean stellar density and orbital period to a / R* via
Kepler's third law (Seager & Mallén-Ornelas 2003):
a / R* = ( G ρ★ P^2 / (3π) )^(1/3)
Parameters
----------
rho_star_gcc : mean stellar density [g cm^-3]
period_days : orbital period [days]
Returns
-------
a_over_rstar : float (dimensionless)
"""
P_sec = period_days * 86_400.0
a_over_r_cgs = (_G_cgs * rho_star_gcc * P_sec ** 2 / (3.0 * np.pi)) ** (1.0 / 3.0)
return float(a_over_r_cgs)
[docs]
def a_over_rstar_to_stellar_density(
a_over_rstar: float,
period_days: float,
) -> float:
"""
Inverse of ``stellar_density_to_a_over_rstar``:
ρ★ = 3π / (G P^2) · (a / R*)^3
Parameters
----------
a_over_rstar : semimajor axis / R* (dimensionless)
period_days : orbital period [days]
Returns
-------
rho_star_gcc : mean stellar density [g cm^-3]
"""
P_sec = period_days * 86_400.0
rho = 3.0 * np.pi * a_over_rstar ** 3 / (_G_cgs * P_sec ** 2)
return float(rho)