sajax.planet
============

.. py:module:: sajax.planet

.. autoapi-nested-parse::

   planet.py — Keplerian planet orbit and pixel-level transit geometry for sajax.

   This module is a standalone companion to sajax/core.py.  It can be used
   independently to compute transit light curves, or integrated with sajax
   via ``build_combined_model`` / ``compute_combined_light_curve`` (defined in
   core.py) to correctly model active-region crossing events — i.e. cases where
   the planet occultes a starspot or facula during transit.

   Architecture
   ------------
   The module is intentionally *geometry-only*: it computes where the planet is
   on the sky at each epoch and which stellar-disc pixels it occults.  The flux
   integration (limb darkening, active-region weighting) is handled by the
   existing sajax machinery in core.py.  This clean separation means that the
   transit model inherits sajax's full limb-darkening parametrisation
   automatically — no extra parameters are required.

   Orbital convention  (Winn 2010 / Eastman et al. 2013)
   ------------------------------------------------------
     X  — sky-plane east-west  (positive east)
     Y  — sky-plane north-south  (positive north, foreshortened by cos i)
     Z  — line-of-sight toward observer  (Z > 0 ⟹ planet in front of star)

   All sky positions are in units of the stellar radius R*.

   Minimum parameter set
   ---------------------
     t0            : mid-transit epoch  [days]
     period        : orbital period  [days]
     a_over_rstar  : semimajor axis / R*  (dimensionless)
                     May be derived from stellar density via
                     ``stellar_density_to_a_over_rstar()``.
     inclination   : orbital inclination  [rad]   (90 / π/2 = perfect edge-on)
     ecc           : orbital eccentricity  [0, 1)
     omega_peri    : argument of periastron  [rad]
                     (ω = 0° → periapsis at ascending node;
                      ω = 90° → periapsis at inferior conjunction /
                      transit centre for a circular orbit)
     k             : planet-to-star radius ratio  Rp / R*

   Limb darkening
   --------------
   The same LDC law stored in the sajax model dict is applied automatically
   to occulted pixels — no separate transit LDC parameters are required.

   Public API
   ----------
     ``_kepler(M, ecc)``                    — differentiable Kepler solver
     ``planet_sky_position(...)``           — single-epoch sky coords (X, Y, Z)
     ``compute_planet_sky_positions(...)``  — vectorised over an array of times
     ``_compute_planet_mask(...)``          — per-pixel occultation mask
     ``build_transit_model(...)``           — pre-compute positions for all times
     ``stellar_density_to_a_over_rstar()``  — unit-conversion convenience



Functions
---------

.. autoapisummary::

   sajax.planet.planet_sky_position
   sajax.planet.compute_planet_sky_positions
   sajax.planet.build_transit_model
   sajax.planet.stellar_density_to_a_over_rstar
   sajax.planet.a_over_rstar_to_stellar_density


Module Contents
---------------

.. py:function:: planet_sky_position(time: jax.numpy.ndarray, t0: float, period: float, a_over_rstar: float, inclination: float, ecc: float, omega_peri: float) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]

   Compute the planet's sky-plane position (X, Y, Z) in units of R*.

   :param time:
   :type time: observation epoch  [same units as t0 / period, e.g. days]
   :param t0:
   :type t0: mid-transit epoch (inferior conjunction)
   :param period:
   :type period: orbital period
   :param a_over_rstar:
   :type a_over_rstar: semimajor axis / R*  (dimensionless, > 1 for non-grazing)
   :param inclination:
   :type inclination: orbital inclination [rad]   (π/2 = edge-on)
   :param ecc:
   :type ecc: eccentricity  [0, 1)
   :param omega_peri: Measured from the ascending node to periapsis.
   :type omega_peri: argument of periastron  [rad]

   :returns: **X, Y, Z** -- X  — east-west (positive east)
             Y  — north-south projected  (= r sin(ω+f) cos i)
             Z  — toward observer  (Z > 0 ⟹ transit;  Z < 0 ⟹ occultation)
   :rtype: sky-plane coordinates in units of R*

   .. rubric:: Notes

   The sky-plane separation from the stellar centre is sqrt(X^2 + Y^2).
   A transit (or occultation) event occurs when sqrt(X^2 + Y^2) < 1 + k,
   where k = Rp / R*.


.. py:function:: compute_planet_sky_positions(times: jax.numpy.ndarray, t0: float, period: float, a_over_rstar: float, inclination: float, ecc: float, omega_peri: float) -> jax.numpy.ndarray

   Vectorised wrapper around ``planet_sky_position``.

   :param times:
   :type times: (ntime,) array of observation epochs

   :returns: **xyz**
   :rtype: (ntime, 3) array  —  columns are [X, Y, Z] in units of R*


.. py:function:: build_transit_model(times: numpy.ndarray, t0: float, period: float, a_over_rstar: float, inclination: float, ecc: float = 0.0, omega_peri: float = 0.0, k: float = 0.1) -> dict

   Pre-compute the planet's sky-plane position at every epoch in ``times``.

   The returned dict should be stored in the sajax model dict under the key
   ``"transit"``.  The combined model builder ``build_combined_model()``
   (in core.py) does this automatically — end users typically do not need
   to call this function directly.

   :param times: Must be the **oversampled** time array when oversampling
                 is active (see ``build_combined_model``).
   :type times: (ntime,) array of observation epochs  [days]
   :param t0:
   :type t0: mid-transit epoch  [days]
   :param period:
   :type period: orbital period  [days]
   :param a_over_rstar:
   :type a_over_rstar: semimajor axis / R*  (dimensionless)
   :param inclination:
   :type inclination: orbital inclination  [rad]
   :param ecc:
   :type ecc: eccentricity  (default: 0.0 = circular)
   :param omega_peri:
   :type omega_peri: argument of periastron  [rad]  (default: 0.0)
   :param k:
   :type k: planet-to-star radius ratio  Rp / R*  (default: 0.1)

   :returns: * *dict with keys*
             * *~~~~~~~~~~~~~~*
             * **``planet_xyz``** (*(ntime, 3) jnp.ndarray — planet (X, Y, Z) per epoch*)
             * **``k``** (*float — planet-to-star radius ratio*)


.. py:function:: stellar_density_to_a_over_rstar(rho_star_gcc: float, period_days: float) -> float

   Convert mean stellar density and orbital period to a / R* via
   Kepler's third law  (Seager & Mallén-Ornelas 2003):

       a / R* = ( G ρ★ P^2 / (3π) )^(1/3)

   :param rho_star_gcc:
   :type rho_star_gcc: mean stellar density  [g cm^-3]
   :param period_days:
   :type period_days: orbital period  [days]

   :returns: **a_over_rstar**
   :rtype: float  (dimensionless)


.. py:function:: a_over_rstar_to_stellar_density(a_over_rstar: float, period_days: float) -> float

   Inverse of ``stellar_density_to_a_over_rstar``:

       ρ★ = 3π / (G P^2) · (a / R*)^3

   :param a_over_rstar:
   :type a_over_rstar: semimajor axis / R*  (dimensionless)
   :param period_days:
   :type period_days: orbital period  [days]

   :returns: **rho_star_gcc**
   :rtype: mean stellar density  [g cm^-3]


