GD1#

In this notebook we replace the results in Alvey et al. 2024 on the GD1 stream.

# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from autocvd import autocvd
autocvd(num_gpus = 1)

import jax 
import jax.numpy as jnp
from jax import jit, random
import equinox as eqx
from jax.sharding import Mesh, PartitionSpec, NamedSharding

# jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt


import numpy as np
from astropy import units as u
from astropy import constants as c

import odisseo
from odisseo import construct_initial_state
from odisseo.integrators import leapfrog
from odisseo.dynamics import direct_acc, DIRECT_ACC, DIRECT_ACC_LAXMAP, DIRECT_ACC_FOR_LOOP, DIRECT_ACC_MATRIX
from odisseo.option_classes import SimulationConfig, SimulationParams, MNParams, NFWParams, PlummerParams, PSPParams, MN_POTENTIAL, NFW_POTENTIAL, PSP_POTENTIAL
from odisseo.initial_condition import Plummer_sphere, ic_two_body, sample_position_on_sphere, inclined_circular_velocity, sample_position_on_circle, inclined_position
from odisseo.utils import center_of_mass
from odisseo.time_integration import time_integration
from odisseo.units import CodeUnits
from odisseo.visualization import create_3d_gif, create_projection_gif, energy_angular_momentum_plot
from odisseo.potentials import MyamotoNagai, NFW

from odisseo.utils import halo_to_gd1_velocity_vmap, halo_to_gd1_vmap, projection_on_GD1


plt.rcParams.update({
    'font.size': 20,
    'axes.labelsize': 20,
    'xtick.labelsize': 13,
    'ytick.labelsize': 13,
    'legend.fontsize': 15,
})
plt.style.use('default')

# 

Setting up the simulation parameters and configurations#

from odisseo.option_classes import DIFFRAX_BACKEND, TSIT5
code_length = 10 * u.kpc
code_mass = 1e4 * u.Msun
G = 1
code_time = 3 * u.Gyr
code_units = CodeUnits(code_length, code_mass, G=1, unit_time = code_time )  


config = SimulationConfig(N_particles = 10_000, 
                          return_snapshots = True, 
                          num_snapshots = 1000, 
                          num_timesteps = 1000, 
                          external_accelerations=(NFW_POTENTIAL, MN_POTENTIAL, PSP_POTENTIAL), 
                          acceleration_scheme = DIRECT_ACC_MATRIX,
                          softening = (0.1 * u.pc).to(code_units.code_length).value,
                          integrator=DIFFRAX_BACKEND,
                          fixed_timestep=False,
                          diffrax_solver=TSIT5
                          ) #default values

params = SimulationParams(t_end = (3 * u.Gyr).to(code_units.code_time).value,  
                          Plummer_params= PlummerParams(Mtot=(10**4.05 * u.Msun).to(code_units.code_mass).value,
                                                        a=(100 * u.pc).to(code_units.code_length).value),
                           MN_params= MNParams(M = (68_193_902_782.346756 * u.Msun).to(code_units.code_mass).value,
                                              a = (3.0 * u.kpc).to(code_units.code_length).value,
                                              b = (0.280 * u.kpc).to(code_units.code_length).value),
                          NFW_params= NFWParams(Mvir=(4.3683325e11 * u.Msun).to(code_units.code_mass).value,
                                               r_s= (16.0 * u.kpc).to(code_units.code_length).value,),      
                          PSP_params= PSPParams(M = 4501365375.06545 * u.Msun.to(code_units.code_mass),
                                                alpha = 1.8, 
                                                r_c = (1.9*u.kpc).to(code_units.code_length).value),                    
                          G=code_units.G, ) 


key = random.PRNGKey(500000000)

#set up the particles in the initial state
positions, velocities, mass = Plummer_sphere(key=key, params=params, config=config)
#the center of mass needs to be integrated backwards in time first 
config_com = config._replace(N_particles=1,)
params_com = params._replace(t_end=-params.t_end,)



