Source code for odisseo.integrators

from typing import Optional, Tuple, Callable, Union, List, NamedTuple
from functools import partial
from jaxtyping import jaxtyped, Array, Float, Scalar
from beartype import beartype as typechecker

import jax
import jax.numpy as jnp
from jax import vmap, jit
from jax import random
from odisseo.potentials import combined_external_acceleration_vmpa_switch
from odisseo.dynamics import direct_acc, direct_acc_laxmap, direct_acc_matrix, direct_acc_for_loop, direct_acc_sharding, no_self_gravity

from odisseo.option_classes import DIRECT_ACC, DIRECT_ACC_LAXMAP, DIRECT_ACC_MATRIX, DIRECT_ACC_FOR_LOOP, DIRECT_ACC_SHARDING, NO_SELF_GRAVITY
from odisseo.option_classes import SimulationConfig, SimulationParams
from odisseo.option_classes import DOPRI5, TSIT5, DOPRI8, SEMIIMPLICITEULER, REVERSIBLEHEUN, LEAPFROGMIDPOINT
from odisseo.option_classes import RECURSIVECHECKPOINTADJOING, FORWARDMODE

from diffrax import diffeqsolve, ODETerm, SaveAt
from diffrax import Tsit5, Dopri5, Dopri8
from diffrax import SemiImplicitEuler, ReversibleHeun, LeapfrogMidpoint
import diffrax


