sajax.core#

core.py — JAX-accelerated stellar active region light-curve engine.

This module is a complete rewrite of SAGE1/sage.py in JAX.

Key differences from the original NumPy/SciPy implementation#

  1. No wavelength loop. The original code iterated over wavelengths with a Python for loop. Here the entire spectral axis is handled by jax.vmap, which maps the single-channel computation across all wavelengths in parallel.

  2. No phase loop. The original code iterated over rotational phases with a Python loop. Here all phases are computed in a single jax.vmap call — this is the main source of speedup over the original code.

  3. No scatter-index active region placement. The original code located active region pixels via integer scatter indices (fancy indexing with .astype(int)), which is not differentiable and incompatible with jit. SAJAX instead computes an analytic angular-distance mask over the full pixel arrays using jnp.where, which is fully vectorised and differentiable.

  4. No class state mutation. The original sage_class.rotate_star() mutated self.phases_rot inside a loop — a latent bug. SAJAX uses pure functions throughout.

  5. No astropy dependency for geometry. Rotation matrices are implemented directly in JAX (see geometry.py).

  6. No transit-geometry parameters. The original SAGE grid was sized using planet_pixel_size, radiusratio, and semimajor — artifacts of its transit-fitting origin. SAJAX replaces these with a single stellar_grid_size parameter: the stellar radius in pixels. No planet required.

  7. Pre-masked grid. build_stellar_grid applies the stellar disc mask immediately and returns 1D arrays containing only the in-disc pixels. No starmask is ever passed to JAX functions — the mask is implicit in the data shape. The only 2D reconstruction happens at output time for star_maps, using stored flat indices.

  8. Differentiable end-to-end. All operations are JAX-native, so jax.grad / jax.jacobian work on the full pipeline — useful for gradient-based retrieval.

  9. Phase oversampling. Real observations integrate photons over a finite exposure time. When an active region crosses the stellar limb, the discrete pixel grid can produce sharp discontinuities in the light curve. The oversample parameter (default 1, i.e. off) spreads each requested phase into multiple sub-exposures and averages the result, mimicking finite-exposure integration and smoothing limb-crossing artefacts.

JIT compilation#

Do NOT jit(evaluate_light_curve) directly — it contains Python-level control flow on model metadata. Instead, the inner _compute_all_phases is the hot path and is safe to JIT via:

from jax import jit _compute_all_phases_jit = jit(_compute_all_phases, static_argnames=[

“star_pixel_rad”, “total_pixels”, “ldc_mode”, “ar_overlap_mode”, “plot_map_wavelength”, “n”,

])

Attributes#

Functions#

build_stellar_grid(→ dict)

Pre-compute the static stellar pixel grid, masked to the stellar disc.

build_model(→ dict)

Pre-build all static model arrays. Call this once before MCMC.

evaluate_light_curve(→ dict)

Evaluate the light curve for a given set of active region parameters.

compute_light_curve(→ dict)

Convenience wrapper: build model and evaluate in one call.

build_combined_model(→ dict)

Build a combined stellar-activity + planetary-transit sajax model.

compute_combined_light_curve(→ dict)

Convenience wrapper: build a combined stellar + transit model and

Module Contents#

sajax.core.LdcMode[source]#
sajax.core.ArOverlapMode[source]#
sajax.core.build_stellar_grid(stellar_grid_size: int, ve: float) dict[source]#

Pre-compute the static stellar pixel grid, masked to the stellar disc.

The mask is applied here once so that all downstream JAX functions receive 1D arrays containing only the in-disc pixels — no starmask is ever passed around.

Parameters:
  • stellar_grid_size (int) – Stellar radius in pixels. This is the single resolution knob: higher values give a finer grid at the cost of n² memory and compute. Values of 100-300 are typical.

  • ve (float) – Stellar equatorial velocity [km/s].

Returns:

  • dict with keys

  • ~~~~~~~~~~~~~~

  • n — full grid side length (always odd)

  • star_pixel_rad— stellar radius in pixels (= stellar_grid_size)

  • total_pixels — number of in-disc pixels

  • flat_indices — (total_pixels,) int indices into the flattened – (n, n) grid; used to reconstruct 2D maps at output

  • x — (total_pixels,) x pixel coordinates [in-disc only]

  • y — (total_pixels,) y pixel coordinates [in-disc only]

  • mu — (total_pixels,) limb-darkening cos θ [in-disc only]

  • vel — (total_pixels,) Doppler factor Δv/c [in-disc only]