#this is the final position of the cluster, we need to integrate backwards in time 
pos_com_final = jnp.array([[11.8, 0.79, 6.4]]) * u.kpc.to(code_units.code_length)
vel_com_final = jnp.array([[109.5,-254.5,-90.3]]) * (u.km/u.s).to(code_units.code_velocity)
mass_com = jnp.array([params_com.Plummer_params.Mtot])
final_state_com = construct_initial_state(pos_com_final, vel_com_final) # state is a (N_particles x 2 x 3)

#evolution in time
snapshots_com = time_integration(final_state_com, mass_com, config_com, params_com)

#we can plot the snapshots of simulations, the snapshot are NameTuple with states=(N_snapshots x N_particles x 2 x 3) array
pos_com, vel_com = snapshots_com.states[-1, :, 0], snapshots_com.states[-1, :, 1]

##### CoM orbit plot####
fig = plt.figure(figsize=(15, 10), tight_layout=True)
ax = fig.add_subplot(111, projection='3d')
ax.scatter(snapshots_com.states[-1, 0, 0, 0]* code_units.code_length.to(u.kpc), 
           snapshots_com.states[-1, 0, 0, 1]* code_units.code_length.to(u.kpc), 
           snapshots_com.states[-1,0, 0, 2]* code_units.code_length.to(u.kpc),c='r', label='Initial position')
ax.scatter(snapshots_com.states[0, 0, 0, 0]* code_units.code_length.to(u.kpc), 
           snapshots_com.states[0, 0, 0, 1]* code_units.code_length.to(u.kpc), 
           snapshots_com.states[0,0, 0, 2]* code_units.code_length.to(u.kpc), c='b', label='Final position')
ax.plot(snapshots_com.states[:, 0, 0, 0]* code_units.code_length.to(u.kpc), 
        snapshots_com.states[:, 0, 0, 1]* code_units.code_length.to(u.kpc), 
        snapshots_com.states[:,0, 0, 2]* code_units.code_length.to(u.kpc), 'k-', label='CoM trajectory')
ax.set_xlabel("X [kpc]")
ax.set_ylabel("Y [kpc]")
ax.set_zlabel("Z [kpc]")
ax.legend()
plt.show()

fig = plt.figure(figsize=(10, 10), tight_layout=True)
ax = fig.add_subplot(111)
ax.plot(snapshots_com.states[:, 0, 0, 0]* code_units.code_length.to(u.kpc), 
        snapshots_com.states[:, 0, 0, 1]* code_units.code_length.to(u.kpc), 'k-', label='CoM trajectory')
ax.plot(snapshots_com.states[-1, 0, 0, 0]* code_units.code_length.to(u.kpc), 
        snapshots_com.states[-1, 0, 0, 1]* code_units.code_length.to(u.kpc), 'ro', label='Initial position')
ax.plot(snapshots_com.states[0, 0, 0, 0]* code_units.code_length.to(u.kpc), 
        snapshots_com.states[0, 0, 0, 1]* code_units.code_length.to(u.kpc), 'bo', label='Final position')
ax.set_xlabel("X [kpc]")
ax.set_ylabel("Y [kpc]")
ax.legend()

#check conservation of energy and angular momentum
energy_angular_momentum_plot(snapshots_com, code_units,)
../_images/7dc221b913a2ced79c6e72cca8c48d5d597294dc63e49b6e307f92c2c3eeab6d.png ../_images/d830cd17d7af62b64d8fd3baf73e0d239d128d5b4db0a50dacdf155e860b9e15.png ../_images/ae0111f8bc4dab20357f6594a02bb7ebad59651f5c2da292f5ebf9d9363a830e.png

Dwarf Galaxy position and velocity#

# Add the center of mass position and velocity to the Plummer sphere particles
positions = positions + pos_com
velocities = velocities + vel_com

#initialize the initial state
initial_state_stream = construct_initial_state(positions, velocities)

