sajax.core
==========

.. py:module:: sajax.core

.. autoapi-nested-parse::

   core.py — JAX-accelerated stellar active region light-curve engine.

   This module is a complete rewrite of ``SAGE1/sage.py`` in JAX.

   Key differences from the original NumPy/SciPy implementation
   -------------------------------------------------------------
   1. **No wavelength loop.**
      The original code iterated over wavelengths with a Python ``for`` loop.
      Here the entire spectral axis is handled by ``jax.vmap``, which maps
      the single-channel computation across all wavelengths in parallel.

   2. **No phase loop.**
      The original code iterated over rotational phases with a Python loop.
      Here all phases are computed in a single ``jax.vmap`` call — this is
      the main source of speedup over the original code.

   3. **No scatter-index active region placement.**
      The original code located active region pixels via integer scatter indices
      (fancy indexing with ``.astype(int)``), which is not differentiable
      and incompatible with ``jit``.  SAJAX instead computes an analytic
      angular-distance mask over the full pixel arrays using ``jnp.where``,
      which is fully vectorised and differentiable.

   4. **No class state mutation.**
      The original ``sage_class.rotate_star()`` mutated ``self.phases_rot``
      inside a loop — a latent bug.  SAJAX uses pure functions throughout.

   5. **No astropy dependency for geometry.**
      Rotation matrices are implemented directly in JAX (see geometry.py).

   6. **No transit-geometry parameters.**
      The original SAGE grid was sized using ``planet_pixel_size``,
      ``radiusratio``, and ``semimajor`` — artifacts of its transit-fitting
      origin.  SAJAX replaces these with a single ``stellar_grid_size``
      parameter: the stellar radius in pixels.  No planet required.

   7. **Pre-masked grid.**
      ``build_stellar_grid`` applies the stellar disc mask immediately and
      returns 1D arrays containing only the in-disc pixels.  No starmask is
      ever passed to JAX functions — the mask is implicit in the data shape.
      The only 2D reconstruction happens at output time for ``star_maps``,
      using stored flat indices.

   8. **Differentiable end-to-end.**
      All operations are JAX-native, so ``jax.grad`` / ``jax.jacobian``
      work on the full pipeline — useful for gradient-based retrieval.

   9. **Phase oversampling.**
      Real observations integrate photons over a finite exposure time.
      When an active region crosses the stellar limb, the discrete pixel
      grid can produce sharp discontinuities in the light curve.
      The ``oversample`` parameter (default 1, i.e. off) spreads each
      requested phase into multiple sub-exposures and averages the result,
      mimicking finite-exposure integration and smoothing limb-crossing
      artefacts.

   JIT compilation
   ---------------
   Do NOT jit(evaluate_light_curve) directly — it contains Python-level
   control flow on model metadata.  Instead, the inner _compute_all_phases
   is the hot path and is safe to JIT via:

       from jax import jit
       _compute_all_phases_jit = jit(_compute_all_phases, static_argnames=[
           "star_pixel_rad", "total_pixels", "ldc_mode",
           "ar_overlap_mode", "plot_map_wavelength", "n",
       ])



Attributes
----------

.. autoapisummary::

   sajax.core.LdcMode
   sajax.core.ArOverlapMode


Functions
---------

.. autoapisummary::

   sajax.core.build_stellar_grid
   sajax.core.build_model
   sajax.core.evaluate_light_curve
   sajax.core.compute_light_curve
   sajax.core.build_combined_model
   sajax.core.compute_combined_light_curve


Module Contents
---------------

.. py:data:: LdcMode

.. py:data:: ArOverlapMode

