Source code for odisseo.initial_condition

from beartype.typing import Optional, Tuple, Callable, Union, List, NamedTuple
from functools import partial
from jaxtyping import jaxtyped, PRNGKeyArray
from beartype import beartype as typechecker

import jax
import jax.numpy as jnp
from jax import vmap, random
import numpy as np
from multiprocessing import Pool

from odisseo.option_classes import SimulationConfig, SimulationParams


[docs] @partial(jax.jit, static_argnames=['config']) @jaxtyped(typechecker=typechecker) def Plummer_sphere(key: PRNGKeyArray, config: SimulationConfig, params: SimulationParams) -> Tuple: """ Create initial conditions for a Plummer sphere. The sampling of velocities is done by inverse fitting the cumulative distribution function of the Plummer sphere. Args: key (jax.random.PRNGKey): Random key. config (NamedTuple): Configuration NamedTuple containing the number of particles (N_particles). params (NamedTuple): Parameters NamedTuple containing: Returns: tuple: A tuple containing: positions (jnp.array): Array of shape (N_particles, 3) representing the positions of the particles. velocities (jnp.array): Array of shape (N_particles, 3) representing the velocities of the particles. masses (jnp.array): Array of shape (N_particles,) representing the masses of the particles. """ key_r, key_phi, key_sin_i, key_u, key_phi_v, key_sin_i_v= random.split(key, 6) r = jnp.sqrt( params.Plummer_params.a**2 / (random.uniform(key=key_r, shape=(config.N_particles,))**(-2/3) -1 ) ) phi = random.uniform(key=key_phi, shape=(config.N_particles,), minval=0, maxval=2*jnp.pi) sin_i = random.uniform(key=key_sin_i, shape=(config.N_particles,), minval=-1, maxval=1) positions = jnp.array([r*jnp.cos(jnp.arcsin(sin_i))*jnp.cos(phi), r*jnp.cos(jnp.arcsin(sin_i))*jnp.sin(phi), r*sin_i]).T potential = - params.G * params.Plummer_params.Mtot / jnp.sqrt( jnp.linalg.norm(positions, axis=1)**2 + params.Plummer_params.a**2) velocities_escape = jnp.sqrt(-2*potential ) def G(q): """ Normalize Cumulative distribution function of q=v/v_escape for a Plummer sphere. The assosiate unormalized probability distribution function assosiated with it is g(q) = (1-q)**(7/2) * q**2 Args: q: (float) Velocity ratio v/v_escape. Returns: float: Normalized cumulative distribution function. """ # return 1287/16 * ((-2*(1-q)**(9/2))*(99*q**2+36*q+8)/1287 +16/1287) return 1/(jnp.pi*7/512) * (q*jnp.sqrt(1 - q**2)*(-384*q**8 + 1488*q**6 - 2104*q**4 + 1210*q**2 - 105) + 105*jnp.asin(q))/3840 # Invere fitting q = jnp.linspace(0, 1, 500) y = G(q) u = random.uniform(key=key_u, shape=(config.N_particles,)) samples = jnp.interp(u, y, q) velocities_modulus = samples * velocities_escape # Generate random angles for the velocity phi_v = random.uniform(key=key_phi_v, shape=(config.N_particles,), minval=0, maxval=2*jnp.pi) sin_i_v = random.uniform(key=key_sin_i_v, shape=(config.N_particles,), minval=-1, maxval=1) velocities = velocities_modulus[:, None]*jnp.array([jnp.cos(jnp.arcsin(sin_i_v))*jnp.cos(phi_v), jnp.cos(jnp.arcsin(sin_i_v))*jnp.sin(phi_v), sin_i_v]).T # return jnp.array(positions), jnp.array(velocities), params.Plummer_params.Mtot/config.N_particles*jnp.ones(config.N_particles) return jnp.array(positions), jnp.array(velocities), 1/config.N_particles*jnp.ones(config.N_particles)
[docs] @partial(jax.jit, static_argnames=['config']) @jaxtyped(typechecker=typechecker) def Plummer_sphere_reparam(noise: jnp.ndarray, config: SimulationConfig, params: SimulationParams) -> Tuple: """ Reparameterized Plummer sphere generation. Args: noise (jnp.ndarray): Pre-sampled uniform random numbers of shape (N_particles, 6) where each column corresponds to: [0]: radial sampling [1]: position azimuthal angle [2]: position polar angle [3]: velocity magnitude sampling [4]: velocity azimuthal angle [5]: velocity polar angle config (SimulationConfig): Configuration containing N_particles params (SimulationParams): Parameters containing Plummer_params, G Returns: tuple: Same as original Plummer_sphere function """ # Extract the 6 noise components (each should be uniform [0,1]) noise_r = noise[:, 0] noise_phi = noise[:, 1] noise_sin_i = noise[:, 2] noise_u = noise[:, 3] noise_phi_v = noise[:, 4] noise_sin_i_v = noise[:, 5] # Position sampling (deterministic transformations of noise) r = jnp.sqrt(params.Plummer_params.a**2 / (noise_r**(-2/3) - 1)) phi = noise_phi * 2 * jnp.pi # Scale [0,1] to [0, 2π] sin_i = 2 * noise_sin_i - 1 # Scale [0,1] to [-1, 1] positions = jnp.array([r*jnp.cos(jnp.arcsin(sin_i))*jnp.cos(phi), r*jnp.cos(jnp.arcsin(sin_i))*jnp.sin(phi), r*sin_i]).T potential = -params.G * params.Plummer_params.Mtot / jnp.sqrt(jnp.linalg.norm(positions, axis=1)**2 + params.Plummer_params.a**2) velocities_escape = jnp.sqrt(-2*potential) def G(q): """Same as your original G function""" return 1/(jnp.pi*7/512) * (q*jnp.sqrt(1 - q**2)*(-384*q**8 + 1488*q**6 - 2104*q**4 + 1210*q**2 - 105) + 105*jnp.asin(q))/3840 # Inverse fitting (unchanged) q = jnp.linspace(0, 1, 500) y = G(q) samples = jnp.interp(noise_u, y, q) # Use noise_u instead of random sampling velocities_modulus = samples * velocities_escape # Velocity direction sampling phi_v = noise_phi_v * 2 * jnp.pi # Scale [0,1] to [0, 2π] sin_i_v = 2 * noise_sin_i_v - 1 # Scale [0,1] to [-1, 1] velocities = velocities_modulus[:, None]*jnp.array([jnp.cos(jnp.arcsin(sin_i_v))*jnp.cos(phi_v), jnp.cos(jnp.arcsin(sin_i_v))*jnp.sin(phi_v), sin_i_v]).T return jnp.array(positions), jnp.array(velocities), 1/config.N_particles*jnp.ones(config.N_particles)
[docs] @partial(jax.jit) def ic_two_body(mass1: Union[float, jnp.ndarray], mass2: Union[float, jnp.ndarray], rp: Union[float, jnp.ndarray], e: Union[float, jnp.ndarray], params: SimulationParams) -> Tuple: """ Create initial conditions for a two-body system. By default, the two bodies will be placed along the x-axis at the closest distance rp. Depending on the input eccentricity, the two bodies can be in a circular (e < 1), parabolic (e = 1), or hyperbolic orbit (e > 1). Args: mass1 (float): Mass of the first body [nbody units]. mass2 (float): Mass of the second body [nbody units]. rp (float): Closest orbital distance [nbody units]. e (float): Eccentricity. config (NamedTuple): Configuration NamedTuple. params (NamedTuple): Parameters NamedTuple. Returns: tuple: A tuple containing: - pos (jnp.ndarray): Positions of the particles. - vel (jnp.ndarray): Velocities of the particles. - mass (jnp.ndarray): Masses of the particles. """ Mtot=mass1+mass2 # if e==1.: # vrel=jnp.sqrt(params.G * 2*Mtot/rp) # else: # a=rp/(1-e) # vrel=jnp.sqrt(params.G * Mtot*(2./rp-1./a)) def circular_orbit(): return jnp.sqrt(params.G * 2*Mtot/rp) def elliptical_orbit(): a=rp/(1-e) return jnp.sqrt(params.G * Mtot*(2./rp-1./a)) vrel = jax.lax.cond( e == 1., circular_orbit, elliptical_orbit, ) v1 = -params.G*mass2/Mtot * vrel v2 = params.G*mass1/Mtot * vrel pos = jnp.array([[0.,0.,0.],[rp,0.,0.]]) vel = jnp.array([[0.,v1,0.],[0.,v2,0.]]) mass = jnp.array([mass1, mass2]) return pos, vel, mass
[docs] @partial(jax.jit, static_argnames=['num_samples']) @jaxtyped(typechecker=typechecker) def sample_position_on_sphere(key: PRNGKeyArray, r_p: float, num_samples: int = 1): """ Sample uniform positions on a sphere of radius r_p. Args: key (jax.random.PRNGKey): JAX random key for sampling. r_p (float): Radius of the sphere. num_samples (int): Number of samples to generate. Deafult is 1. Returns: jnp.ndarray: Sampled positions (num_samples, 3). """ subkey1, subkey2 = random.split(key) # Sample phi uniformly in [0, 2π] phi = random.uniform(subkey1, shape=(num_samples,), minval=0, maxval=2*jnp.pi) # Sample cos(theta) uniformly in [-1, 1] to ensure uniform distribution on the sphere costheta = random.uniform(subkey2, shape=(num_samples,), minval=-1, maxval=1) theta = jnp.arccos(costheta) # Convert to theta # Convert to Cartesian coordinates x = r_p * jnp.sin(theta) * jnp.cos(phi) y = r_p * jnp.sin(theta) * jnp.sin(phi) z = r_p * jnp.cos(theta) return jnp.stack([x, y, z], axis=-1)
[docs] @partial(jax.jit, static_argnames=['num_samples']) @jaxtyped(typechecker=typechecker) def sample_position_on_circle(key: PRNGKeyArray, r_p: float, num_samples: int =1): """ Sample uniform positions on a sphere of radius r_p. Args: key (jax.random.PRNGKey): JAX random key for sampling. r_p (float): Radius of the sphere. num_samples (int): Number of samples to generate. Returns: jnp.ndarray: Sampled positions (num_samples, 3). """ subkey1, subkey2 = random.split(key) # Sample phi uniformly in [0, 2π] phi = random.uniform(subkey1, shape=(num_samples,), minval=0, maxval=2*jnp.pi) # Sample cos(theta) uniformly in [-1, 1] to ensure uniform distribution on the sphere theta = jnp.radians(90) # Convert to theta # Convert to Cartesian coordinates x = r_p * jnp.sin(theta) * jnp.cos(phi) y = r_p * jnp.sin(theta) * jnp.sin(phi) z = jnp.zeros_like(x) return jnp.stack([x, y, z], axis=-1)
[docs] @partial(jax.jit,) @jaxtyped(typechecker=typechecker) def inclined_position(position: jnp.ndarray, inclination: jnp.ndarray): """ Convert position on the xy-plane to an inclined orbit. Args: position (jnp.ndarray): (x, y, z) position of the Plummer sphere. inclination (jnp.ndarray): Inclination angle in radians. Returns: jnp.ndarray: (x, y, z) position of the Plummer sphere after inclination. """ x, y, z = position.T phi = jnp.arctan2(y, x)[0] # Azimuthal angle # Rotation matrix around x-axis by inclination R_x = jnp.array([ [1, 0, 0], [0, jnp.cos(inclination), -jnp.sin(inclination)], [0, jnp.sin(inclination), jnp.cos(inclination)] ]) # Rotate position around x-axis rotated_position = R_x @ position.T return rotated_position.T
[docs] @partial(jax.jit,) @jaxtyped(typechecker=typechecker) def inclined_circular_velocity(position: jnp.ndarray, v_c: jnp.ndarray, inclination: jnp.ndarray): """ Convert circular velocity module on the xy plane to an inclined orbit Cartesian components. Args: position (jnp.ndarray): (x, y, z) position of the Plummer sphere. v_c (float): Circular velocity (km/s). inclination (float): Inclination angle in radians. Returns: jnp.ndarray: (v_x, v_y, v_z) velocity components. """ x, y, z = position.T phi = jnp.arctan2(y, x) # Azimuthal angle # Initial velocity vector velocity = jnp.zeros(3) velocity = velocity.at[1].set(v_c[0]) # Rotation matrix around z-axis by phi R_z = jnp.array([ [jnp.cos(phi[0]), -jnp.sin(phi[0]), 0], [jnp.sin(phi[0]), jnp.cos(phi[0]), 0], [0, 0, 1] ]) # Rotate velocity around z-axis velocity = R_z @ velocity # Rotation matrix around x-axis by inclination R_x = jnp.array([ [1, 0, 0], [0, jnp.cos(inclination), -jnp.sin(inclination)], [0, jnp.sin(inclination), jnp.cos(inclination)] ]) # Rotate velocity around x-axis rotated_velocity = R_x @ velocity return rotated_velocity