#run the simulation
snapshots = time_integration(initial_state_stream, mass, config, params)
plt.style.use('default')

# Create a circular ring (solid line)
radius = 8 #kpc
z_disk = 0
theta = np.linspace(0, 2*np.pi, 500)
x_ring = radius * np.cos(theta)
y_ring = radius * np.sin(theta)
z_ring = np.full_like(theta, z_disk)

fig = plt.figure(figsize=(15, 15), tight_layout=True,)
ax = fig.add_subplot(111, projection='3d')
for i in np.linspace(0, config.num_snapshots, 8, dtype=int):
    ax.scatter(snapshots.states[i, :, 0, 0] * code_units.code_length.to(u.kpc), 
               snapshots.states[i, :, 0, 1] * code_units.code_length.to(u.kpc), 
               snapshots.states[i, :, 0, 2] * code_units.code_length.to(u.kpc), 
               label=f"t={(snapshots.times[i]*code_units.code_time).to(u.Gyr):.2f}",
               s=50)


ax.plot(snapshots_com.states[:, 0, 0, 0]* code_units.code_length.to(u.kpc), 
        snapshots_com.states[:, 0, 0, 1]* code_units.code_length.to(u.kpc), 
        snapshots_com.states[:,0, 0, 2]* code_units.code_length.to(u.kpc), 'k-', alpha=0.5, label='CoM trajectory')
# ax.scatter(0, 0, 0, c='k', s=100, )

# Ring (Milky Way disk)
# ax.plot(x_ring, y_ring, z_ring, color="royalblue", linewidth=3, label="MW Disk")

r = 8
n = 500


# Random polar points inside the disk
theta = np.random.uniform(0, 2*np.pi, n)
rho = r * np.sqrt(np.random.uniform(0, 1, n))  # ensures uniform density
x = rho * np.cos(theta)
y = rho * np.sin(theta)
z = np.zeros_like(x)

# Triangular surface
ax.plot_trisurf(x, y, z, color="lightgray", alpha=0.7, linewidth=0, label="MW Disk")

ax.set_xlabel('X [kpc]', fontsize=20)
ax.set_ylabel('Y [kpc]', fontsize=20)
ax.set_zlabel('Z [kpc]', fontsize=20)
ax.legend(fontsize=20, ncol=5)


ax.tick_params(axis='both', which='major', labelsize=17)  # 
# ax.axis('off')
fig.savefig("gd1_orbit.pdf", bbox_inches='tight')



# Remove the grid
# ax.grid(False)

energy_angular_momentum_plot(snapshots, code_units,)
../_images/db1bd48f04532ed5a9eab2ba0e72350f9ccc122cdbcc3076f6123f4d7a92db86.png ../_images/415db35baa5c49c6e367d4c652efd92341d4cfd0fa71567cc2a8b705914233b4.png
fig = plt.figure(figsize=(10, 10), tight_layout=True)
ax = fig.add_subplot(111)
ax.plot(snapshots_com.states[:, 0, 0, 0]* code_units.code_length.to(u.kpc), 
    snapshots_com.states[:, 0, 0, 1]* code_units.code_length.to(u.kpc), 'r-', label='CoM trajectory')
ax.plot(snapshots_com.states[-1, 0, 0, 0]* code_units.code_length.to(u.kpc), 
    snapshots_com.states[-1, 0, 0, 1]* code_units.code_length.to(u.kpc), 'ro', label='Initial position')
ax.plot(snapshots_com.states[0, 0, 0, 0]* code_units.code_length.to(u.kpc), 
    snapshots_com.states[0, 0, 0, 1]* code_units.code_length.to(u.kpc), 'bo', label='Final position')
ax.scatter(snapshots.states[-1, :, 0, 0]* code_units.code_length.to(u.kpc), 
    snapshots.states[-1, :, 0, 1]* code_units.code_length.to(u.kpc), s=0.1, label='Final positions of the stream')