.. py:function:: build_stellar_grid(stellar_grid_size: int, ve: float) -> dict

   Pre-compute the static stellar pixel grid, masked to the stellar disc.

   The mask is applied here once so that all downstream JAX functions
   receive 1D arrays containing only the in-disc pixels — no starmask
   is ever passed around.

   :param stellar_grid_size: Stellar radius in pixels.  This is the single resolution knob:
                             higher values give a finer grid at the cost of n² memory and
                             compute.  Values of 100-300 are typical.
   :type stellar_grid_size: int
   :param ve: Stellar equatorial velocity [km/s].
   :type ve: float

   :returns: * *dict with keys*
             * *~~~~~~~~~~~~~~*
             * ``n``             — full grid side length (always odd)
             * ``star_pixel_rad``— stellar radius in pixels (= stellar_grid_size)
             * ``total_pixels``  — number of in-disc pixels
             * ``flat_indices``  — (total_pixels,) int  indices into the flattened -- (n, n) grid; used to reconstruct 2D maps at output
             * ``x``             — (total_pixels,) x pixel coordinates  [in-disc only]
             * ``y``             — (total_pixels,) y pixel coordinates  [in-disc only]
             * ``mu``            — (total_pixels,) limb-darkening cos θ [in-disc only]
             * ``vel``           — (total_pixels,) Doppler factor Δv/c  [in-disc only]


.. py:function:: build_model(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, params: dict, phases_rot: numpy.ndarray, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: Optional[float] = None, oversample: int = 1) -> dict

       Pre-build all static model arrays.  Call this **once** before MCMC.

       Everything that does not change between MCMC steps is computed here in
       NumPy and stored in the returned model dict.  The only quantities that
       vary per step — ``flux_active``, ``ar_lat``, ``ar_long``,
       ``ar_size`` — are intentionally excluded and passed to
       ``evaluate_light_curve`` instead.

       Parameters
       ----------
       wavelength : array_like, shape (nwave,)
       flux_quiet : array_like, shape (nwave,)
       params : dict
           Model parameters. Recognised keys:

           ``inc_star`` : float, optional
               Stellar inclination in degrees (default: 90.0).
               90° = equator-on, 0° = pole-on.

           ``ldc_coeffs`` : list of float or list of array(nwave,)
               Limb-darkening coefficients for the chosen ``ldc_mode``:
               - ``"linear"``:     [u]
               - ``"quadratic"``:  [u1, u2]
               - ``"power2"``:     [c, alpha]
               - ``"kipping3"``:   [c1, c2, c3]
   - ``"nonlinear4"``: [c1, c2, c3, c4]
               Each element may be a scalar (broadcast to all wavelengths)
               or an array of length ``nwave``.
               For ``"quadratic"`` mode only, ``u1`` and ``u2`` are also
               accepted as separate keys (legacy interface).

           ``mu_profile`` : array-like, optional
               Monotonically increasing μ grid points for
               ``ldc_mode="intensity_profile"`` (default: [0, 1]).

           ``I_profile`` : array-like, shape (nwave, n_mu_pts), optional
               Specific intensity at each (wavelength, μ) grid point.
               Required when ``ldc_mode="intensity_profile"``.
       phases_rot : array_like, shape (nphase,)
       stellar_grid_size : int
       ve : float
       ldc_mode : str
       ar_overlap_mode : {"hottest_wins", "coldest_wins"}, optional
           Rule for resolving overlapping active regions:
           - "hottest_wins": overlap pixel uses flux from hottest (highest flux) AR
           - "coldest_wins": overlap pixel uses flux from coldest (lowest flux) AR
           Default: "hottest_wins"
       plot_map_wavelength : float, optional
       oversample : int, optional
           Number of sub-exposures per phase point.  Each requested phase is
           spread into ``oversample`` uniformly spaced sub-phases spanning one
           phase step, and the resulting fluxes are averaged.  This mimics
           finite-exposure integration and smooths limb-crossing artefacts.
           Default: 3 (no oversampling).

       Returns
       -------
       dict  — pass directly to ``evaluate_light_curve``