[docs] @partial(jax.jit, static_argnames=['config']) @jaxtyped(typechecker=typechecker) def leapfrog(state: jnp.ndarray, mass: jnp.ndarray, dt: Scalar, config: SimulationConfig, params: SimulationParams): """Simple implementation of a symplectic Leapfrog (Verlet) integrator for N-body simulations. Args: state (jax.numpy.ndarray): The state of the particles, where the first column represents positions and the second column represents velocities. mass (jax.numpy.ndarray): The mass of the particles. dt (float): Time-step for current integration. config (object): Configuration object containing the acceleration scheme and external accelerations. params (dict): Additional parameters for the acceleration functions. Returns: jax.numpy.ndarray: The updated state of the particles. """ if config.acceleration_scheme == DIRECT_ACC: acc_func = direct_acc elif config.acceleration_scheme == DIRECT_ACC_LAXMAP: acc_func = direct_acc_laxmap elif config.acceleration_scheme == DIRECT_ACC_MATRIX: acc_func = direct_acc_matrix elif config.acceleration_scheme == DIRECT_ACC_FOR_LOOP: acc_func = direct_acc_for_loop elif config.acceleration_scheme == DIRECT_ACC_SHARDING: acc_func = direct_acc_sharding elif config.acceleration_scheme == NO_SELF_GRAVITY: acc_func = no_self_gravity add_external_acceleration = len(config.external_accelerations) > 0 acc = acc_func(state, mass, config, params) # Check additional accelerations if add_external_acceleration: acc = acc + combined_external_acceleration_vmpa_switch(state, config, params) # removing half-step velocity state = state.at[:, 0].set(state[:, 0] + state[:, 1]*dt + 0.5*acc*(dt**2)) acc2 = acc_func(state, mass, config, params) if add_external_acceleration: acc2 = acc2 + combined_external_acceleration_vmpa_switch(state, config, params) state = state.at[:, 1].set(state[:, 1] + 0.5*(acc + acc2)*dt) return state
[docs] @partial(jax.jit, static_argnames=['config']) @jaxtyped(typechecker=typechecker) def RungeKutta4(state: jnp.ndarray, mass: jnp.ndarray, dt: Scalar, config: SimulationConfig, params: SimulationParams): """Simple implementation of a 4th order Runge-Kutta integrator for N-body simulations. Args: state (jax.numpy.ndarray): The state of the particles, where the first column represents positions and the second column represents velocities. mass (jax.numpy.ndarray): The mass of the particles. dt (float): Time-step for current integration. config (object): Configuration object containing the acceleration scheme and external accelerations. params (dict): Additional parameters for the acceleration functions. Returns: jax.numpy.ndarray: The updated state of the particles. """ if config.acceleration_scheme == DIRECT_ACC: acc_func = direct_acc elif config.acceleration_scheme == DIRECT_ACC_LAXMAP: acc_func = direct_acc_laxmap elif config.acceleration_scheme == DIRECT_ACC_MATRIX: acc_func = direct_acc_matrix elif config.acceleration_scheme == DIRECT_ACC_FOR_LOOP: acc_func = direct_acc_for_loop elif config.acceleration_scheme == DIRECT_ACC_SHARDING: acc_func = direct_acc_sharding elif config.acceleration_scheme == NO_SELF_GRAVITY: acc_func = no_self_gravity add_external_acceleration = len(config.external_accelerations) > 0 k1r = state[:, 1] * dt k1v = acc_func(state, mass, config, params) * dt state_2 = state.copy() state_2 = state_2.at[:, 0].set(state[:, 0] + 0.5*k1r) acc2 = acc_func(state_2, mass, config, params) if add_external_acceleration: acc2 = acc2 + combined_external_acceleration_vmpa_switch(state, config, params) k2r = (state[:, 1] + 0.5*k1v) * dt k2v = acc2 * dt state_3 = state.copy() state_3 = state_3.at[:, 0].set(state[:, 0] + 0.5*k2r) acc3 = acc_func(state_3, mass, config, params) if add_external_acceleration: acc3 = acc3 + combined_external_acceleration_vmpa_switch(state, config, params) k3r = (state[:, 1] + 0.5*k2v) * dt k3v = acc3 * dt state_4 = state.copy() state_4 = state_4.at[:, 0].set(state[:, 0] + k3r) acc4 = acc_func(state_4, mass, config, params) if add_external_acceleration: acc4 = acc4 + combined_external_acceleration_vmpa_switch(state, config, params) k4r = (state[:, 1] + k3v) * dt k4v = acc4 * dt state = state.at[:, 0].set(state[:, 0] + (k1r + 2*k2r + 2*k3r + k4r)/6) state = state.at[:, 1].set(state[:, 1] + (k1v + 2*k2v + 2*k3v + k4v)/6) return state
[docs] @partial(jax.jit, static_argnames=['config']) @jaxtyped(typechecker=typechecker) def diffrax_solver(state: jnp.ndarray, mass: jnp.ndarray, dt: Scalar, config: SimulationConfig, params: SimulationParams,) -> jnp.ndarray: """Diffrax backhend Args: state (jax.numpy.ndarray): The state of the particles, where the first column represents positions and the second column represents velocities. mass (jax.numpy.ndarray): The mass of the particles. dt (float): Time-step for current integration. config (object): Configuration object containing the acceleration scheme and external accelerations. params (dict): Additional parameters for the acceleration functions. Returns: jax.numpy.ndarray: The updated state of the particles. """ def vector_field(t, y, args): """ Vector field function for the ODE solver. Args: t (float): Time variable. y (jax.numpy.ndarray): State vector. args (tuple): Additional arguments for the acceleration function. Returns: jax.numpy.ndarray: The updated state of the particles. """ pos_x, pos_y, pos_z, vel_x, vel_y, vel_z = y # Unpack the args mass = args positions = jnp.stack((pos_x, pos_y, pos_z), axis=1) velocities = jnp.stack((vel_x, vel_y, vel_z), axis=1) state = jnp.stack((positions, velocities), axis=1) d_xpos = vel_x d_ypos = vel_y d_zpos = vel_z d_vx = acc_func(state, mass, config, params)[:, 0] + external_acc_func(state, config, params)[:, 0] d_vy = acc_func(state, mass, config, params)[:, 1] + external_acc_func(state, config, params)[:, 1] d_vz = acc_func(state, mass, config, params)[:, 2] + external_acc_func(state, config, params)[:, 2] d_y = jnp.array([d_xpos, d_ypos, d_zpos, d_vx, d_vy, d_vz]) return d_y def f(t, y, args): """ Vector field for the transform of positions """ return y def g(t, y, args): """ Vector field for the transform of velocities args is the mass """ state = jnp.zeros((config.N_particles, 2, 3)) state = state.at[:, 0].set(y) return acc_func(state, args, config, params) + external_acc_func(state, config, params) if config.acceleration_scheme == DIRECT_ACC: acc_func = direct_acc elif config.acceleration_scheme == DIRECT_ACC_LAXMAP: acc_func = direct_acc_laxmap elif config.acceleration_scheme == DIRECT_ACC_MATRIX: acc_func = direct_acc_matrix elif config.acceleration_scheme == DIRECT_ACC_FOR_LOOP: acc_func = direct_acc_for_loop elif config.acceleration_scheme == DIRECT_ACC_SHARDING: acc_func = direct_acc_sharding elif config.acceleration_scheme == NO_SELF_GRAVITY: acc_func = no_self_gravity add_external_acceleration = len(config.external_accelerations) > 0 if add_external_acceleration: external_acc_func = combined_external_acceleration_vmpa_switch else: external_acc_func = lambda state, config, params: jnp.zeros_like(state[:, 0]) if config.diffrax_solver == DOPRI5: solver = Dopri5() term = ODETerm(vector_field) elif config.diffrax_solver == TSIT5: solver = Tsit5() term = ODETerm(vector_field) elif config.diffrax_solver == DOPRI8: solver = Dopri8() term = ODETerm(vector_field) # Symplectic methods elif config.diffrax_solver == SEMIIMPLICITEULER: solver = SemiImplicitEuler() term = (ODETerm(f), ODETerm(g)) elif config.diffrax_solver == REVERSIBLEHEUN: solver = ReversibleHeun() term = ODETerm(vector_field) elif config.diffrax_solver == LEAPFROGMIDPOINT: solver = LeapfrogMidpoint() term = ODETerm(vector_field) #adjoint method if config.diffrax_adjoint_method == RECURSIVECHECKPOINTADJOING: adjoint = diffrax.RecursiveCheckpointAdjoint() elif config.diffrax_adjoint_method == FORWARDMODE: adjoint = diffrax.ForwardMode() if config.diffrax_solver != SEMIIMPLICITEULER: t0 = 0.0 dt0 = dt # t1 = dt #in the fixed number of timesteps case we want to integrate only one step t1 = jnp.where(config.fixed_timestep, dt, params.t_end) y0 = jnp.array([state[:, 0, 0], state[:, 0, 1], state[:, 0, 2], state[:, 1, 0], state[:, 1, 1], state[:, 1, 2]]) args = mass if config.return_snapshots: if not config.fixed_timestep: saveat = SaveAt(ts=jnp.linspace(0, t1, config.num_snapshots, endpoint=True), t1=False) #we put t1=False to avoid duplication of the last snapshot else: saveat = SaveAt(t1=True) else: saveat = SaveAt(t1=True) sol = diffeqsolve( terms = term, solver = solver, t0 = t0, t1 = t1, dt0 = dt0, saveat = saveat, adjoint = adjoint, y0 = y0, max_steps = 100_000, args=args,) if config.return_snapshots: pos = jnp.stack((sol.ys[:,0,:], sol.ys[:,1,:], sol.ys[:,2,:]), axis=2) vel = jnp.stack((sol.ys[:,3,:], sol.ys[:,4,:], sol.ys[:,5,:]), axis=2) return jnp.stack((pos, vel), axis=2) else: pos = jnp.stack((sol.ys[0][0], sol.ys[0][1], sol.ys[0][2]), axis=1) vel = jnp.stack((sol.ys[0][3], sol.ys[0][4], sol.ys[0][5]), axis=1) else: t0 = 0.0 dt0 = dt t1 = dt #in the fixed number of timesteps case we want to integrate only one step y0 = jnp.array([state[:, 0, 0], state[:, 0, 1], state[:, 0, 2]]), jnp.array([state[:, 1, 0], state[:, 1, 1], state[:, 1, 2]]) args = mass sol = diffeqsolve( terms = term, solver = solver, t0 = t0, t1 = t1, dt0 = dt0, saveat = saveat, y0 = y0, adjoint = diffrax.RecursiveCheckpointAdjoint(), args=args,) pos = jnp.stack((sol.ys[0][0], sol.ys[0][1], sol.ys[0][2]), axis=1) vel = jnp.stack((sol.ys[1][0], sol.ys[1][1], sol.ys[1][2]), axis=1) return jnp.stack((pos, vel), axis=1)