sajax.core.build_model(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, params: dict, phases_rot: numpy.ndarray, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: float | None = None, oversample: int = 1) dict[source]#

Pre-build all static model arrays. Call this once before MCMC.

Everything that does not change between MCMC steps is computed here in NumPy and stored in the returned model dict. The only quantities that vary per step — flux_active, ar_lat, ar_long, ar_size — are intentionally excluded and passed to evaluate_light_curve instead.

wavelength : array_like, shape (nwave,) flux_quiet : array_like, shape (nwave,) params : dict

Model parameters. Recognised keys:

inc_starfloat, optional

Stellar inclination in degrees (default: 90.0). 90° = equator-on, 0° = pole-on.

ldc_coeffslist of float or list of array(nwave,)

Limb-darkening coefficients for the chosen ldc_mode: - "linear": [u] - "quadratic": [u1, u2] - "power2": [c, alpha] - "kipping3": [c1, c2, c3]

  • "nonlinear4": [c1, c2, c3, c4]

    Each element may be a scalar (broadcast to all wavelengths) or an array of length nwave. For "quadratic" mode only, u1 and u2 are also accepted as separate keys (legacy interface).

    mu_profilearray-like, optional

    Monotonically increasing μ grid points for ldc_mode="intensity_profile" (default: [0, 1]).

    I_profilearray-like, shape (nwave, n_mu_pts), optional

    Specific intensity at each (wavelength, μ) grid point. Required when ldc_mode="intensity_profile".

    phases_rot : array_like, shape (nphase,) stellar_grid_size : int ve : float ldc_mode : str ar_overlap_mode : {“hottest_wins”, “coldest_wins”}, optional

    Rule for resolving overlapping active regions: - “hottest_wins”: overlap pixel uses flux from hottest (highest flux) AR - “coldest_wins”: overlap pixel uses flux from coldest (lowest flux) AR Default: “hottest_wins”

    plot_map_wavelength : float, optional oversample : int, optional

    Number of sub-exposures per phase point. Each requested phase is spread into oversample uniformly spaced sub-phases spanning one phase step, and the resulting fluxes are averaged. This mimics finite-exposure integration and smooths limb-crossing artefacts. Default: 3 (no oversampling).

    dict — pass directly to evaluate_light_curve

sajax.core.evaluate_light_curve(model: dict, flux_active: jax.numpy.ndarray, ar_lat: jax.numpy.ndarray, ar_long: jax.numpy.ndarray, ar_size: jax.numpy.ndarray) dict[source]#

Evaluate the light curve for a given set of active region parameters.

This function is pure JAX — all inputs may be JAX arrays or tracers, making it fully compatible with jit, vmap, and gradient-based samplers such as emcee_jax or blackjax.

When the model was built with oversample > 1, the computation runs on the oversampled phase grid and the results are averaged back to the original phase grid before returning.

Parameters:
  • model (dict) – Pre-built model dict returned by build_model.

  • flux_active (jnp.ndarray, shape (nar, nwave) or (nwave,)) – Active-region (active region) flux spectrum. - If (nar, nwave): each active region gets its own spectrum. - If (nwave,): broadcasts to all active regions (legacy mode).

  • ar_lat (jnp.ndarray, shape (nar,)) – active region latitudes in degrees. Must be in [-90, 90].

  • ar_long (jnp.ndarray, shape (nar,)) – active region longitudes in degrees. Must be in [0, 360).

  • ar_size (jnp.ndarray, shape (nar,)) – active region angular radii in degrees.

Returns:

  • dict with keys

  • ~~~~~~~~~~~~~~

  • lc — (nphase_original,) normalised broadband light curve

  • epsilon — (nphase_original, nwave) contamination factor ε(λ)

  • star_maps — (nphase_original, n, n) stellar flux map per phase – (maps are from the first sub-exposure of each phase when oversampling is active)

sajax.core.compute_light_curve(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, flux_active: numpy.ndarray, params: dict, ar_lat: numpy.ndarray, ar_long: numpy.ndarray, ar_size: numpy.ndarray, phases_rot: numpy.ndarray, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: float | None = None, oversample: int = 1) dict[source]#

Convenience wrapper: build model and evaluate in one call.