.. py:function:: evaluate_light_curve(model: dict, flux_active: jax.numpy.ndarray, ar_lat: jax.numpy.ndarray, ar_long: jax.numpy.ndarray, ar_size: jax.numpy.ndarray) -> dict

   Evaluate the light curve for a given set of active region parameters.

   This function is **pure JAX** — all inputs may be JAX arrays or tracers,
   making it fully compatible with ``jit``, ``vmap``, and gradient-based
   samplers such as ``emcee_jax`` or ``blackjax``.

   When the model was built with ``oversample > 1``, the computation runs
   on the oversampled phase grid and the results are averaged back to the
   original phase grid before returning.

   :param model: Pre-built model dict returned by ``build_model``.
   :type model: dict
   :param flux_active: Active-region (active region) flux spectrum.
                       - If (nar, nwave): each active region gets its own spectrum.
                       - If (nwave,):     broadcasts to all active regions (legacy mode).
   :type flux_active: jnp.ndarray, shape (nar, nwave) or (nwave,)
   :param ar_lat: active region latitudes in degrees. Must be in [-90, 90].
   :type ar_lat: jnp.ndarray, shape (nar,)
   :param ar_long: active region longitudes in degrees. Must be in [0, 360).
   :type ar_long: jnp.ndarray, shape (nar,)
   :param ar_size: active region angular radii in degrees.
   :type ar_size: jnp.ndarray, shape (nar,)

   :returns: * *dict with keys*
             * *~~~~~~~~~~~~~~*
             * ``lc``        — (nphase_original,) normalised broadband light curve
             * ``epsilon``   — (nphase_original, nwave) contamination factor ε(λ)
             * ``star_maps`` — (nphase_original, n, n) stellar flux map per phase -- (maps are from the *first* sub-exposure of each phase
               when oversampling is active)


.. py:function:: compute_light_curve(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, flux_active: numpy.ndarray, params: dict, ar_lat: numpy.ndarray, ar_long: numpy.ndarray, ar_size: numpy.ndarray, phases_rot: numpy.ndarray, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: Optional[float] = None, oversample: int = 1) -> dict

   Convenience wrapper: build model and evaluate in one call.

   Equivalent to::

       model  = build_model(wavelength, flux_quiet, params, phases_rot,
                            stellar_grid_size, ve, ldc_mode, ar_overlap_mode,
                            plot_map_wavelength, oversample)
       result = evaluate_light_curve(model, flux_active,
                                     ar_lat, ar_long, ar_size)

   Use ``build_model`` + ``evaluate_light_curve`` directly when running
   MCMC so the grid is built only once.

   :param wavelength:
   :type wavelength: array_like, shape (nwave,)
   :param flux_quiet:
   :type flux_quiet: array_like, shape (nwave,)
   :param flux_active:
   :type flux_active: array_like, shape (nar, nwave) or (nwave,)
   :param params:
   :type params: dict
   :param ar_lat:
   :type ar_lat: array_like, shape (nar,)
   :param ar_long:
   :type ar_long: array_like, shape (nar,)
   :param ar_size:
   :type ar_size: array_like, shape (nar,)
   :param phases_rot:
   :type phases_rot: array_like, shape (nphase,)
   :param stellar_grid_size:
   :type stellar_grid_size: int
   :param ve:
   :type ve: float
   :param ldc_mode:
   :type ldc_mode: str
   :param ar_overlap_mode:
   :type ar_overlap_mode: str
   :param plot_map_wavelength:
   :type plot_map_wavelength: float, optional
   :param oversample: Number of sub-exposures per phase point (default: 3).
   :type oversample: int, optional

   :rtype: dict with keys ``lc``, ``epsilon``, ``star_maps`` as NumPy arrays.


