sajax.core#
core.py — JAX-accelerated stellar active region light-curve engine.
This module is a complete rewrite of SAGE1/sage.py in JAX.
Key differences from the original NumPy/SciPy implementation#
No wavelength loop. The original code iterated over wavelengths with a Python
forloop. Here the entire spectral axis is handled byjax.vmap, which maps the single-channel computation across all wavelengths in parallel.No phase loop. The original code iterated over rotational phases with a Python loop. Here all phases are computed in a single
jax.vmapcall — this is the main source of speedup over the original code.No scatter-index active region placement. The original code located active region pixels via integer scatter indices (fancy indexing with
.astype(int)), which is not differentiable and incompatible withjit. SAJAX instead computes an analytic angular-distance mask over the full pixel arrays usingjnp.where, which is fully vectorised and differentiable.No class state mutation. The original
sage_class.rotate_star()mutatedself.phases_rotinside a loop — a latent bug. SAJAX uses pure functions throughout.No astropy dependency for geometry. Rotation matrices are implemented directly in JAX (see geometry.py).
No transit-geometry parameters. The original SAGE grid was sized using
planet_pixel_size,radiusratio, andsemimajor— artifacts of its transit-fitting origin. SAJAX replaces these with a singlestellar_grid_sizeparameter: the stellar radius in pixels. No planet required.Pre-masked grid.
build_stellar_gridapplies the stellar disc mask immediately and returns 1D arrays containing only the in-disc pixels. No starmask is ever passed to JAX functions — the mask is implicit in the data shape. The only 2D reconstruction happens at output time forstar_maps, using stored flat indices.Differentiable end-to-end. All operations are JAX-native, so
jax.grad/jax.jacobianwork on the full pipeline — useful for gradient-based retrieval.Phase oversampling. Real observations integrate photons over a finite exposure time. When an active region crosses the stellar limb, the discrete pixel grid can produce sharp discontinuities in the light curve. The
oversampleparameter (default 1, i.e. off) spreads each requested phase into multiple sub-exposures and averages the result, mimicking finite-exposure integration and smoothing limb-crossing artefacts.
JIT compilation#
Do NOT jit(evaluate_light_curve) directly — it contains Python-level control flow on model metadata. Instead, the inner _compute_all_phases is the hot path and is safe to JIT via:
from jax import jit _compute_all_phases_jit = jit(_compute_all_phases, static_argnames=[
“star_pixel_rad”, “total_pixels”, “ldc_mode”, “ar_overlap_mode”, “plot_map_wavelength”, “n”,
])
Attributes#
Functions#
|
Pre-compute the static stellar pixel grid, masked to the stellar disc. |
|
Pre-build all static model arrays. Call this once before MCMC. |
|
Evaluate the light curve for a given set of active region parameters. |
|
Convenience wrapper: build model and evaluate in one call. |
|
Build a combined stellar-activity + planetary-transit sajax model. |
|
Convenience wrapper: build a combined stellar + transit model and |
Module Contents#
- sajax.core.build_stellar_grid(stellar_grid_size: int, ve: float) dict[source]#
Pre-compute the static stellar pixel grid, masked to the stellar disc.
The mask is applied here once so that all downstream JAX functions receive 1D arrays containing only the in-disc pixels — no starmask is ever passed around.
- Parameters:
stellar_grid_size (int) – Stellar radius in pixels. This is the single resolution knob: higher values give a finer grid at the cost of n² memory and compute. Values of 100-300 are typical.
ve (float) – Stellar equatorial velocity [km/s].
- Returns:
dict with keys
~~~~~~~~~~~~~~
n— full grid side length (always odd)star_pixel_rad— stellar radius in pixels (= stellar_grid_size)total_pixels— number of in-disc pixelsflat_indices— (total_pixels,) int indices into the flattened – (n, n) grid; used to reconstruct 2D maps at outputx— (total_pixels,) x pixel coordinates [in-disc only]y— (total_pixels,) y pixel coordinates [in-disc only]mu— (total_pixels,) limb-darkening cos θ [in-disc only]vel— (total_pixels,) Doppler factor Δv/c [in-disc only]
- sajax.core.build_model(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, params: dict, phases_rot: numpy.ndarray, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: float | None = None, oversample: int = 1) dict[source]#
Pre-build all static model arrays. Call this once before MCMC.
Everything that does not change between MCMC steps is computed here in NumPy and stored in the returned model dict. The only quantities that vary per step —
flux_active,ar_lat,ar_long,ar_size— are intentionally excluded and passed toevaluate_light_curveinstead.wavelength : array_like, shape (nwave,) flux_quiet : array_like, shape (nwave,) params : dict
Model parameters. Recognised keys:
inc_starfloat, optionalStellar inclination in degrees (default: 90.0). 90° = equator-on, 0° = pole-on.
ldc_coeffslist of float or list of array(nwave,)Limb-darkening coefficients for the chosen
ldc_mode: -"linear": [u] -"quadratic": [u1, u2] -"power2": [c, alpha] -"kipping3": [c1, c2, c3]
"nonlinear4": [c1, c2, c3, c4]Each element may be a scalar (broadcast to all wavelengths) or an array of length
nwave. For"quadratic"mode only,u1andu2are also accepted as separate keys (legacy interface).mu_profilearray-like, optionalMonotonically increasing μ grid points for
ldc_mode="intensity_profile"(default: [0, 1]).I_profilearray-like, shape (nwave, n_mu_pts), optionalSpecific intensity at each (wavelength, μ) grid point. Required when
ldc_mode="intensity_profile".
phases_rot : array_like, shape (nphase,) stellar_grid_size : int ve : float ldc_mode : str ar_overlap_mode : {“hottest_wins”, “coldest_wins”}, optional
Rule for resolving overlapping active regions: - “hottest_wins”: overlap pixel uses flux from hottest (highest flux) AR - “coldest_wins”: overlap pixel uses flux from coldest (lowest flux) AR Default: “hottest_wins”
plot_map_wavelength : float, optional oversample : int, optional
Number of sub-exposures per phase point. Each requested phase is spread into
oversampleuniformly spaced sub-phases spanning one phase step, and the resulting fluxes are averaged. This mimics finite-exposure integration and smooths limb-crossing artefacts. Default: 3 (no oversampling).dict — pass directly to
evaluate_light_curve
- sajax.core.evaluate_light_curve(model: dict, flux_active: jax.numpy.ndarray, ar_lat: jax.numpy.ndarray, ar_long: jax.numpy.ndarray, ar_size: jax.numpy.ndarray) dict[source]#
Evaluate the light curve for a given set of active region parameters.
This function is pure JAX — all inputs may be JAX arrays or tracers, making it fully compatible with
jit,vmap, and gradient-based samplers such asemcee_jaxorblackjax.When the model was built with
oversample > 1, the computation runs on the oversampled phase grid and the results are averaged back to the original phase grid before returning.- Parameters:
model (dict) – Pre-built model dict returned by
build_model.flux_active (jnp.ndarray, shape (nar, nwave) or (nwave,)) – Active-region (active region) flux spectrum. - If (nar, nwave): each active region gets its own spectrum. - If (nwave,): broadcasts to all active regions (legacy mode).
ar_lat (jnp.ndarray, shape (nar,)) – active region latitudes in degrees. Must be in [-90, 90].
ar_long (jnp.ndarray, shape (nar,)) – active region longitudes in degrees. Must be in [0, 360).
ar_size (jnp.ndarray, shape (nar,)) – active region angular radii in degrees.
- Returns:
dict with keys
~~~~~~~~~~~~~~
lc— (nphase_original,) normalised broadband light curveepsilon— (nphase_original, nwave) contamination factor ε(λ)star_maps— (nphase_original, n, n) stellar flux map per phase – (maps are from the first sub-exposure of each phase when oversampling is active)
- sajax.core.compute_light_curve(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, flux_active: numpy.ndarray, params: dict, ar_lat: numpy.ndarray, ar_long: numpy.ndarray, ar_size: numpy.ndarray, phases_rot: numpy.ndarray, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: float | None = None, oversample: int = 1) dict[source]#
Convenience wrapper: build model and evaluate in one call.
Equivalent to:
model = build_model(wavelength, flux_quiet, params, phases_rot, stellar_grid_size, ve, ldc_mode, ar_overlap_mode, plot_map_wavelength, oversample) result = evaluate_light_curve(model, flux_active, ar_lat, ar_long, ar_size)
Use
build_model+evaluate_light_curvedirectly when running MCMC so the grid is built only once.- Parameters:
wavelength (array_like, shape (nwave,))
flux_quiet (array_like, shape (nwave,))
flux_active (array_like, shape (nar, nwave) or (nwave,))
params (dict)
ar_lat (array_like, shape (nar,))
ar_long (array_like, shape (nar,))
ar_size (array_like, shape (nar,))
phases_rot (array_like, shape (nphase,))
stellar_grid_size (int)
ve (float)
ldc_mode (str)
ar_overlap_mode (str)
plot_map_wavelength (float, optional)
oversample (int, optional) – Number of sub-exposures per phase point (default: 3).
- Return type:
dict with keys
lc,epsilon,star_mapsas NumPy arrays.
- sajax.core.build_combined_model(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, params: dict, times: numpy.ndarray, P_rot: float, transit_params: dict, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: float | None = None, oversample: int = 1) dict[source]#
Build a combined stellar-activity + planetary-transit sajax model.
This is the entry point for modelling active-region crossing events: the planet mask is applied at the individual pixel level, so if the planet occults a starspot or facula the resulting anomaly in the light curve is computed correctly.
Compared to multiplying independent stellar and transit light curves, this function correctly handles:
Planet occulting a spot (spot-crossing anomaly).
Planet occulting a facula (facula-crossing anomaly).
The varying limb-darkening depth of the transit as a function of the stellar surface brightness profile.
- Parameters:
wavelength ((nwave,) wavelength array [nm])
flux_quiet ((nwave,) quiet-star flux spectrum)
params (stellar model params dict (same as
build_model))times ((ntime,) absolute observation times [days])
P_rot (stellar rotation period [days])
transit_params (dict with keys (all required unless noted):) –
t0— mid-transit epoch [days]period— orbital period [days]a_over_rstar— semimajor axis / R★ (dimensionless)inclination— orbital inclination [rad]k— planet-to-star radius ratio Rp / R★ecc— eccentricity (default 0.0)omega_peri— argument of periastron [rad] (default 0.0)stellar_grid_size (stellar radius in pixels)
ve (equatorial velocity [km/s])
ldc_mode (limb-darkening law (same options as
build_model))ar_overlap_mode (active-region overlap rule)
plot_map_wavelength (wavelength for 2D map output [nm])
oversample (sub-exposure count per phase point (default 3))
- Return type:
model dict — pass directly to
evaluate_light_curve
Notes
The oversampling is applied consistently to both the stellar rotation phase grid and the orbital time grid. For each original time t_i with phase step dt,
oversamplesub-times are generated spanning [t_i - dt/2, t_i + dt/2), exactly mirroring_make_oversampled_phases.
- sajax.core.compute_combined_light_curve(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, flux_active: numpy.ndarray, params: dict, ar_lat: numpy.ndarray, ar_long: numpy.ndarray, ar_size: numpy.ndarray, times: numpy.ndarray, P_rot: float, transit_params: dict, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: float | None = None, oversample: int = 1) dict[source]#
Convenience wrapper: build a combined stellar + transit model and evaluate it in one call.
Equivalent to:
model = build_combined_model(wavelength, flux_quiet, params, times, P_rot, transit_params, stellar_grid_size, ve, ldc_mode, ar_overlap_mode, plot_map_wavelength, oversample) result = evaluate_light_curve(model, flux_active, ar_lat, ar_long, ar_size)
Use
build_model+evaluate_light_curvedirectly when running MCMC so the grid is built only once.- Parameters:
and ((All parameters match build_combined_model)
details.) (evaluate_light_curve. See their docstrings for)
transit_params (dict) –
t0,period,a_over_rstar,inclination,k, and optionallyecc,omega_peri.
- Returns:
dict with keys
lc,epsilon,star_maps(same ascompute_light_curve).