ax.set_xlabel("X [kpc]")
ax.set_ylabel("Y [kpc]")
# ax.set_xlim(-30, 30)
# ax.set_ylim(-30, 30)
ax.legend()
<matplotlib.legend.Legend at 0x7399cc15e390>
../_images/cadb3b88993998993b3072d5360c93a5f9b71d077646179e8bb9d90cfbe84df5.png
fig = plt.figure(figsize=(6, 8), tight_layout=True)
ax = fig.add_subplot(111)
conversion = code_units.code_length.to(u.kpc)
ax.plot(snapshots_com.states[:, 0, 0, 0]*conversion, 
        snapshots_com.states[:, 0, 0, 1]* conversion, 'r-', label='CoM trajectory')
ax.plot(snapshots_com.states[-1, 0, 0, 0]*conversion,
         snapshots_com.states[-1, 0, 0, 1]*conversion, 'ro', label='Initial position')
ax.plot(snapshots_com.states[0, 0, 0, 0]*conversion, 
        snapshots_com.states[0, 0, 0, 1]*conversion, 'bo', label='Final position')
ax.scatter(snapshots.states[-1, :, 0, 0]*conversion, 
           snapshots.states[-1, :, 0, 1]*conversion, s=0.1, label='Final positions of the stream')
ax.set_xlabel("X [kpc]")
ax.set_ylabel("Y [kpc]")
ax.set_xlim(-5, 15)
ax.set_ylim(-25, 20)
ax.legend()
<matplotlib.legend.Legend at 0x7399cc6ec800>
../_images/cbaf3cc44a4b3cefb9269db74ba69987ded52523c638b167cf8bbfdb1aebbee5.png
final_state = snapshots.states[-1].copy()
final_positions, final_velocities = final_state[:, 0], final_state[:, 1]

final_positions = final_positions * code_units.code_length.to(u.kpc)
final_velocities = final_velocities * code_units.code_velocity.to(u.kpc / u.Myr)

Projection on the GD1 plane#

s = projection_on_GD1(final_state, code_units=code_units,)
R = s[:, 0]  # kpc
phi1 = s[:, 1]  # deg
phi2 = s[:, 2]  # deg
vR = s[:, 3]  # km/s
v1_cosphi2 = s[:, 4]  # mass/yr
v2 = s[:, 5]  # mass/yr

fig = plt.figure(figsize=(25, 5), tight_layout=True)

ax = fig.add_subplot(131)
ax.scatter(phi1, phi2, s=0.1)
ax.set_xlabel("$\phi_1$ [deg]")
ax.set_ylabel("$\phi_2$ [deg]")
ax.set_xlim(-120, 70)
ax.set_ylim(-8, 2)

ax = fig.add_subplot(132)
ax.scatter(v1_cosphi2 ,
           v2 ,
            s=0.1)
ax.set_xlabel("$v_{\phi_1}\cos(\phi_2)$ [mass/yr]")
ax.set_ylabel("$v_{\phi_2}$ [mass/yr]")
ax.set_xlim(-2., 1.0)
ax.set_ylim(-0.10, 0.10)

ax = fig.add_subplot(133)
ax.scatter(R, vR , s=0.1)
ax.set_xlabel("R [kpc]")
ax.set_ylabel("$v_R$ [km/s]")
ax.set_xlim(6, 20)
ax.set_ylim(-250, 250)
(-250.0, 250.0)
../_images/a13e632cf51e364fd41f7d789a3772d0c75d2ebfe2b574eddae601d655c3e847.png
import matplotlib.colors as colors

fig = plt.figure(figsize=(30, 5), tight_layout=True)

# Define bin edges and create meshgrids
phi1_bins = jnp.linspace(-120, 70, 65)    # 64 bins
phi2_bins = jnp.linspace(-8, 2, 33)       # 32 bins
v1_bins = jnp.linspace(-2., 1.0, 65)      # 64 bins  
v2_bins = jnp.linspace(-0.10, 0.10, 33)   # 32 bins
R_bins = jnp.linspace(6, 20, 65)          # 64 bins
vR_bins = jnp.linspace(-250, 250, 33)     # 32 bins

