import matplotlib.pyplot as plt
import math
from random import randint, uniform, random
import numpy as np
from matplotlib.animation import FuncAnimation, PillowWriter
import jax.numpy as jnp
from astropy import units as u
from astropy.coordinates import SkyCoord
[docs]
def energy_angular_momentum_plot(snapshots, code_units, filename=None):
"""
Plots the relative change in total energy and angular momentum of the system over time.
Args:
snapshots (odisseo.Snapshots): Snapshots object containing the simulation data.
code_units (odisseo.CodeUnits): CodeUnits object containing the units of the simulation.
filename (str, optional): The filename to save the plot to. If None, the plot will be displayed but not saved.
Returns:
None
"""
fig = plt.figure(figsize=(17, 5), tight_layout=True)
ax = fig.add_subplot(121)
Delta_E = ((snapshots.total_energy - snapshots.total_energy[0])/snapshots.total_energy[0])
ax.plot((snapshots.times * code_units.code_time).to(u.Gyr), 100*Delta_E,)
ax.set_xlabel('Time [Gyr]')
ax.set_ylabel(r'$(E - E_0)/E_0 \% $')
ax.set_ylim(-100, 100)
ax.axhline(5, color='r', linestyle='--', label='5%')
ax.axhline(-5, color='r', linestyle='--', )
ax.grid(linestyle='dotted')
ax.legend()
ax = fig.add_subplot(122)
Delta_AngMom = ((snapshots.angular_momentum - snapshots.angular_momentum[0])/snapshots.angular_momentum[0])
ax.plot((snapshots.times * code_units.code_time).to(u.Gyr), 100*Delta_AngMom[:, 2], )
ax.set_xlabel('Time [Gyr]')
ax.set_ylabel(r'$(L - L_0)/L_0 \% $')
ax.set_ylim(-100, 100)
ax.axhline(5, color='r', linestyle='--', label='5%')
ax.axhline(-5, color='r', linestyle='--', )
ax.grid(linestyle='dotted')
ax.legend()
if filename is not None:
fig.savefig(filename)
plt.show()
[docs]
def plot_last_snapshot(snapshots, code_units, rp, plotting_units_length, filename=None):
"""
Plots the last snapshot of the particles in 3D space.
Args:
snapshots (object): An object containing the states of the particles at different time steps.
code_units (object): An object containing the code units for length conversion.
plotting_units_length (object): The units to which the lengths should be converted for plotting.
filename (str, optional): The filename to save the plot. If None, the plot is not saved (default is None).
Returns:
None
"""
fig = plt.figure(figsize=(10, 10), tight_layout=True)
ax = fig.add_subplot(111, projection='3d')
ax.set_xlabel(f'X {plotting_units_length}')
ax.set_ylabel(f'Y {plotting_units_length}')
ax.set_zlabel(f'Z {plotting_units_length}')
ax.scatter((snapshots.states[-1, :, 0, 0]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[-1, :, 0, 1]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[-1, :, 0, 2]* code_units.code_length).to(plotting_units_length).value)
ax.scatter(0, 0, 0, s=100, marker='*', color='r')
ax.set_xlim(-(rp* code_units.code_length).to(plotting_units_length).value, (rp* code_units.code_length).to(plotting_units_length).value)
ax.set_ylim(-(rp* code_units.code_length).to(plotting_units_length).value, (rp* code_units.code_length).to(plotting_units_length).value)
ax.set_zlim(-(rp* code_units.code_length).to(plotting_units_length).value, (rp* code_units.code_length).to(plotting_units_length).value)
if filename is not None:
fig.savefig(filename)
plt.show()
[docs]
def plot_orbit(snapshots, ax_lim, code_units, plotting_units_length, config, filename=None):
"""
Plots the orbit of particles in 3D space.
Args:
snapshots (object): An object containing the states of the particles at different time steps.
ax_lim (float): The limit for the axes in code units.
code_units (object): An object containing the code units for length conversion.
plotting_units_length (object): The units to which the lengths should be converted for plotting.
config (object): Configuration object containing the number of particles (N_particles).
filename (str, optional): The filename to save the plot. If None, the plot is not saved (default is None).
Raises:
AssertionError: If the number of particles in config.N_particles is 10 or more.
Returns:
None
"""
assert config.N_particles < 10, "Too many particles! "
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.set_xlabel(f'X {plotting_units_length}')
ax.set_ylabel(f'Y {plotting_units_length}')
ax.set_zlabel(f'Z {plotting_units_length}')
ax.set_xlim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_ylim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_zlim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
colors = plt.get_cmap("tab10").colors
for i in range(config.N_particles):
ax.plot((snapshots.states[:, i, 0, 0]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[:, i, 0, 1]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[:, i, 0, 2]* code_units.code_length).to(plotting_units_length).value, label=f'Particle {i}', color=colors[i])
ax.scatter((snapshots.states[-1, i, 0, 0]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[-1, i, 0, 1]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[-1, i, 0, 2]* code_units.code_length).to(plotting_units_length).value, s=25, marker='o', color=colors[i],)
ax.legend()
if filename is not None:
fig.savefig(filename)
plt.show()
[docs]
def plot_sky_projection(snapshots, code_units, plotting_units_length, filename=None):
"""
Plots the sky projection of the particles in the simulation.
Args:
snapshots (odisseo.Snapshots): Snapshots object containing the simulation data.
code_units (odisseo.CodeUnits): CodeUnits object containing the units of the simulation.
plotting_units_length (astropy.units.Quantity): The unit to which the lengths should be converted for plotting.
filename (str, optional): The filename to save the plot to. If None, the plot will be displayed but not saved.
Returns:
None
"""
# Example: 3D Cartesian coordinates in kpc
x = (snapshots.states[-1, :, 0, 0]*code_units.code_length).to(plotting_units_length)
y = (snapshots.states[-1, :, 0, 1]*code_units.code_length).to(plotting_units_length)
z = (snapshots.states[-1, :, 0, 2]*code_units.code_length).to(plotting_units_length)
# Observer's position at (-8, 0, 0) kpc
x_obs, y_obs, z_obs = -8 * u.kpc, 0 * u.kpc, 0 * u.kpc
# Shift to observer's frame
x_rel = x - x_obs
y_rel = y - y_obs
z_rel = z - z_obs
# Convert to Galactic longitude l and latitude b
distance = np.sqrt(x_rel**2 + y_rel**2 + z_rel**2)
l = np.arctan2(y_rel, x_rel).to(u.deg)
b = np.arcsin(z_rel / distance).to(u.deg)
# Convert to Astropy SkyCoord object (if needed)
galactic_coords = SkyCoord(l=l, b=b, distance=distance, frame="galactic")
# Convert to Equatorial (RA, Dec) if needed
equatorial_coords = galactic_coords.transform_to("icrs")
# Get sky-plane projection
ra = equatorial_coords.ra
dec = equatorial_coords.dec
# Convert longitude to range [-180, 180] for better visualization
l_wrap = (l + 180 * u.deg) % (360 * u.deg) - 180 * u.deg
plt.figure(figsize=(20, 10))
ax = plt.subplot(111, projection= 'aitoff')
ax.scatter(l_wrap, b, s=1, color='blue', alpha=0.5)
ax.set_xlabel("Galactic Longitude l (deg)")
ax.set_ylabel("Galactic Latitude b (deg)")
ax.set_title("Sky Projection in Galactic Coordinates")
ax.grid(True, linestyle="--", alpha=0.5)
if filename is not None:
plt.savefig(filename)
plt.show()
[docs]
def create_3d_gif(snapshots, ax_lim, code_units, plotting_units_length, plot_units_time, filename=None):
"""
Create a 3D GIF animation from a series of snapshots.
Args:
snapshots (object): An object containing the states and times of the snapshots to be animated.
ax_lim (float): The axis limit for the 3D plot.
code_units (object): An object containing the code units for length and time.
plotting_units_length (astropy.units.Unit): The units for plotting the length.
plot_units_time (astropy.units.Unit): The units for plotting the time.
filename (str, optional): The filename to save the GIF. If None, the GIF will not be saved.
Returns:
None
"""
# Create a figure for plotting
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.set_axis_off()
# Create lists of star positions for galaxy:
leading_arm, trailing_arm = build_spiral_arms(b=-0.3, arms_info=arms_info)
core_stars = build_core_stars(SCALE)
inner_haze_stars = haze(SCALE, r_mult=2, z_mult=0.5, density=5)
outer_haze_stars = haze(SCALE, r_mult=1, z_mult=0.3, density=5)
color_milky_way = 'w'
ax.scatter(*zip(*leading_arm), c=color_milky_way, marker='.', s=1, alpha=0.5)
ax.scatter(*zip(*trailing_arm), c=color_milky_way, marker='.', s=1, alpha=0.5)
ax.scatter(*zip(*core_stars), c=color_milky_way, marker='.', s=1, alpha=0.5)
ax.scatter(*zip(*inner_haze_stars), c=color_milky_way, marker='.', s=1, alpha=0.5)
ax.scatter(*zip(*outer_haze_stars), c=color_milky_way, marker='.', s=1, alpha=0.5)
# Initialize the scatter plot
scatter1 = ax.scatter([], [], [], c='b')
scatter2 = ax.scatter([], [], [], c='r', marker='*')
def init():
ax.set_xlabel(f'X {plotting_units_length}')
ax.set_ylabel(f'Y {plotting_units_length}')
ax.set_zlabel(f'Z {plotting_units_length}')
return scatter1, scatter2
def update(frame):
ax.clear()
ax.scatter(*zip(*leading_arm), c=color_milky_way, marker='.', s=1, alpha=0.5)
ax.scatter(*zip(*trailing_arm), c=color_milky_way, marker='.', s=1, alpha=0.5)
ax.scatter(*zip(*core_stars), c=color_milky_way, marker='.', s=1, alpha=0.5)
ax.scatter(*zip(*inner_haze_stars), c=color_milky_way, marker='.', s=1, alpha=0.5)
ax.scatter(*zip(*outer_haze_stars), c=color_milky_way, marker='.', s=1, alpha=0.5)
ax.set_xlabel(f'X {plotting_units_length}')
ax.set_ylabel(f'Y {plotting_units_length}')
ax.set_zlabel(f'Z {plotting_units_length}')
ax.set_xlim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_ylim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_zlim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_title(f'Time: {(snapshots.times[frame] * code_units.code_time).to(plot_units_time):.2f} ')
scatter1 = ax.scatter((snapshots.states[frame, :, 0, 0]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[frame, :, 0, 1]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[frame, :, 0, 2]* code_units.code_length).to(plotting_units_length).value, c='b', s=1)
# scatter2 = ax.scatter(0, 0, 0, c='r', s=100, marker='*')
return scatter1, scatter2
# Create the animation
anim = FuncAnimation(fig, update, frames=range(0, len(snapshots.states), 1), init_func=init, blit=False)
if filename is not None:
# Save the animation as a GIF
anim.save(filename, writer=PillowWriter(fps=10))
[docs]
def create_3d_gif_velocitycoding(snapshots, ax_lim, code_units, plotting_units_length, plot_units_time, vmin=None, vmax=None, filename=None):
# Create a figure for plotting
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# Calculate velocity norms across all frames if vmin/vmax not provided
if vmin is None or vmax is None:
all_velocity_norms = jnp.linalg.norm(snapshots.states[:, :, 1], axis=2)
vmin = vmin if vmin is not None else jnp.min(all_velocity_norms)
vmax = vmax if vmax is not None else jnp.max(all_velocity_norms)
# Store colorbar reference
cbar = None
# Initialize the scatter plot
scatter1 = ax.scatter([], [], [], )
scatter2 = ax.scatter([], [], [], c='r', marker='*')
def init():
ax.set_xlabel(f'X {plotting_units_length}')
ax.set_ylabel(f'Y {plotting_units_length}')
ax.set_zlabel(f'Z {plotting_units_length}')
return scatter1, scatter2
def update(frame):
nonlocal cbar
ax.clear()
ax.set_xlabel(f'X {plotting_units_length}')
ax.set_ylabel(f'Y {plotting_units_length}')
ax.set_zlabel(f'Z {plotting_units_length}')
ax.set_xlim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_ylim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_zlim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_title(f'Time: {(snapshots.times[frame] * code_units.code_time).to(plot_units_time):.2f} ')
velocity_norms = jnp.linalg.norm(snapshots.states[frame, :, 1], axis=1)
scatter1 = ax.scatter((snapshots.states[frame, :, 0, 0]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[frame, :, 0, 1]* code_units.code_length).to(plotting_units_length).value,
(snapshots.states[frame, :, 0, 2]* code_units.code_length).to(plotting_units_length).value,
c=velocity_norms,
s=1)
scatter2 = ax.scatter(0, 0, 0, c='r', s=100, marker='*')
if cbar is not None:
cbar.remove()
cbar = fig.colorbar(scatter1, ax=ax, shrink=0.8, pad=0.1)
# vel_label = f"Velocity" if code_units.velocity_units is None else f"Velocity [{code_units.velocity_units}]"
# cbar.set_label(vel_label)
return scatter1, scatter2
# Create the animation
anim = FuncAnimation(fig, update, frames=range(0, len(snapshots.states), 1), init_func=init, blit=False)
if filename is not None:
# Save the animation as a GIF
anim.save(filename, writer=PillowWriter(fps=10))
[docs]
def create_projection_gif(snapshots, ax_lim, code_units, plotting_units_length, plot_units_time, filename=None):
"""
Create a GIF animation of 3D projections from simulation snapshots.
This function generates a GIF animation showing the evolution of 3D projections
of simulation data over time. The projections are displayed in three subplots:
X-Y, X-Z, and Y-Z planes.
Args:
snapshots (object): An object containing the simulation snapshots. It should have attributes
`states` and `times`, where `states` is a 4D array with shape
(num_frames, num_particles, 1, 3) representing the positions of particles
in each frame, and `times` is a 1D array representing the time of each frame.
ax_lim (float): The axis limit for the plots, in code units.
code_units (object): An object containing the code units for length and time. It should have
attributes `code_length` and `code_time`.
plotting_units_length (astropy.units.Unit): The units for plotting the lengths.
plot_units_time (astropy.units.Unit): The units for plotting the time.
filename (str, optional): The filename to save the GIF animation. If None, the animation is not saved.
Returns:
None
"""
# Create a figure for plotting
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 5))
# Initialize the scatter plots
scatter1 = ax1.scatter([], [], c='b')
scatter2 = ax1.scatter([], [], c='r', marker='*')
scatter3 = ax2.scatter([], [], c='b')
scatter4 = ax2.scatter([], [], c='r', marker='*')
scatter5 = ax3.scatter([], [], c='b')
scatter6 = ax3.scatter([], [], c='r', marker='*')
def init():
for ax in [ax1, ax2, ax3]:
ax.set_xlim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_ylim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax1.set_xlabel(f'X {plotting_units_length}')
ax1.set_ylabel(f'Y {plotting_units_length}')
ax2.set_xlabel(f'X {plotting_units_length}')
ax2.set_ylabel(f'Z {plotting_units_length}')
ax3.set_xlabel(f'Y {plotting_units_length}')
ax3.set_ylabel(f'Z {plotting_units_length}')
return scatter1, scatter2, scatter3, scatter4, scatter5, scatter6
def update(frame):
fig.suptitle(f'Time: {(snapshots.times[frame]*code_units.code_time).to(plot_units_time):.2f}')
for ax in [ax1, ax2, ax3]:
ax.clear()
ax.set_xlim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax.set_ylim(-(ax_lim* code_units.code_length).to(plotting_units_length).value, (ax_lim* code_units.code_length).to(plotting_units_length).value)
ax1.set_xlabel(f'X {plotting_units_length}')
ax1.set_ylabel(f'Y {plotting_units_length}')
ax1.grid(linestyle='dotted')
scatter1 = ax1.scatter((snapshots.states[frame, :, 0, 0] * code_units.code_length).to(plotting_units_length).value,
(snapshots.states[frame, :, 0, 1] * code_units.code_length).to(plotting_units_length).value, c='b', s=1)
scatter2 = ax1.scatter(0, 0, c='r', s=100, marker='*')
ax2.set_xlabel(f'X {plotting_units_length}')
ax2.set_ylabel(f'Z {plotting_units_length}')
ax2.grid(linestyle='dotted')
scatter3 = ax2.scatter((snapshots.states[frame, :, 0, 0] * code_units.code_length).to(plotting_units_length).value,
(snapshots.states[frame, :, 0, 2]* code_units.code_length).to(plotting_units_length).value, c='b', s=1)
scatter4 = ax2.scatter(0, 0, c='r', s=100, marker='*')
ax3.set_xlabel(f'Y {plotting_units_length}')
ax3.set_ylabel(f'Z {plotting_units_length}')
ax3.grid(linestyle='dotted')
scatter5 = ax3.scatter((snapshots.states[frame, :, 0, 1] * code_units.code_length).to(plotting_units_length).value,
(snapshots.states[frame, :, 0, 2] * code_units.code_length).to(plotting_units_length).value, c='b', s=1)
scatter6 = ax3.scatter(0, 0, c='r', s=100, marker='*')
return scatter1, scatter2, scatter3, scatter4, scatter5, scatter6
# Create the animation
anim = FuncAnimation(fig, update, frames=range(0, len(snapshots.states), 1), init_func=init, blit=False)
if filename is not None:
# Save the animation as a GIF
anim.save(filename, writer=PillowWriter(fps=10))
# plt.style.use('dark_background')
# Set the radius of the galactic disc (scaling factor):
SCALE = 26 # Use range of 200 - 700.
[docs]
def build_spiral_stars(b, r, rot_fac, fuz_fac):
"""Return list of (x,y,z) points for a logarithmic spiral.
b = constant for spiral direction and "openness"
r = scale factor (galactic disc radius)
rot_fac = factor to rotate each spiral arm
fuz_fac = randomly shift star position; applied to 'fuzz' variable
"""
fuzz = int(0.030 * abs(r)) # Scalable initial amount to shift locations.
num_stars = 1000
spiral_stars = []
for i in range(0, num_stars):
theta = math.radians(i)
x = r * math.exp(b*theta) * math.cos(theta - math.pi * rot_fac) - randint(-fuzz, fuzz) * fuz_fac
y = r * math.exp(b*theta) * math.sin(theta - math.pi * rot_fac) - randint(-fuzz, fuzz) * fuz_fac
z = uniform((-SCALE / (SCALE * 3)), (SCALE / (SCALE * 3)))
spiral_stars.append((x, y, z))
return spiral_stars
# Assign scale factor, rotation factor, and fuzz factor for spiral arms.
# Each arm is a pair: leading arm + trailing arm:
arms_info = [(SCALE, 1, 1.5), (SCALE, 0.91, 1.5),
(-SCALE, 1, 1.5), (-SCALE, -1.09, 1.5),
(-SCALE, 0.5, 1.5), (-SCALE, 0.4, 1.5),
(-SCALE, -0.5, 1.5), (-SCALE, -0.6, 1.5)]
[docs]
def build_spiral_arms(b, arms_info):
"""Return lists of point coordinates for galactic spiral arms.
b = constant for spiral direction and "openness"
arms_info = list of scale, rotation, and fuzz factors
"""
leading_arms = []
trailing_arms = []
for i, arm_info in enumerate(arms_info):
arm = build_spiral_stars(b=b,
r=arm_info[0],
rot_fac=arm_info[1],
fuz_fac=arm_info[2])
if i % 2 != 0:
leading_arms.extend(arm)
else:
trailing_arms.extend(arm)
return leading_arms, trailing_arms
[docs]
def spherical_coords(num_pts, radius):
"""Return list of uniformly distributed points in a sphere."""
position_list = []
for _ in range(num_pts):
coords = np.random.normal(0, 1, 3)
coords *= radius
coords[2] *= 0.02 # Reduce z range for matplotlib default z-scale.
position_list.append(list(coords))
return position_list
[docs]
def build_core_stars(scale_factor):
"""Return lists of point coordinates for galactic core stars."""
core_radius = scale_factor / 15
num_rim_stars = 3000
outer_stars = spherical_coords(num_rim_stars, core_radius)
inner_stars = spherical_coords(int(num_rim_stars/4), core_radius/2.5)
return (outer_stars + inner_stars)
[docs]
def haze(scale_factor, r_mult, z_mult, density):
"""Generate uniform random (x,y,z) points within a disc for 2-D display.
scale_factor = galactic disc radius
r_mult = scalar for radius of disc
z_mult = scalar for z values
density = multiplier to vary the number of stars posted
"""
haze_coords = []
for _ in range(0, scale_factor * density):
n = random()
theta = uniform(0, 2 * math.pi)
x = round(math.sqrt(n) * math.cos(theta) * scale_factor) / r_mult
y = round(math.sqrt(n) * math.sin(theta) * scale_factor) / r_mult
z = np.random.uniform(-1, 1) * z_mult
haze_coords.append((x, y, z))
return haze_coords