Source code for sajax.geometry
"""
geometry.py — JAX rotation matrices and coordinate transforms.
Replaces the ``astropy.coordinates.matrix_utilities.rotation_matrix``
dependency from the original SAGE code with pure JAX, making all
geometry operations differentiable and JIT-compilable.
Geometry convention (identical to original SAGE):
- Observer is at z → +∞. The plane of sky is X-Y.
- The stellar rotation axis is the y-axis.
- inc_star = 90° → equator-on (observer sees the equator).
- inc_star = 0° → pole-on (observer looks at the north pole).
"""
import jax.numpy as jnp
# ---------------------------------------------------------------------------
# Primitive rotation matrices
# ---------------------------------------------------------------------------
[docs]
def rotation_matrix_y(angle_rad: float) -> jnp.ndarray:
"""
3x3 active rotation matrix around the y-axis.
Parameters
----------
angle_rad : float
Rotation angle in radians.
Returns
-------
jnp.ndarray, shape (3, 3)
"""
c = jnp.cos(angle_rad)
s = jnp.sin(angle_rad)
return jnp.array([
[ c, 0., s],
[ 0., 1., 0.],
[-s, 0., c],
])
[docs]
def rotation_matrix_x(angle_rad: float) -> jnp.ndarray:
"""
3x3 active rotation matrix around the x-axis.
Parameters
----------
angle_rad : float
Rotation angle in radians.
Returns
-------
jnp.ndarray, shape (3, 3)
"""
c = jnp.cos(angle_rad)
s = jnp.sin(angle_rad)
return jnp.array([
[1., 0., 0.],
[0., c, -s],
[0., s, c],
])
# ---------------------------------------------------------------------------
# Combined stellar-rotation + inclination transform
# ---------------------------------------------------------------------------
[docs]
def rotate_active_region(
cart: jnp.ndarray,
phase_deg: float,
inc_deg: float,
) -> jnp.ndarray:
"""
Apply stellar rotation (y-axis) then stellar inclination (x-axis)
to a Cartesian coordinate vector of an active region.
This replaces the two-step ``stellar_rotation`` + ``stellar_inc``
functions from the original SAGE code.
Parameters
----------
cart : jnp.ndarray, shape (3,)
[x, y, z] pixel-coordinate position of the active region on the
stellar sphere.
phase_deg : float
Rotational phase in degrees.
inc_deg : float
Stellar inclination in degrees (90 = equator-on, 0 = pole-on).
Returns
-------
jnp.ndarray, shape (3,)
Rotated [x, y, z] coordinates.
"""
phase_rad = jnp.deg2rad(phase_deg)
# Original SAGE applies (90 - inc_star) as the x-axis tilt
tilt_rad = jnp.deg2rad(90.0 - inc_deg)
#Getting rotation matrices
R_rot = rotation_matrix_y(phase_rad)
R_inc = rotation_matrix_x(tilt_rad)
return R_inc @ R_rot @ cart