# Create meshgrids for bin edges (not centers)
PHI1, PHI2 = jnp.meshgrid(phi1_bins, phi2_bins, indexing='ij')
V1, V2 = jnp.meshgrid(v1_bins, v2_bins, indexing='ij')
R_GRID, VR_GRID = jnp.meshgrid(R_bins, vR_bins, indexing='ij')

# Create 2D histograms
ax = fig.add_subplot(131)
counts1 = jnp.histogram2d(phi1, phi2, bins=[phi1_bins, phi2_bins])[0]
im1 = ax.pcolormesh(PHI1, PHI2, np.log10(counts1), cmap='coolwarm')
ax.set_xlabel("$\phi_1$ [deg]")
ax.set_ylabel("$\phi_2$ [deg]")
ax.set_xlim(-120, 70)
ax.set_ylim(-8, 2)
# Define a normalization that centers white at 0
plt.colorbar(im1, ax=ax, label=r'$\log_{10}(\text{N}_{\text{stars}})$',  )

ax = fig.add_subplot(132)
counts2 = jnp.histogram2d(v1_cosphi2, v2, bins=[v1_bins, v2_bins])[0]
im2 = ax.pcolormesh(V1, V2, np.log10(counts2), cmap='coolwarm')
ax.set_xlabel("$v_{\phi_1}\cos(\phi_2)$ [mas/yr]")
ax.set_ylabel("$v_{\phi_2}$ [mas/yr]")
ax.set_xlim(-2., 1.0)
ax.set_ylim(-0.10, 0.10)
plt.colorbar(im2, ax=ax, label=r'$\log_{10}(\text{N}_{\text{stars}})$')

ax = fig.add_subplot(133)
counts3 = jnp.histogram2d(R, vR, bins=[R_bins, vR_bins])[0]
im3 = ax.pcolormesh(R_GRID, VR_GRID, np.log10(counts3), cmap='coolwarm')
ax.set_xlabel("R [kpc]")
ax.set_ylabel("$v_R$ [km/s]")
ax.set_xlim(6, 20)
ax.set_ylim(-250, 250)
plt.colorbar(im3, ax=ax, label=r'$\log_{10}(\text{N}_{\text{stars}})$')
/tmp/ipykernel_2170897/3570116978.py:21: RuntimeWarning: divide by zero encountered in log10
  im1 = ax.pcolormesh(PHI1, PHI2, np.log10(counts1), cmap='coolwarm')
/tmp/ipykernel_2170897/3570116978.py:31: RuntimeWarning: divide by zero encountered in log10
  im2 = ax.pcolormesh(V1, V2, np.log10(counts2), cmap='coolwarm')
/tmp/ipykernel_2170897/3570116978.py:40: RuntimeWarning: divide by zero encountered in log10
  im3 = ax.pcolormesh(R_GRID, VR_GRID, np.log10(counts3), cmap='coolwarm')
<matplotlib.colorbar.Colorbar at 0x739af417bf80>
../_images/8bdcaec908fbda4a956f87988c1349a7036c1c9d38bf3f30a8b02cf064dc7159.png
# Plotting the vector field of velocities
fig  = plt.figure(figsize=(10, 4), tight_layout=True)
ax = fig.add_subplot(111)
vectorf_field = ax.quiver(phi1, 
                            phi2,
                            v1_cosphi2/jnp.cos(jnp.deg2rad(phi2)),
                            v2,
                            vR,
                            scale=25,)
