from typing import Union, NamedTuple, Tuple
from functools import partial
import jax
from jax import jit
import jax.numpy as jnp
from odisseo.dynamics import direct_acc, direct_acc_laxmap, direct_acc_matrix, direct_acc_for_loop, direct_acc_sharding, no_self_gravity
from odisseo.potentials import combined_external_acceleration_vmpa_switch
from odisseo.option_classes import SimulationConfig, SimulationParams
from odisseo.option_classes import DIRECT_ACC, DIRECT_ACC_LAXMAP, DIRECT_ACC_MATRIX, DIRECT_ACC_FOR_LOOP, DIRECT_ACC_SHARDING, NO_SELF_GRAVITY
from odisseo.units import CodeUnits
from astropy import units as u
from astropy import constants as c
from jaxtyping import jaxtyped
from beartype import beartype as typechecker
[docs]
@partial(jax.jit, )
@jaxtyped(typechecker=typechecker)
def center_of_mass(state: jnp.ndarray,
mass: jnp.ndarray) -> jnp.ndarray:
"""
Return the center of mass of the system.
Args:
state (jnp.ndarray): Array of shape (N_particles, 6) representing the positions and velocities of the particles.
mass (jnp.ndarray): Array of masses for each particle.
Returns:
jnp.ndarray: The center of mass position
"""
return jnp.sum(state[:, 0] * mass[:, jnp.newaxis], axis=0) / jnp.sum(mass)
###### Calculation of conserved quontities ######
[docs]
@jit
@jaxtyped(typechecker=typechecker)
def E_kin(state: jnp.ndarray,
mass: jnp.ndarray) -> jnp.ndarray:
"""
Return the kinetic energy of the system.
Args:
state (jnp.ndarray): Array of shape (N_particles, 6) representing the positions and velocities of the particles.
mass (jnp.ndarray): Array of masses for each particle.
Returns:
jnp.ndarray: Kinetic energy of the particles in the system
"""
return 0.5 * (jnp.sum(state[:, 1]**2, axis=1) * mass)
[docs]
@partial(jax.jit, static_argnames=['config'])
@jaxtyped(typechecker=typechecker)
def E_pot(state: jnp.ndarray,
mass: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams, ) -> jnp.ndarray:
"""
Return the potential energy of the system.
Args:
state (jnp.ndarray): Array of shape (N_particles, 6) representing the positions and velocities of the particles.
mass (jnp.ndarray): Array of shape (N_particles,) representing the masses of the particles.
config (SimulationConfig): Configuration object containing simulation parameters.
params (SimulationParams): Parameters object containing physical parameters for the simulation.
Returns:
E_tot: The potential energy of each particle in the system.
"""
if config.acceleration_scheme == DIRECT_ACC:
_, pot = direct_acc(state, mass, config, params, return_potential=True)
elif config.acceleration_scheme == DIRECT_ACC_LAXMAP:
_, pot = direct_acc_laxmap(state, mass, config, params, return_potential=True)
elif config.acceleration_scheme == DIRECT_ACC_MATRIX:
_, pot = direct_acc_matrix(state, mass, config, params, return_potential=True)
elif config.acceleration_scheme == DIRECT_ACC_FOR_LOOP:
_, pot = direct_acc_for_loop(state, mass, config, params, return_potential=True)
elif config.acceleration_scheme == DIRECT_ACC_SHARDING:
_, pot = direct_acc_sharding(state, mass, config, params, return_potential=True)
elif config.acceleration_scheme == NO_SELF_GRAVITY:
_, pot = no_self_gravity(state, mass, config, params, return_potential=True)
self_Epot = pot*mass
external_Epot = 0.
if len(config.external_accelerations) > 0:
_, external_pot = combined_external_acceleration_vmpa_switch(state, config, params, return_potential=True)
external_Epot = external_pot*mass
return self_Epot + external_Epot
[docs]
@partial(jax.jit, static_argnames=['config'])
@jaxtyped(typechecker=typechecker)
def E_tot(state: jnp.ndarray,
mass: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams, ) -> jnp.ndarray:
"""
Return the total energy of the system.
Args:
state (jnp.ndarray): Array of shape (N_particles,2, 3) representing the positions and velocities of the particles.
mass (jnp.ndarray): Array of shape (N_particles,) representing the masses of the particles.
config (SimulationConfig): Configuration object containing simulation parameters.
params (SimulationParams): Parameters object containing physical parameters for the simulation.
Returns:
float: The total energy of each particle in the system
"""
return E_kin(state, mass) + E_pot(state, mass, config, params)
[docs]
@partial(jax.jit, )
@jaxtyped(typechecker=typechecker)
def Angular_momentum(state: jnp.ndarray,
mass: jnp.ndarray) -> jnp.ndarray:
"""
Return the angular momentum of the system.
Args:
state (jnp.ndarray): Array of shape (N_particles, 6) representing the positions and velocities of the particles.
mass (jnp.ndarray): Array of shape (N_particles,) representing the masses of the particles.
Returns:
jnp.ndarray: The angular momentum of each particle in the system
"""
return jnp.cross(state[:, 0], state[:, 1]) * mass[:, jnp.newaxis]
#### projection, this section is taken from the sstrax repo: https://github.com/undark-lab/sstrax/blob/main/sstrax/projection.py, add the code_units part #####
[docs]
@jax.jit
def halo_to_sun(Xhalo: jnp.ndarray) -> jnp.ndarray:
"""
Conversion from simulation frame to cartesian frame centred at Sun
Args:
Xhalo: 3d position (x [kpc], y [kpc], z [kpc]) in simulation frame
Returns:
3d position (x_s [kpc], y_s [kpc], z_s [kpc]) in Sun frame
Examples
--------
>>> halo_to_sun(jnp.array([1.0, 2.0, 3.0]))
"""
sunx = 8.0 # Distance from the Sun to the Galactic Centre in kpc
xsun = sunx - Xhalo[0]
ysun = Xhalo[1]
zsun = Xhalo[2]
return jnp.array([xsun, ysun, zsun])
[docs]
@jax.jit
def sun_to_gal(Xsun: jnp.ndarray) -> jnp.ndarray:
"""
Conversion from sun cartesian frame to galactic co-ordinates
Args:
Xsun: 3d position (x_s [kpc], y_s [kpc], z_s [kpc]) in Sun frame
Returns:
3d position (r [kpc], b [rad], l [rad]) in galactic frame
Examples
--------
>>> sun_to_gal(jnp.array([1.0, 2.0, 3.0]))
"""
r = jnp.linalg.norm(Xsun)
b = jnp.arcsin(Xsun[2] / r)
l = jnp.arctan2(Xsun[1], Xsun[0])
return jnp.array([r, b, l])
[docs]
@jax.jit
def gal_to_equat(Xgal: jnp.ndarray) -> jnp.ndarray:
"""
Conversion from galactic co-ordinates to equatorial co-ordinates
Args:
Xgal: 3d position (r [kpc], b [rad], l [rad]) in galactic frame
Returns:
3d position (r [kpc], alpha [rad], delta [rad]) in equatorial frame
Examples
--------
>>> gal_to_equat(jnp.array([1.0, 2.0, 3.0]))
"""
dNGPdeg = 27.12825118085622
lNGPdeg = 122.9319185680026
aNGPdeg = 192.85948
dNGP = dNGPdeg * jnp.pi / 180.0
lNGP = lNGPdeg * jnp.pi / 180.0
aNGP = aNGPdeg * jnp.pi / 180.0
r = Xgal[0]
b = Xgal[1]
l = Xgal[2]
sb = jnp.sin(b)
cb = jnp.cos(b)
sl = jnp.sin(lNGP - l)
cl = jnp.cos(lNGP - l)
cs = cb * sl
cc = jnp.cos(dNGP) * sb - jnp.sin(dNGP) * cb * cl
alpha = jnp.arctan(cs / cc) + aNGP
delta = jnp.arcsin(jnp.sin(dNGP) * sb + jnp.cos(dNGP) * cb * cl)
return jnp.array([r, alpha, delta])
[docs]
@jax.jit
def equat_to_gd1cart(Xequat: jnp.ndarray) -> jnp.ndarray:
"""
Conversion from equatorial co-ordinates to cartesian GD1 co-ordinates
Args:
Xequat: 3d position (r [kpc], alpha [rad], delta [rad]) in equatorial frame
Returns:
3d position (x_gd1 [kpc], y_gd1 [kpc], z_gd1 [kpc]) in cartesian GD1 frame
Examples
--------
>>> equat_to_gd1cart(jnp.array([1.0, 2.0, 3.0]))
"""
xgd1 = Xequat[0] * (
-0.4776303088 * jnp.cos(Xequat[1]) * jnp.cos(Xequat[2])
- 0.1738432154 * jnp.sin(Xequat[1]) * jnp.cos(Xequat[2])
+ 0.8611897727 * jnp.sin(Xequat[2])
)
ygd1 = Xequat[0] * (
0.510844589 * jnp.cos(Xequat[1]) * jnp.cos(Xequat[2])
- 0.8524449229 * jnp.sin(Xequat[1]) * jnp.cos(Xequat[2])
+ 0.111245042 * jnp.sin(Xequat[2])
)
zgd1 = Xequat[0] * (
0.7147776536 * jnp.cos(Xequat[1]) * jnp.cos(Xequat[2])
+ 0.4930681392 * jnp.sin(Xequat[1]) * jnp.cos(Xequat[2])
+ 0.4959603976 * jnp.sin(Xequat[2])
)
return jnp.array([xgd1, ygd1, zgd1])
[docs]
@jax.jit
def gd1cart_to_gd1(Xgd1cart: jnp.ndarray) -> jnp.ndarray:
"""
Conversion from cartesian GD1 co-ordinates to angular GD1 co-ordinates
Args:
Xgd1cart: 3d position (x_gd1 [kpc], y_gd1 [kpc], z_gd1 [kpc]) in cartesian GD1 frame
Returns:
3d position (r [kpc], phi1 [rad], phi2 [rad]) in angular GD1 frame
Examples
--------
>>> gd1cart_to_gd1(jnp.array([1.0, 2.0, 3.0]))
"""
r = jnp.linalg.norm(Xgd1cart)
phi1 = jnp.arctan2(Xgd1cart[1], Xgd1cart[0])
phi2 = jnp.arcsin(Xgd1cart[2] / r)
return jnp.array([r, phi1, phi2])
[docs]
@jax.jit
def halo_to_gd1(Xhalo: jnp.ndarray) -> jnp.ndarray:
"""
Composed conversion from simulation frame co-ordinates to angular GD1 co-ordinates
Args:
Xhalo: 3d position (x [kpc], y [kpc], z [kpc]) in simulation frame
Returns:
3d position (r [kpc], phi1 [rad], phi2 [rad]) in angular GD1 frame
Examples
--------
>>> halo_to_gd1(jnp.array([1.0, 2.0, 3.0]))
"""
Xsun = halo_to_sun(Xhalo)
Xgal = sun_to_gal(Xsun)
Xequat = gal_to_equat(Xgal)
Xgd1cart = equat_to_gd1cart(Xequat)
Xgd1 = gd1cart_to_gd1(Xgd1cart)
return Xgd1
jacobian_halo_to_gd1 = jax.jit(
jax.jacfwd(halo_to_gd1)
) # Jacobian for computing the velocity transformation from simulation frame to angular GD1 co-ordinates
halo_to_gd1_vmap = jax.jit(
jax.vmap(halo_to_gd1, (0,))
) # Vectorised version of co-ordinate transformation from simulation frame to angular GD1 co-ordinates
[docs]
@jax.jit
def equat_to_gd1(Xequat: jnp.ndarray) -> jnp.ndarray:
"""
Composed conversion from equatorial frame co-ordinates to angular GD1 co-ordinates
Args:
Xhalo: 3d position (x [kpc], y [kpc], z [kpc]) in simulation frame
Returns:
3d position (r [kpc], phi1 [rad], phi2 [rad]) in angular GD1 frame
Examples
--------
>>> equat_to_gd1(jnp.array([1.0, 2.0, 3.0]))
"""
Xgd1cart = equat_to_gd1cart(Xequat)
Xgd1 = gd1cart_to_gd1(Xgd1cart)
return Xgd1
jacobian_equat_to_gd1 = jax.jit(
jax.jacfwd(equat_to_gd1)
) # Jacobian for computing the velocity transformation from equatorial frame to angular GD1 co-ordinates
[docs]
@jax.jit
def equat_to_gd1_velocity(Xequat: jnp.ndarray, Vequat: jnp.ndarray) -> jnp.ndarray:
"""
Velocity conversion from equatorial frame co-ordinates to angular GD1 co-ordinates
Args:
Xequat: 3d position (r [kpc], alpha [rad], delta [rad]) in equatorial frame
Vequat: 3d velocity (v_r [kpc/Myr], v_alpha [rad/Myr], v_delta [rad/Myr]) in equatorial frame
Returns:
3d velocity (v_r [kpc/Myr], v_phi1 [rad/Myr], v_phi2 [rad/Myr]) in angular GD1 frame
Examples
--------
>>> equat_to_gd1_velocity(jnp.array([1.0, 2.0, 3.0]), jnp.array([1.0, 2.0, 3.0]))
"""
return jnp.matmul(jacobian_equat_to_gd1(Xequat), Vequat)
[docs]
@jax.jit
def halo_to_gd1_velocity(Xhalo: jnp.ndarray, Vhalo: jnp.ndarray) -> jnp.ndarray:
"""
Velocity conversion from equatorial frame co-ordinates to angular GD1 co-ordinates
Args:
Xhalo: 3d position (x [kpc], y [kpc], z [kpc]) in simulation frame
Vhalo: 3d velocity (v_x [kpc/Myr], v_y [kpc/Myr], v_z [kpc/Myr]) in simulation frame
Returns:
3d velocity (v_r [kpc/Myr], v_phi1 [rad/Myr], v_phi2 [rad/Myr]) in angular GD1 frame
Examples
--------
>>> halo_to_gd1_velocity(jnp.array([1.0, 2.0, 3.0]), jnp.array([1.0, 2.0, 3.0]))
"""
return jnp.matmul(jacobian_halo_to_gd1(Xhalo), Vhalo)
halo_to_gd1_velocity_vmap = jax.jit(
jax.vmap(halo_to_gd1_velocity, (0, 0))
) # Vectorised version of velocity co-ordinate transformation from simulation frame to angular GD1 co-ordinates
[docs]
@jax.jit
def halo_to_gd1_all(Xhalo: jnp.ndarray, Vhalo: jnp.ndarray) -> jnp.ndarray:
"""
Position and Velocity conversion from equatorial frame co-ordinates to angular GD1 co-ordinates
Args:
Xhalo: 3d position (x [kpc], y [kpc], z [kpc]) in simulation frame
Vhalo: 3d velocity (v_x [kpc/Myr], v_y [kpc/Myr], v_z [kpc/Myr]) in simulation frame
Returns:
6d phase space (x [kpc], y [kpc], z[kpv], v_r [kpc/Myr], v_phi1 [rad/Myr], v_phi2 [rad/Myr]) in angular GD1 frame
Examples
--------
>>> halo_to_gd1_all(jnp.array([1.0, 2.0, 3.0]), jnp.array([1.0, 2.0, 3.0]))
"""
return jnp.concatenate((halo_to_gd1(Xhalo), halo_to_gd1_velocity(Xhalo, Vhalo)))
gd1_projection_vmap = jax.jit(
jax.vmap(halo_to_gd1_all, (0, 0))
) # Vectorised version of position and velocity co-ordinate transformation from simulation frame to angular GD1 co-ordinates
[docs]
@partial(jax.jit, static_argnames=['code_units'])
def projection_on_GD1(final_state, code_units: CodeUnits) -> jnp.ndarray:
final_positions, final_velocities = final_state[:, 0], final_state[:, 1]
final_positions = final_positions * code_units.code_length.to(u.kpc)
final_velocities = final_velocities * code_units.code_velocity.to(u.kpc / u.Myr)
#first map on GD1 stream, needs kpc and kpc/Myr units
gd1_positions = halo_to_gd1_vmap(final_positions) # R, phi1, ph2
gd1_velocities = halo_to_gd1_velocity_vmap(final_positions, final_velocities) #v_r, v_phi1, v_phi2
#convert to sensible units
gd1_velocities = gd1_velocities.at[:, 0].set(gd1_velocities[:, 0] * (u.kpc/u.Myr).to(u.km/u.s) ) #v_r in km/s
gd1_velocities = gd1_velocities.at[:, 1].set(gd1_velocities[:, 1]/gd1_positions[:, 0] * 2.0626480624709636e8 / 1e6) #mas/yr $v_{\phi_1}\cos(\phi_2)$
gd1_velocities = gd1_velocities.at[:, 2].set(gd1_velocities[:, 2]/gd1_positions[:, 0] * 2.0626480624709636e8 / 1e6) #mas/yr
gd1_positions = gd1_positions.at[:, 1].set(jnp.rad2deg(gd1_positions[:, 1])) #phi1 in degrees
gd1_positions = gd1_positions.at[:, 2].set(jnp.rad2deg(gd1_positions[:, 2])) #phi2 in degrees
return jnp.concatenate((gd1_positions, gd1_velocities), axis=1)