.. py:function:: build_combined_model(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, params: dict, times: numpy.ndarray, P_rot: float, transit_params: dict, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: Optional[float] = None, oversample: int = 1) -> dict

   Build a combined stellar-activity + planetary-transit sajax model.

   This is the entry point for modelling **active-region crossing events**:
   the planet mask is applied at the individual pixel level, so if the planet
   occults a starspot or facula the resulting anomaly in the light curve is
   computed correctly.

   Compared to multiplying independent stellar and transit light curves, this
   function correctly handles:
     • Planet occulting a spot (spot-crossing anomaly).
     • Planet occulting a facula (facula-crossing anomaly).
     • The varying limb-darkening depth of the transit as a function of
       the stellar surface brightness profile.

   :param wavelength:
   :type wavelength: (nwave,)  wavelength array  [nm]
   :param flux_quiet:
   :type flux_quiet: (nwave,)  quiet-star flux spectrum
   :param params:
   :type params: stellar model params dict (same as ``build_model``)
   :param times:
   :type times: (ntime,)  absolute observation times  [days]
   :param P_rot:
   :type P_rot: stellar rotation period  [days]
   :param transit_params: ``t0``            — mid-transit epoch  [days]
                          ``period``        — orbital period  [days]
                          ``a_over_rstar``  — semimajor axis / R★  (dimensionless)
                          ``inclination``   — orbital inclination  [rad]
                          ``k``             — planet-to-star radius ratio  Rp / R★
                          ``ecc``           — eccentricity  (default 0.0)
                          ``omega_peri``    — argument of periastron  [rad]  (default 0.0)
   :type transit_params: dict with keys (all required unless noted):
   :param stellar_grid_size:
   :type stellar_grid_size: stellar radius in pixels
   :param ve:
   :type ve: equatorial velocity  [km/s]
   :param ldc_mode:
   :type ldc_mode: limb-darkening law  (same options as ``build_model``)
   :param ar_overlap_mode:
   :type ar_overlap_mode: active-region overlap rule
   :param plot_map_wavelength:
   :type plot_map_wavelength: wavelength for 2D map output  [nm]
   :param oversample:
   :type oversample: sub-exposure count per phase point  (default 3)

   :rtype: model dict — pass directly to ``evaluate_light_curve``

   .. rubric:: Notes

   The oversampling is applied *consistently* to both the stellar rotation
   phase grid and the orbital time grid.  For each original time t_i with
   phase step dt, ``oversample`` sub-times are generated spanning
   [t_i - dt/2, t_i + dt/2), exactly mirroring ``_make_oversampled_phases``.


.. py:function:: compute_combined_light_curve(wavelength: numpy.ndarray, flux_quiet: numpy.ndarray, flux_active: numpy.ndarray, params: dict, ar_lat: numpy.ndarray, ar_long: numpy.ndarray, ar_size: numpy.ndarray, times: numpy.ndarray, P_rot: float, transit_params: dict, stellar_grid_size: int, ve: float, ldc_mode: LdcMode = 'quadratic', ar_overlap_mode: ArOverlapMode = 'hottest_wins', plot_map_wavelength: Optional[float] = None, oversample: int = 1) -> dict

   Convenience wrapper: build a combined stellar + transit model and
   evaluate it in one call.

   Equivalent to::

       model  = build_combined_model(wavelength, flux_quiet, params, times, P_rot,
                            transit_params, stellar_grid_size, ve, ldc_mode, ar_overlap_mode,
                            plot_map_wavelength, oversample)

       result = evaluate_light_curve(model, flux_active,
                                     ar_lat, ar_long, ar_size)

   Use ``build_model`` + ``evaluate_light_curve`` directly when running
   MCMC so the grid is built only once.

   :param (All parameters match ``build_combined_model`` and:
   :param ``evaluate_light_curve``.  See their docstrings for details.):
   :param transit_params: ``t0``, ``period``, ``a_over_rstar``, ``inclination``, ``k``,
                          and optionally ``ecc``, ``omega_peri``.
   :type transit_params: dict

   :returns: * dict with keys ``lc``, ``epsilon``, ``star_maps``  (same as
             * ``compute_light_curve``).


