sajax.geometry
==============

.. py:module:: sajax.geometry

.. autoapi-nested-parse::

   geometry.py — JAX rotation matrices and coordinate transforms.

   Replaces the ``astropy.coordinates.matrix_utilities.rotation_matrix``
   dependency from the original SAGE code with pure JAX, making all
   geometry operations differentiable and JIT-compilable.

   Geometry convention (identical to original SAGE):
       - Observer is at z → +∞.  The plane of sky is X-Y.
       - The stellar rotation axis is the y-axis.
       - inc_star = 90°  →  equator-on  (observer sees the equator).
       - inc_star =  0°  →  pole-on     (observer looks at the north pole).



Functions
---------

.. autoapisummary::

   sajax.geometry.rotation_matrix_y
   sajax.geometry.rotation_matrix_x
   sajax.geometry.rotate_active_region


Module Contents
---------------

.. py:function:: rotation_matrix_y(angle_rad: float) -> jax.numpy.ndarray

   3x3 active rotation matrix around the y-axis.

   :param angle_rad: Rotation angle in radians.
   :type angle_rad: float

   :rtype: jnp.ndarray, shape (3, 3)


.. py:function:: rotation_matrix_x(angle_rad: float) -> jax.numpy.ndarray

   3x3 active rotation matrix around the x-axis.

   :param angle_rad: Rotation angle in radians.
   :type angle_rad: float

   :rtype: jnp.ndarray, shape (3, 3)


.. py:function:: rotate_active_region(cart: jax.numpy.ndarray, phase_deg: float, inc_deg: float) -> jax.numpy.ndarray

   Apply stellar rotation (y-axis) then stellar inclination (x-axis)
   to a Cartesian coordinate vector of an active region.

   This replaces the two-step ``stellar_rotation`` + ``stellar_inc``
   functions from the original SAGE code.

   :param cart: [x, y, z] pixel-coordinate position of the active region on the
                stellar sphere.
   :type cart: jnp.ndarray, shape (3,)
   :param phase_deg: Rotational phase in degrees.
   :type phase_deg: float
   :param inc_deg: Stellar inclination in degrees (90 = equator-on, 0 = pole-on).
   :type inc_deg: float

   :returns: Rotated [x, y, z] coordinates.
   :rtype: jnp.ndarray, shape (3,)


