from typing import Optional, Tuple, Callable, Union, List, NamedTuple
from functools import partial
from jaxtyping import jaxtyped
from beartype import beartype as typechecker
import jax
import jax.numpy as jnp
from jax import vmap, jit, lax
from jax import random
import jax.scipy.special as jsp
from scipy.special import j0, j1, jv
import numpy as np
import scipy
from odisseo.option_classes import SimulationConfig, SimulationParams
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def combined_external_acceleration_vmpa_switch(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
Compute the total acceleration of all particles due to all external potentials.
Vectorized way
Args:
state (jnp.ndarray): Array of shape (N_particles,2,3) representing the positions and velocities of the particles.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool): If True, also returns the total potential energy of all external potentials.
Returns:
jnp.ndarray: Total acceleration of all particles due to all external potentials if return_potential is False.
Tuple: Total acceleration and total potential energy of all particles due to all external potentials if return_potential is True.
"""
total_external_acceleration = jnp.zeros_like(state[:, 0])
total_external_potential = jnp.zeros_like(config.N_particles)
state_tobe_vmap = jnp.repeat(state[jnp.newaxis, ...], repeats=len(config.external_accelerations), axis=0)
if return_potential:
# The POTENTIAL_LIST NEEDS TO BE IN THE SAME ORDER AS THE INTEGER VALUES
POTENTIAL_LIST = [lambda state: NFW(state, config=config, params=params, return_potential=True),
lambda state: point_mass(state, config=config, params=params, return_potential=True),
lambda state: MyamotoNagai(state, config=config, params=params, return_potential=True),
lambda state: PowerSphericalPotentialwCutoff(state, config=config, params=params, return_potential=True),
lambda state: logarithmic_potential(state, config=config, params=params, return_potential=True),
lambda state: TriaxialNFW(state, config=config, params=params, return_potential=True),
lambda state: Thin_MN3DiskPotential(state, config=config, params=params, return_potential=True),
lambda state: Thick_MN3DiskPotential(state, config=config, params=params, return_potential=True),
lambda state: TwoPowerTriaxialPotential(state, config=config, params=params, return_potential=True),
]
vmap_function = vmap(lambda i, state: lax.switch(i, POTENTIAL_LIST, state))
external_acc, external_pot = vmap_function(jnp.array(config.external_accelerations), state_tobe_vmap)
total_external_acceleration = jnp.sum(external_acc, axis=0)
total_external_potential = jnp.sum(external_pot, axis=0)
return total_external_acceleration, total_external_potential
else:
POTENTIAL_LIST = [lambda state: NFW(state, config=config, params=params, return_potential=False),
lambda state: point_mass(state, config=config, params=params, return_potential=False),
lambda state: MyamotoNagai(state, config=config, params=params, return_potential=False),
lambda state: PowerSphericalPotentialwCutoff(state, config=config, params=params, return_potential=False),
lambda state: logarithmic_potential(state, config=config, params=params, return_potential=False),
lambda state: TriaxialNFW(state, config=config, params=params, return_potential=False),
lambda state: Thin_MN3DiskPotential(state, config=config, params=params, return_potential=False),
lambda state: Thick_MN3DiskPotential(state, config=config, params=params, return_potential=False),
lambda state: TwoPowerTriaxialPotential(state, config=config, params=params, return_potential=False),
]
vmap_function = vmap(lambda i, state: lax.switch(i, POTENTIAL_LIST, state))
external_acc = vmap_function(jnp.array(config.external_accelerations), state_tobe_vmap)
total_external_acceleration = jnp.sum(external_acc, axis=0)
return total_external_acceleration
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def NFW(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
Compute acceleration of all particles due to a NFW profile.
Args:
state (jnp.ndarray): Array of shape (N_particles, 2, 3) representing the positions and velocities of the particles.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool, optional): If True, also returns the potential energy of the NFW profile. Defaults to False.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
- Acceleration (jnp.ndarray): Acceleration of all particles due to NFW external potential.
- Potential (jnp.ndarray): Potential energy of all particles due to NFW external potential. Returned only if return_potential is True.
"""
# params_NFW = params.NFW_params
# r = jnp.linalg.norm(state[:, 0], axis=1)
# NUM = (params_NFW.r_s+r)*jnp.log(1+r/params_NFW.r_s) - r
# DEN = r*r*r*(params_NFW.r_s+r)*params_NFW.d_c
# @jit
# def acceleration(state):
# return - params.G * params_NFW.Mvir*NUM[:, jnp.newaxis]/DEN[:, jnp.newaxis] * state[:, 0]
# @jit
# def potential(state):
# return - params.G * params_NFW.Mvir*jnp.log(1+r/params_NFW.r_s)/(r*params_NFW.d_c)
# acc = acceleration(state)
# if return_potential:
# pot = potential(state)
# return acc, pot
# else:
# return acc
params_NFW = params.NFW_params
M = params_NFW.Mvir
r_s = params_NFW.r_s
r = jnp.linalg.norm(state[:, 0], axis=1)
@jit
def potential(r):
r"""Potential for the NFW model.
$$ \Phi(r) = -\frac{G m}{r_s} \frac{r_s}{r} \log(1 + \frac{r}{r_s}) $$
where $m$ is the characteristic mass and $r_s$ is the scale radius.
"""
x = r / r_s
phi0 = -params.G * M / r_s
return phi0 * jnp.log(1 + x) / x
@jit
def mass_enclosed(r):
r"""Enclosed mass for the NFW model.
$$ M(<r) = \frac{m}{\ln(1 + x) - \frac{x}{1 + x}} $$
where $x = r / r_s$ is the dimensionless radius and $m$ is the
characteristic mass.
"""
x = r / r_s
return M * (jnp.log(1 + x) - x / (1 + x))
@jit
def acceleration(r):
return - params.G * mass_enclosed(r)[:, None] * state[:, 0] / (r**3)[:, None]
# @jit
# def acceleration(r):
# rad = jnp.linalg.norm(state[:, 0], axis=1)
# dimless_prefactor = (
# 8**2 * (rad / (r_s + rad) - jnp.log((r_s + rad)/r_s)
# / (rad**2 * (8./ (r_s + 8.)) - jnp.log((r_s+8.)/r_s))
# ))
# direction = (1/rad)[:, None ] * state[:, 0]
# ftot = (0.000001045940172532453 * 220**2 / 8.) * 1
# return - 0.35 * ftot * dimless_prefactor[:, None] * direction
#calculate the acceleration
acc = acceleration(r)
if return_potential:
pot = potential(r)
return acc, pot
else:
return acc
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def point_mass(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
Compute acceleration of all particles due to a point mass potential.
Args:
state (jnp.ndarray): Array of shape (N_particles, 2, 3) representing the positions and velocities of the particles.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool, optional): If True, also returns the potential energy of the point mass potential. Defaults to False.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
- Acceleration (jnp.ndarray): Acceleration of all particles due to point mass external potential.
- Potential (jnp.ndarray): Potential energy of all particles due to point mass external potential. Returned only if return_potential is True.
"""
params_point_mass = params.PointMass_params
r = jnp.linalg.norm(state[:, 0], axis=1)
@jit
def acceleration(state):
return - params.G * params_point_mass.M * state[:, 0] / (r**3)[:, None]
@jit
def potential(r):
return - params.G * params_point_mass.M / r
acc = acceleration(state)
if return_potential:
pot = potential(r)
return acc, pot
else:
return acc
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def MyamotoNagai(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
Compute acceleration of all particles due to a MyamotoNagai disk profile.
Args:
state (jnp.ndarray): Array of shape (N_particles, 2, 3) representing the positions and velocities of the particles.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool, optional): If True, also returns the potential energy of the MyamotoNagai profile. Defaults to False.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
- Acceleration (jnp.ndarray): Acceleration of all particles due to MyamotoNagai external potential.
- Potential (jnp.ndarray): Potential energy of all particles due to MyamotoNagai external potential. Returned only if return_potential is True.
"""
params_MN = params.MN_params
z2 = state[:, 0, 2]**2
b = params_MN.b
a = params_MN.a
M = params_MN.M
Dz = (a+(z2+b**2)**0.5)
D = jnp.linalg.norm(state[:, 0, :2], axis=1)**2 + Dz**2
K = - params.G * params_MN.M / D**(3/2)
@jit
def acceleration(pos):
ax = K * pos[:, 0]
ay = K * pos[:, 1]
az = K * pos[:, 2] * Dz / (z2 + b**2)**0.5
return jnp.stack([ax, ay, az], axis=1)
@jit
def potential(pos):
return - params.G * params_MN.M / jnp.sqrt(D)
pos = state[:, 0]
acc = acceleration(pos)
if return_potential:
pot = potential(pos)
return acc, pot
else:
return acc
[docs]
@partial(jax.jit, static_argnames=['return_potential'])
@jaxtyped(typechecker=typechecker)
def call_MyamotoNagai(state: jnp.ndarray,
M: Union[float, jnp.ndarray],
a: Union[float, jnp.ndarray],
b: Union[float, jnp.ndarray],
params: SimulationParams,
return_potential=False):
"""
Compute acceleration of all particles due to a MyamotoNagai disk profile. It is used as base function for MN3 approximation of douoble exponential disk.
This function exposes directly the a, b and M parameters intstead of calling the params of the simulation
Args:
state (jnp.ndarray): Array of shape (N_particles, 2, 3) representing the positions and velocities of the particles.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool, optional): If True, also returns the potential energy of the MyamotoNagai profile. Defaults to False.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
- Acceleration (jnp.ndarray): Acceleration of all particles due to MyamotoNagai external potential.
- Potential (jnp.ndarray): Potential energy of all particles due to MyamotoNagai external potential. Returned only if return_potential is True.
"""
z2 = state[:, 0, 2]**2
Dz = (a+(z2+b**2)**0.5)
D = jnp.linalg.norm(state[:, 0, :2], axis=1)**2 + Dz**2
K = - params.G * M / D**(3/2)
@jit
def acceleration(pos):
ax = K * pos[:, 0]
ay = K * pos[:, 1]
az = K * pos[:, 2] * Dz / (z2 + b**2)**0.5
return jnp.stack([ax, ay, az], axis=1)
@jit
def potential(pos):
return - params.G * M / jnp.sqrt(D)
pos = state[:, 0]
acc = acceleration(pos)
if return_potential:
pot = potential(pos)
return acc, pot
else:
return acc
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def PowerSphericalPotentialwCutoff(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
Compute acceleration of all particles due to a power spherical potential with cutoff.
taken from galax: https://github.com/GalacticDynamics/galax/blob/main/src/galax/potential/_src/builtin/powerlawcutoff.py#L35
Args:
state (jnp.ndarray): Array of shape (N_particles, 2, 3) representing the positions and velocities of the particles.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool, optional): If True, also returns the potential energy of the power spherical potential. Defaults to False.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
- Acceleration (jnp.ndarray): Acceleration of all particles due to power spherical external potential.
- Potential (jnp.ndarray): Potential energy of all particles due to power spherical external potential. Returned only if return_potential is True.
"""
@partial(jax.jit)
def _safe_gamma_inc(a, x):
return jax.scipy.special.gammainc(a, x) * jax.scipy.special.gamma(a)
params_PSP = params.PSP_params
M = params_PSP.M
alpha = params_PSP.alpha
r_c = params_PSP.r_c
pos = state[:, 0]
@jit
def rho(radius):
return (1/radius)**alpha * jnp.exp(-(radius/r_c)**2)
@jit
def potential(pos):
r = jnp.linalg.norm(pos)
a = alpha/2
s2 = (r/r_c)**2
GM = params.G * M
# pot_value = - GM * (
# (a - 1.5) * _safe_gamma_inc(1.5 - 1, s2) / (r * jax.scipy.special.gamma(2.5 - a))
# + _safe_gamma_inc(1 - a, s2) / (r_c * jax.scipy.special.gamma(1.5 - a)))
# return jnp.squeeze(pot_value)
den = jsp.gamma(1.5 -a)
L1 = _safe_gamma_inc(1.5 - a, s2)
L2 = _safe_gamma_inc(1 - a, s2)
pot = -GM / den * ( L1 / r + (jsp.gamma(1 - a) - L2) / r_c )
return jnp.squeeze(pot)
@jit
def acceleration(pos):
return -jax.vmap(jax.grad((potential)))(pos)
# compute the acceleration
acc = acceleration(pos)
if return_potential:
pot = jax.vmap(potential)(pos)
return acc, pot
else:
return acc
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def logarithmic_potential(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
Compute acceleration of all particles due to a logarithmic potential.
Args:
state (jnp.ndarray): Array of shape (N_particles, 2, 3) representing the positions and velocities of the particles.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool, optional): If True, also returns the potential energy of the logarithmic potential. Defaults to False.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
- Acceleration (jnp.ndarray): Acceleration of all particles due to logarithmic external potential.
- Potential (jnp.ndarray): Potential energy of all particles due to logarithmic external potential. Returned only if return_potential is True.
"""
r = jnp.sqrt(state[:, 0, 0]**2 + state[:, 0, 1]**2)
z = state[:, 0, 2]
v2_0 = params.Logarithmic_params.v0**2
q2 = params.Logarithmic_params.q**2
@jit
def potential(state):
return - v2_0/2 * jnp.log(r**2 + (z**2/q2))
@jit
def acceleration(state):
DEN = r**2 + (z**2/q2)
ax = - v2_0 * state[:, 0, 0] / DEN
ay = - v2_0 * state[:, 0, 1] / DEN
az = - v2_0 * z * (1/q2) / DEN
return jnp.stack([ax, ay, az], axis=1)
acc = acceleration(state)
if return_potential:
pot = potential(state)
return acc, pot
else:
return acc
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def TriaxialNFW(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
Compute acceleration of all particles due to a Triaxial NFW profile.
This code is heavily inspired by the implementation in galax: https://github.com/GalacticDynamics/galax/blob/main/src/galax/potential/_src/builtin/nfw/triaxial.py
The density is given by:
rho(xi) = rho_0 / (xi/r_s) / (1 + xi/r_s)^2
where:
xi^2 = x^2 + y^2/q1^2 + z^2/q2^2
Args:
state (jnp.ndarray): Array of shape (N_particles, 2, 3) representing the positions and velocities of the particles.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool, optional): If True, also returns the potential energy. Defaults to False.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]:
- Acceleration (jnp.ndarray): Acceleration of all particles due to Triaxial NFW external potential.
- Potential (jnp.ndarray): Potential energy of all particles. Returned only if return_potential is True.
"""
params_TNFW = params.TriaxialNFW_params
M = params_TNFW.Mvir
r_s = params_TNFW.r_s
q1 = params_TNFW.q1 # y-axis flattening
q2 = params_TNFW.q2 # z-axis flattening
# Gauss-Legendre quadrature (order 50)
integration_order = config.glorder
x_, w_ = np.polynomial.legendre.leggauss(integration_order)
x_gl, w_gl = jnp.asarray(x_, dtype=float), jnp.asarray(w_, dtype=float)
# Change interval from [-1, 1] to [0, 1]
x_gl = 0.5 * (x_gl + 1)
w_gl = 0.5 * w_gl
# Central density: rho_0 = M / (4 * pi * r_s^3)
rho0 = M / (4 * jnp.pi * r_s**3)
q1sq = q1**2
q2sq = q2**2
@jit
def ellipsoid_surface(pos, s2):
"""Compute xi^2 on the ellipsoid surface."""
# xi^2(tau) = x^2/(1+tau) + y^2/(q1^2+tau) + z^2/(q2^2+tau)
# with tau = 1/s^2 - 1, this becomes:
return s2 * (
pos[0]**2
+ pos[1]**2 / (1 + (q1sq - 1) * s2)
+ pos[2]**2 / (1 + (q2sq - 1) * s2)
)
@jit
def potential_single(pos):
"""Compute potential for a single particle."""
def integrand(s):
s2 = s**2
xi = jnp.sqrt(ellipsoid_surface(pos, s2)) / r_s
delta_psi_factor = 2.0 / (1.0 + xi)
denom = jnp.sqrt(((q1sq - 1) * s2 + 1) * ((q2sq - 1) * s2 + 1))
return delta_psi_factor / denom
# Gauss-Legendre integration
integral = jnp.sum(w_gl * vmap(integrand)(x_gl))
return -2.0 * jnp.pi * params.G * rho0 * r_s**2 * q1 * q2 * integral
@jit
def acceleration_single(pos):
"""Compute acceleration for a single particle via gradient of potential."""
return -jax.grad(potential_single)(pos)
pos = state[:, 0]
acc = vmap(acceleration_single)(pos)
if return_potential:
pot = vmap(potential_single)(pos)
return acc, pot
else:
return acc
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def Thin_MN3DiskPotential(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
Compute acceleration and potential of all particles due to a thin disk approximated by 3 Miyamoto-Nagai potentials.
Inspired by: https://gala.adrian.pw/en/latest/_modules/gala/potential/potential/builtin/core.html#MN3ExponentialDiskPotential.
Original paper: `Smith et al. (2015) <https://ui.adsabs.harvard.edu/abs/2015MNRAS.448.2934S/abstract>`
Args:
state (jnp.ndarray): (N_particles, 2, 3) positions and velocities.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool): If True, also returns the potential.
Returns:
jnp.ndarray: Acceleration (N_particles, 3)
jnp.ndarray: Potential (N_particles,) if return_potential is True
"""
params_ThinMN3Disk = params.ThinMN3Disk_params
m = params_ThinMN3Disk.M
h_R = params_ThinMN3Disk.hr
h_z = params_ThinMN3Disk.hz
hzR = h_z / h_R
sech2_z = config.sech2_z
MN3_positive_density = config.MN3_positive_density
_K_pos_dens = jnp.array(
[
[0.0036, -0.0330, 0.1117, -0.1335, 0.1749],
[-0.0131, 0.1090, -0.3035, 0.2921, -5.7976],
[-0.0048, 0.0454, -0.1425, 0.1012, 6.7120],
[-0.0158, 0.0993, -0.2070, -0.7089, 0.6445],
[-0.0319, 0.1514, -0.1279, -0.9325, 2.6836],
[-0.0326, 0.1816, -0.2943, -0.6329, 2.3193],
]
)
_K_neg_dens = jnp.array(
[
[-0.0090, 0.0640, -0.1653, 0.1164, 1.9487],
[0.0173, -0.0903, 0.0877, 0.2029, -1.3077],
[-0.0051, 0.0287, -0.0361, -0.0544, 0.2242],
[-0.0358, 0.2610, -0.6987, -0.1193, 2.0074],
[-0.0830, 0.4992, -0.7967, -1.2966, 4.4441],
[-0.0247, 0.1718, -0.4124, -0.5944, 0.7333],
]
)
K = jnp.where(MN3_positive_density, _K_pos_dens, _K_neg_dens)
b_hR = jnp.where(sech2_z, -0.033 * hzR**3 + 0.262 * hzR**2 + 0.659 * hzR, -0.269 * hzR**3 + 1.08 * hzR**2 + 1.092 * hzR)
x = jnp.vander(jnp.array([b_hR]), N=5)[0]
param_vec = K @ x
_ms = param_vec[:3] * m
_as = param_vec[3:] * h_R
_b = b_hR * h_R
_b = jnp.broadcast_to(_b, _ms.shape) #needed for vmap
c_only = {}
for i in range(3):
c_only[f"m{i + 1}"] = _ms[i]
c_only[f"a{i + 1}"] = _as[i]
c_only[f"b{i + 1}"] = _b
acc_total = jax.vmap(lambda m, a, b: call_MyamotoNagai(state, m, a, b, params, return_potential=False))(
_ms, _as, _b
).sum(axis=0)
if return_potential:
pot_total = jax.vmap(lambda m, a, b: call_MyamotoNagai(state, m, a, b, params, return_potential=True))(
_ms, _as, _b
)[1].sum(axis=0)
return acc_total, pot_total
else:
return acc_total
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def Thick_MN3DiskPotential(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
Compute acceleration and potential of all particles due to a thin disk approximated by 3 Miyamoto-Nagai potentials.
Inspired by: https://gala.adrian.pw/en/latest/_modules/gala/potential/potential/builtin/core.html#MN3ExponentialDiskPotential.
Original paper: `Smith et al. (2015) <https://ui.adsabs.harvard.edu/abs/2015MNRAS.448.2934S/abstract>`
Args:
state (jnp.ndarray): (N_particles, 2, 3) positions and velocities.
config (NamedTuple): Configuration parameters.
params (NamedTuple): Simulation parameters.
return_potential (bool): If True, also returns the potential.
Returns:
jnp.ndarray: Acceleration (N_particles, 3)
jnp.ndarray: Potential (N_particles,) if return_potential is True
"""
params_ThickMN3Disk = params.ThickMN3Disk_params
m = params_ThickMN3Disk.M
h_R = params_ThickMN3Disk.hr
h_z = params_ThickMN3Disk.hz
hzR = h_z / h_R
sech2_z = config.sech2_z
MN3_positive_density = config.MN3_positive_density
_K_pos_dens = jnp.array(
[
[0.0036, -0.0330, 0.1117, -0.1335, 0.1749],
[-0.0131, 0.1090, -0.3035, 0.2921, -5.7976],
[-0.0048, 0.0454, -0.1425, 0.1012, 6.7120],
[-0.0158, 0.0993, -0.2070, -0.7089, 0.6445],
[-0.0319, 0.1514, -0.1279, -0.9325, 2.6836],
[-0.0326, 0.1816, -0.2943, -0.6329, 2.3193],
]
)
_K_neg_dens = jnp.array(
[
[-0.0090, 0.0640, -0.1653, 0.1164, 1.9487],
[0.0173, -0.0903, 0.0877, 0.2029, -1.3077],
[-0.0051, 0.0287, -0.0361, -0.0544, 0.2242],
[-0.0358, 0.2610, -0.6987, -0.1193, 2.0074],
[-0.0830, 0.4992, -0.7967, -1.2966, 4.4441],
[-0.0247, 0.1718, -0.4124, -0.5944, 0.7333],
]
)
K = jnp.where(MN3_positive_density, _K_pos_dens, _K_neg_dens)
b_hR = jnp.where(sech2_z, -0.033 * hzR**3 + 0.262 * hzR**2 + 0.659 * hzR, -0.269 * hzR**3 + 1.08 * hzR**2 + 1.092 * hzR)
x = jnp.vander(jnp.array([b_hR]), N=5)[0]
param_vec = K @ x
_ms = param_vec[:3] * m
_as = param_vec[3:] * h_R
_b = b_hR * h_R
_b = jnp.broadcast_to(_b, _ms.shape)
acc_total = jax.vmap(lambda m, a, b: call_MyamotoNagai(state, m, a, b, params, return_potential=False))(
_ms, _as, _b
).sum(axis=0)
if return_potential:
pot_total = jax.vmap(lambda m, a, b: call_MyamotoNagai(state, m, a, b, params, return_potential=True))(
_ms, _as, _b
)[1].sum(axis=0)
return acc_total, pot_total
else:
return acc_total
[docs]
@partial(jax.jit, static_argnames=['config', 'return_potential'])
@jaxtyped(typechecker=typechecker)
def TwoPowerTriaxialPotential(state: jnp.ndarray,
config: SimulationConfig,
params: SimulationParams,
return_potential=False):
"""
General triaxial two-power-law potential:
rho(x,y,z) = amp/(4*pi*a^3) * 1/(m/a)^alpha * 1/(1+m/a)^(beta-alpha)
m^2 = x^2 + y^2/b^2 + z^2/c^2
Args:
state: (N_particles, 2, 3) positions and velocities.
config: Configuration parameters.
params: Simulation parameters.
return_potential: If True, also returns the potential.
Returns:
acc: (N_particles, 3) acceleration
pot: (N_particles,) potential (if return_potential)
"""
p = params.TwoPowerTriaxial_params
rho = p.rho
a = p.a
alpha = p.alpha
beta = p.beta
b = p.b
c = p.c
b2 = b**2
c2 = c**2
# Gauss-Legendre quadrature (order 50)
glorder = config.glorder
x_, w_ = np.polynomial.legendre.leggauss(glorder)
x_gl, w_gl = jnp.asarray(x_, dtype=float), jnp.asarray(w_, dtype=float)
x_gl = 0.5 * (x_gl + 1)
w_gl = 0.5 * w_gl
# Normalization
# norm = params.G * rho / (4 * jnp.pi * a**3) # we use directly the normalization rho
# norm = params.G *rho
norm = params.G * rho
def safe_hyp2f1(a, b, c, z):
# Transformation: z_new = z / (z - 1)
# This maps z in (-inf, -1] to z_new in [0.5, 1)
z_new = z / (z - 1.0)
transformed_val = jnp.pow(1.0 - z, -a) * jsp.hyp2f1(a, c - b, c, z_new)
# Use jnp.where to choose the transformation only when z < -0.9
# (We use -0.9 to stay well away from the boundary of the unit circle)
return jnp.where(z < -0.9, transformed_val, jsp.hyp2f1(a, b, c, z))
@jit
def mfunc(pos):
return jnp.sqrt(pos[0]**2 + pos[1]**2 / b2 + pos[2]**2 / c2)
@jit
def _psi_inf():
# psi_inf = gamma(beta-2) * gamma(3-alpha) / gamma(beta-alpha)
return jsp.gamma(beta - 2.0) * jsp.gamma(3.0 - alpha) / jsp.gamma(beta - alpha)
psi_inf = _psi_inf()
twominusalpha = 2.0 - alpha
threeminusalpha = 3.0 - alpha
betaminusalpha = beta - alpha
@jit
def psi(m):
# See galpy: _psi
# If twominusalpha == 0:
# -2 a^2 (a/m)^(beta-alpha) / (beta-alpha) * hyp2f1(b-a, b-a, b-a+1, -a/m)
# else:
# -2 a^2 [psi_inf - (m/a)^(2-alpha)/(2-alpha) * hyp2f1(2-alpha, beta-alpha, 3-alpha, -m/a)]
# val_z = -m / a
# # This will print every time the JIT-compiled function is executed
# jax.debug.print("Current m: {m_val}, Argument z: {z_val}", m_val=m, z_val=val_z)
# res = jsp.hyp2f1(twominusalpha, betaminusalpha, threeminusalpha, val_z)
# # Check for NaNs or Infs
# jax.debug.print("hyp2f1 result: {r}", r=res)
def branch():
# return -2.0 * a**2 * (a / m) ** betaminusalpha / betaminusalpha * jsp.hyp2f1(
# betaminusalpha, betaminusalpha, betaminusalpha + 1, -a / m
# )
return (
-2.0
* a**2
* (a / m) ** betaminusalpha
/ betaminusalpha
# * jsp.hyp2f1(
* safe_hyp2f1(
betaminusalpha,
betaminusalpha,
betaminusalpha + 1,
-a / m,
)
)
def main():
# return -2.0 * a**2 * (
# psi_inf
# - (m / a) ** twominusalpha / twominusalpha * jsp.hyp2f1(
# twominusalpha, betaminusalpha, threeminusalpha, -m / a
# )
# )
return (
-2.0
* a**2
* (
psi_inf
- (m / a) ** twominusalpha
/ twominusalpha
# * jsp.hyp2f1(
* safe_hyp2f1(
twominusalpha,
betaminusalpha,
threeminusalpha,
-m / a,
)
)
)
return jax.lax.cond(jnp.abs(twominusalpha) < 1e-10, branch, main)
@jit
def dens(m):
return (a / m) ** alpha / (1.0 + m / a) ** betaminusalpha
@jit
def force_integral(pos, i):
# Integrate over s in [0,1]
def integrand(s):
t = 1.0 / s**2 - 1.0
m = jnp.sqrt(
pos[0] ** 2 / (1.0 + t)
+ pos[1] ** 2 / (b2 + t)
+ pos[2] ** 2 / (c2 + t)
)
numer = (
pos[0] / (1.0 + t) * (i == 0)
+ pos[1] / (b2 + t) * (i == 1)
+ pos[2] / (c2 + t) * (i == 2)
)
denom = jnp.sqrt((1.0 + (b2 - 1.0) * s**2) * (1.0 + (c2 - 1.0) * s**2))
return dens(m) * numer / denom
return jnp.sum(w_gl * jax.vmap(integrand)(x_gl))
@jit
def potential_integral(pos):
def integrand(s):
t = 1.0 / s**2 - 1.0
# m = jnp.sqrt(
# pos[0] ** 2 / (1.0 + t)
# + pos[1] ** 2 / (b2 + t)
# + pos[2] ** 2 / (c2 + t)
# )
# denom = jnp.sqrt((1.0 + (b2 - 1.0) * s**2) * (1.0 + (c2 - 1.0) * s**2))
# return psi(m) / denom
return psi(
jnp.sqrt(pos[0]**2.0 / (1.0 + t) + pos[1]**2.0 / (b2 + t) + pos[2]**2.0 / (c2 + t))
) / jnp.sqrt((1.0 + (b2 - 1.0) * s**2.0) * (1.0 + (c2 - 1.0) * s**2.0))
return jnp.sum(w_gl * jax.vmap(integrand)(x_gl))
@jit
def acc_and_pot_single(pos):
acc = -4.0 * jnp.pi * b * c * norm * jnp.array( #is norm correect here?
# acc = -4.0 * jnp.pi * b * c * jnp.array(
[force_integral(pos, 0), force_integral(pos, 1), force_integral(pos, 2)]
)
pot = 2.0 * jnp.pi * b * c * norm * potential_integral(pos) #is norm correect here?
# pot = 2.0 * jnp.pi * b * c * potential_integral(pos)
# pot = potential_integral(pos)
return acc, pot
pos = state[:, 0]
acc, pot = jax.vmap(acc_and_pot_single)(pos)
if return_potential:
return acc, pot
else:
return acc