plt.colorbar(vectorf_field, ax=ax, label='vR [km/s]')
ax.set_xlabel("$\phi_1$ [deg]")
ax.set_ylabel("$\phi_2$ [deg]")
ax.set_xlim(-120, 70)
ax.set_ylim(-8, 2)
(-8.0, 2.0)
../_images/12e4d250c77666497d5354bd74747f2e5b910d655ffb31c28fc640d003d35996.png
# np.savez('/export/data/vgiusepp/odisseo_data/data_fix_position/true.npz',
#          x=s,
#          theta=np.array([params.t_end * code_units.code_time.to(u.Gyr),
#                          params.Plummer_params.Mtot * code_units.code_mass.to(u.Msun),
#                          params.Plummer_params.a * code_units.code_length.to(u.kpc),
#                          params.NFW_params.Mvir * code_units.code_mass.to(u.Msun),
#                          params.NFW_params.r_s * code_units.code_length.to(u.kpc),
#                          params.MN_params.M * code_units.code_mass.to(u.Msun),
#                          params.MN_params.a * code_units.code_length.to(u.kpc),]))

Plummer parameters dependence#

plt.hist(jnp.linalg.norm(snapshots.states[0, :, 0] - snapshots_com.states[-1, :, 0], axis=1))
(array([9.968e+03, 2.100e+01, 6.000e+00, 2.000e+00, 1.000e+00, 0.000e+00,
        0.000e+00, 1.000e+00, 0.000e+00, 1.000e+00]),
 array([5.59819222e-04, 1.97572753e-01, 3.94585669e-01, 5.91598630e-01,
        7.88611531e-01, 9.85624433e-01, 1.18263745e+00, 1.37965035e+00,
        1.57666326e+00, 1.77367616e+00, 1.97068906e+00]),
 <BarContainer object of 10 artists>)
../_images/f30da75fa076ddf0c75d9869c60f0bdc70199cb65f4134de71f8d0270d91fc30.png
# ...existing code...
colormap = 'coolwarm'
fontsize = 30
mask = jnp.linalg.norm(snapshots.states[0, :, 0] - snapshots_com.states[-1, :, 0], axis=1) < 0.1


fig = plt.figure(figsize=(15, 7), constrained_layout=False, layout='tight')
gs = fig.add_gridspec(nrows=2, ncols=2, height_ratios=[0.07, 1.0], hspace=0.0, wspace=0.0)

ax1 = fig.add_subplot(gs[1, 0])
ax2 = fig.add_subplot(gs[1, 1])

dist0 = jnp.linalg.norm(snapshots.states[0, :, 0] - snapshots_com.states[-1, :, 0], axis=1) * code_units.code_length.to(u.kpc)
vmin, vmax = float(np.asarray(dist0[mask]).min()), float(np.asarray(dist0[mask]).max())

sc1 = ax1.scatter(
    snapshots.states[0, :, 0, 0][mask] * code_units.code_length.to(u.kpc),
    snapshots.states[0, :, 0, 1][mask] * code_units.code_length.to(u.kpc),
    s=5, c=dist0[mask], cmap=colormap, vmin=vmin, vmax=vmax
)
ax1.scatter(
    snapshots_com.states[-1, :, 0, 0]*code_units.code_length.to(u.kpc),
    snapshots_com.states[-1, :, 0, 1]*code_units.code_length.to(u.kpc),
    s=50, facecolors='white', edgecolors='black', marker='o'
)
ax1.set_xlabel("X [kpc]", fontsize=fontsize)
ax1.set_ylabel("Y [kpc]", fontsize=fontsize)

sc2 = ax2.scatter(
    snapshots.states[-1, :, 0, 0][mask] * code_units.code_length.to(u.kpc),
    snapshots.states[-1, :, 0, 1][mask] * code_units.code_length.to(u.kpc),
    s=5, c=dist0[mask], cmap=sc1.cmap, norm=sc1.norm
)
ax2.scatter(
    snapshots_com.states[0, :, 0, 0]*code_units.code_length.to(u.kpc),
    snapshots_com.states[0, :, 0, 1]*code_units.code_length.to(u.kpc),
    s=50, facecolors='white', edgecolors='black', marker='o', label='Progenitor'
)
ax2.set_xlabel("X [kpc]", fontsize=fontsize)
ax2.yaxis.set_ticks_position('right')
ax2.legend(fontsize=fontsize-5, loc='upper right')