Equivalent to:

model  = build_model(wavelength, flux_quiet, params, phases_rot,
                     stellar_grid_size, ve, ldc_mode, ar_overlap_mode,
                     plot_map_wavelength, oversample)
result = evaluate_light_curve(model, flux_active,
                              ar_lat, ar_long, ar_size)

Use build_model + evaluate_light_curve directly when running MCMC so the grid is built only once.

Parameters:
  • wavelength (array_like, shape (nwave,))

  • flux_quiet (array_like, shape (nwave,))

  • flux_active (array_like, shape (nar, nwave) or (nwave,))

  • params (dict)

  • ar_lat (array_like, shape (nar,))

  • ar_long (array_like, shape (nar,))

  • ar_size (array_like, shape (nar,))

  • phases_rot (array_like, shape (nphase,))

  • stellar_grid_size (int)

  • ve (float)

  • ldc_mode (str)

  • ar_overlap_mode (str)

  • plot_map_wavelength (float, optional)

  • oversample (int, optional) – Number of sub-exposures per phase point (default: 3).

Return type:

dict with keys lc, epsilon, star_maps as NumPy arrays.

sajax.core.build_combined_model(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, params: dict, times: numpy.ndarray, P_rot: float, transit_params: dict, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: float | None = None, oversample: int = 1) dict[source]#

Build a combined stellar-activity + planetary-transit sajax model.

This is the entry point for modelling active-region crossing events: the planet mask is applied at the individual pixel level, so if the planet occults a starspot or facula the resulting anomaly in the light curve is computed correctly.

Compared to multiplying independent stellar and transit light curves, this function correctly handles:

  • Planet occulting a spot (spot-crossing anomaly).

  • Planet occulting a facula (facula-crossing anomaly).

  • The varying limb-darkening depth of the transit as a function of the stellar surface brightness profile.

Parameters:
  • wavelength ((nwave,) wavelength array [nm])

  • flux_quiet ((nwave,) quiet-star flux spectrum)

  • params (stellar model params dict (same as build_model))

  • times ((ntime,) absolute observation times [days])

  • P_rot (stellar rotation period [days])

  • transit_params (dict with keys (all required unless noted):) – t0 — mid-transit epoch [days] period — orbital period [days] a_over_rstar — semimajor axis / R★ (dimensionless) inclination — orbital inclination [rad] k — planet-to-star radius ratio Rp / R★ ecc — eccentricity (default 0.0) omega_peri — argument of periastron [rad] (default 0.0)

  • stellar_grid_size (stellar radius in pixels)

  • ve (equatorial velocity [km/s])

  • ldc_mode (limb-darkening law (same options as build_model))

  • ar_overlap_mode (active-region overlap rule)

  • plot_map_wavelength (wavelength for 2D map output [nm])

  • oversample (sub-exposure count per phase point (default 3))

Return type:

model dict — pass directly to evaluate_light_curve

Notes

The oversampling is applied consistently to both the stellar rotation phase grid and the orbital time grid. For each original time t_i with phase step dt, oversample sub-times are generated spanning [t_i - dt/2, t_i + dt/2), exactly mirroring _make_oversampled_phases.

sajax.core.compute_combined_light_curve(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, flux_active: numpy.ndarray, params: dict, ar_lat: numpy.ndarray, ar_long: numpy.ndarray, ar_size: numpy.ndarray, times: numpy.ndarray, P_rot: float, transit_params: dict, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: float | None = None, oversample: int = 1) dict[source]#

Convenience wrapper: build a combined stellar + transit model and evaluate it in one call.

Equivalent to:

model  = build_combined_model(wavelength, flux_quiet, params, times, P_rot,
                     transit_params, stellar_grid_size, ve, ldc_mode, ar_overlap_mode,
                     plot_map_wavelength, oversample)

result = evaluate_light_curve(model, flux_active,
                              ar_lat, ar_long, ar_size)

Use build_model + evaluate_light_curve directly when running MCMC so the grid is built only once.

Parameters:
  • and ((All parameters match build_combined_model)

  • details.) (evaluate_light_curve. See their docstrings for)

  • transit_params (dict) – t0, period, a_over_rstar, inclination, k, and optionally ecc, omega_peri.

Returns:

  • dict with keys lc, epsilon, star_maps (same as

  • compute_light_curve).