# shared colorbar spanning both columns, placed on top
cax = fig.add_subplot(gs[0, :])
cb = fig.colorbar(sc1, cax=cax, orientation='horizontal', )
cb.set_label(r"$r_{Progenitor}$ [kpc]", fontsize=fontsize, labelpad=14)
cb.ax.xaxis.set_label_position('top')
cb.ax.xaxis.set_ticks_position('top')
cb.ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False, labelsize=fontsize-2)


#add text annotations
ax1.text(0.02, 0.03, "$t = 0$ Gyr", transform=ax1.transAxes,
         ha="left", va="bottom", fontsize=fontsize+5, color="k", zorder=10)

ax2.text(0.02, 0.03, "$t = 3$ Gyr", transform=ax2.transAxes,
         ha="left", va="bottom", fontsize=fontsize+5, color="k", zorder=10)


#ticks

from matplotlib.ticker import MaxNLocator

# Reduce number of ticks and set tick label size for each subplot
ax1.xaxis.set_major_locator(MaxNLocator(nbins=4))  # <= max ~4 ticks on X
ax1.yaxis.set_major_locator(MaxNLocator(nbins=4))  # <= max ~4 ticks on Y
ax1.tick_params(axis='both', which='major', labelsize=fontsize-2)

ax2.xaxis.set_major_locator(MaxNLocator(nbins=4))
ax2.yaxis.set_major_locator(MaxNLocator(nbins=4))
ax2.tick_params(axis='both', which='major', labelsize=fontsize-2)
fig.savefig("gd1_stream_evolution.pdf", dpi=300, bbox_inches='tight')
/tmp/ipykernel_2170897/4256544036.py:7: UserWarning: The Figure parameters 'layout' and 'constrained_layout' cannot be used together. Please use 'layout' only.
  fig = plt.figure(figsize=(15, 7), constrained_layout=False, layout='tight')
../_images/95780f6a9022ccc8f230ba1f979053c7ed0bcf44c104a55efa434a41e0567676.png

External potential parameters dependence#

config = config._replace(N_particles=10_000,
                        return_snapshots=False, )
config_com = config_com._replace(return_snapshots=False, )

@jit
def run_simulation(params_for_sims):
    Mvir, r_s = params_for_sims
    new_params = params._replace(NFW_params= NFWParams(Mvir=Mvir,
                                               r_s= r_s,),)
    #we also update the t_end parameter for the center of mass
    new_params_com = new_params._replace(
                t_end=-new_params.t_end # Update the t_end parameter for the center of mass
                )
    
    #Final position and velocity of the center of mass
    pos_com_final = jnp.array([[11.8, 0.79, 6.4]]) * u.kpc.to(code_units.code_length)
    vel_com_final = jnp.array([[109.5,-254.5,-90.3]]) * (u.km/u.s).to(code_units.code_velocity)
    mass_com = jnp.array([params.Plummer_params.Mtot]) 
    
    #we construmt the initial state of the com 
    initial_state_com = construct_initial_state(pos_com_final, vel_com_final,)
    #we run the simulation backwards in time for the center of mass
    final_state_com = time_integration(initial_state_com, mass_com, config=config_com, params=new_params_com)
    #we calculate the final position and velocity of the center of mass
    pos_com = final_state_com[:, 0]
    vel_com = final_state_com[:, 1]

    #we construct the initial state of the Plummer sphere
    positions, velocities, mass = Plummer_sphere(key=key, params=new_params, config=config)
    #we add the center of mass position and velocity to the Plummer sphere particles
    positions = positions + pos_com
    velocities = velocities + vel_com
    #initialize the initial state
    initial_state_stream = construct_initial_state(positions, velocities, )
    #run the simulation
    final_state = time_integration(initial_state_stream, mass, config=config, params=new_params)
    return final_state


new_params = jnp.array([[params.NFW_params.Mvir * 0.8,
                         params.NFW_params.r_s * 0.8],
                        [params.NFW_params.Mvir * 0.8,
                        params.NFW_params.r_s * 1.2],
                        [params.NFW_params.Mvir * 1.2,
                        params.NFW_params.r_s * 0.8],
                        [params.NFW_params.Mvir * 1.2,
                        params.NFW_params.r_s * 1.2],])

new_final_states = jax.vmap(run_simulation)(new_params)
# ...existing code...
from matplotlib.ticker import MaxNLocator
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

def plot_external_potential_grid(new_final_states,
                                 snapshots,
                                 new_params,
                                 labels_fontsize=20,
                                 ticks_fontsize=20,
                                 xy_labelpad=0,
                                 n_xticks=4,
                                 n_yticks=4,
                                 colors=None):
    """2x2 grid; each panel colored by its (Mvir, r_s) ratio; legend in lower left."""
    

    base_Mvir = params.NFW_params.Mvir
    base_rs = params.NFW_params.r_s

    if colors is None:
        cmap = plt.get_cmap('coolwarm')
        samples = [0.10, 0.35, 0.80, 0.90]
        colors = [matplotlib.colors.to_hex(cmap(s)) for s in samples]

    fig, axes = plt.subplots(2, 2, figsize=(15, 15), sharex=True, sharey=True)
    fig.subplots_adjust(wspace=0.0, hspace=0.0)

    x0 = snapshots.states[-1, :, 0, 0] * code_units.code_length.to(u.kpc)
    y0 = snapshots.states[-1, :, 0, 1] * code_units.code_length.to(u.kpc)
    xmin, xmax = float(x0.min()), 22.0
    ymin, ymax = -40, float(y0.max())

    for i, ax in enumerate(axes.flat):
        # Background (reference) stream (no legend entry)
        ax.scatter(
            snapshots.states[-1, :, 0, 0] * code_units.code_length.to(u.kpc),
            snapshots.states[-1, :, 0, 1] * code_units.code_length.to(u.kpc),
            c='black', alpha=0.08, s=2
        )

        Mvir_ratio = float(new_params[i, 0] / base_Mvir)
        rs_ratio = float(new_params[i, 1] / base_rs)

        # Colored scatter for this simulation
        ax.scatter(
            new_final_states[i][:, 0, 0] * code_units.code_length.to(u.kpc),
            new_final_states[i][:, 0, 1] * code_units.code_length.to(u.kpc),
            s=2,
            color=colors[i],
        )
        ax.text(-34, -35,
                f"{Mvir_ratio:.1f}$M_{{NFW}}$ \n{rs_ratio:.1f}$r_s$",
                fontsize=labels_fontsize,
                bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))

        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)

        if i // 2 != 1:
            ax.tick_params(labelbottom=False)
        if i % 2 != 0:
            ax.tick_params(labelleft=False)

        ax.xaxis.set_major_locator(MaxNLocator(nbins=n_xticks))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=n_yticks))
        ax.tick_params(axis='both', which='major', labelsize=ticks_fontsize)

        # Legend in lower left
        # ax.legend(loc='lower left', fontsize=ticks_fontsize, frameon=True)

    fig.text(0.5, 0.02 + xy_labelpad, "X [kpc]", ha='center', fontsize=labels_fontsize)
    fig.text(0.02 + xy_labelpad, 0.5, "Y [kpc]", va='center', rotation='vertical', fontsize=labels_fontsize)

    return fig, axes

# Example call
fig, axes = plot_external_potential_grid(new_final_states,
                                         snapshots,
                                         new_params=new_params,
                                         labels_fontsize=30,
                                         ticks_fontsize=25,
                                         xy_labelpad=0.03,
                                         n_xticks=3,
                                         n_yticks=3,
                                         colors=None)
fig.savefig("initial_params_dependence.pdf", dpi=300, bbox_inches='tight')

# ...existing code...
../_images/1c9d124ce5521f6223cddf7d06df6169e02947403693f5b0430f1c91b2777a4c.png