import numpy as np
import matplotlib.pyplot as plt
import tools21cm as t2c
from ..structs import TemporalCube, Parameters
from ..structs.statistics import StatisticsEstimator
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _mean(grid_or_stats, field: str, parameters: Parameters | None = None):
"""Return ``(z, values)`` for *field* from either a TemporalCube or a StatisticsEstimator."""
if isinstance(grid_or_stats, StatisticsEstimator):
r = grid_or_stats.results
return r['z'], r[f'mean_{field}']
grid = grid_or_stats
return grid.z[:], grid.global_mean(field)
def _ps_dTb(grid_or_stats, parameters: Parameters | None, k_value=None, k_index=1):
"""Return ``(z, k_bins, ps_matrix, mean_dTb)`` for the dTb power spectrum."""
if isinstance(grid_or_stats, StatisticsEstimator):
r = grid_or_stats.results
z = r['z']
k_bins = r['k']
# ps_dTb stored as Delta^2 already
ps_c = r['ps_dTb']
mean_dTb = r['mean_dTb']
return z, k_bins, ps_c, mean_dTb
grid = grid_or_stats
z = grid.z[:]
mean_dTb = grid.global_mean('Grid_dTb')
ps, k_bins = grid.power_spectrum(grid.Grid_dTb, parameters)
ps_c = ps * k_bins[np.newaxis, :] ** 3 * mean_dTb[:, np.newaxis] ** 2 / (2 * np.pi ** 2)
return z, k_bins, ps_c, mean_dTb
# ---------------------------------------------------------------------------
# Public draw functions
# ---------------------------------------------------------------------------
[docs]
def draw_dTb_signal(ax: plt.Axes, grid, label=None, color=None, **kwargs):
"""Plot the global mean differential brightness temperature dTb(z).
Args:
ax (matplotlib.axes.Axes): Axis to draw on.
grid (TemporalCube | StatisticsEstimator): Data source.
label (str, optional): Legend label.
color (str|tuple, optional): Line color.
**kwargs: Additional keyword arguments forwarded to ``ax.plot``.
"""
z_range, mean_dtb = _mean(grid, 'dTb')
ax.plot(z_range, mean_dtb, color=color, label=label, **kwargs)
ax.set_xlim(z_range.min() - 0.2, z_range.max())
ax.set_xlabel('z')
ax.set_ylabel(r'$dT_b$ [mK]')
[docs]
def draw_x_alpha_signal(ax: plt.Axes, grid, label=None, color=None, **kwargs):
"""Plot the mean Lyman-alpha coupling history x_alpha(z).
Args:
ax (matplotlib.axes.Axes): Axis to draw on.
grid (TemporalCube | StatisticsEstimator): Data source.
label (str, optional): Legend label.
color (str|tuple, optional): Line color.
**kwargs: Additional keyword arguments forwarded to ``ax.semilogy``.
"""
z_range, mean_x_alpha = _mean(grid, 'xal')
ax.semilogy(z_range, mean_x_alpha, color=color, label=label, **kwargs)
ax.set_xlim(z_range.min() - 0.2, z_range.max())
ax.set_xlabel('z')
ax.set_ylabel(r'$x_\alpha$')
[docs]
def draw_Temp_signal(ax: plt.Axes, grid, label=None, color=None, **kwargs):
"""Plot the mean kinetic temperature history T_k(z).
Args:
ax (matplotlib.axes.Axes): Axis to draw on.
grid (TemporalCube | StatisticsEstimator): Data source.
label (str, optional): Legend label.
color (str|tuple, optional): Line color.
**kwargs: Additional keyword arguments forwarded to ``ax.semilogy``.
"""
z_range, mean_tk = _mean(grid, 'Temp')
ax.semilogy(z_range, mean_tk, color=color, label=label, **kwargs)
ax.set_xlim(z_range.min() - 0.2, z_range.max())
ax.set_ylabel(r'$T_{k}$ [K]')
ax.set_xlabel('z')
[docs]
def draw_xHII_signal(ax: plt.Axes, grid, label=None, color=None, **kwargs):
"""Plot the mean ionized fraction x_HII(z).
Args:
ax (matplotlib.axes.Axes): Axis to draw on.
grid (TemporalCube | StatisticsEstimator): Data source.
label (str, optional): Legend label.
color (str|tuple, optional): Line color.
**kwargs: Additional keyword arguments forwarded to ``ax.plot``.
"""
z_range, mean_x_HII = _mean(grid, 'xHII')
ax.plot(z_range, mean_x_HII, color=color, label=label, **kwargs)
ax.set_xlim(z_range.min() - 0.2, z_range.max())
ax.set_ylabel(r'$x_{\mathrm{HII}}$')
ax.set_xlabel('z')
[docs]
def draw_dTb_power_spectrum_of_z(ax: plt.Axes, grid, parameters: Parameters = None, label=None, color=None, k_index=1, k_value=None, **kwargs):
"""Plot the evolution of the dTb power spectrum at a fixed k.
Computes the power spectrum for each snapshot and plots the
dimensionless power at the requested wavenumber index as a
function of redshift.
Args:
ax (matplotlib.axes.Axes): Axis to draw on.
grid (TemporalCube | StatisticsEstimator): Data source.
parameters (Parameters, optional): Required when *grid* is a
:class:`~beorn.structs.TemporalCube`; not needed for
:class:`~beorn.structs.statistics.StatisticsEstimator`.
label (str, optional): Legend label.
color (str|tuple, optional): Line color.
k_index (int, optional): Index of the k-bin to plot (ignored when *k_value* is set).
k_value (float, optional): Target k value in Mpc^-1; nearest bin is used.
**kwargs: Additional keyword arguments forwarded to ``ax.semilogy``.
Returns:
float: The k value actually plotted.
"""
z_range, k_bins, ps_c, mean_dTb = _ps_dTb(grid, parameters, k_value=k_value, k_index=k_index)
k = k_bins[k_index] if k_value is None else k_bins[np.abs(k_bins - k_value).argmin()]
ki = np.abs(k_bins - k).argmin()
ax.semilogy(z_range, ps_c[:, ki], label=label, color=color, **kwargs)
ax.set_ylim(1e-1, 1e3)
ax.set_ylabel(r'$\Delta_\mathrm{21}^{{2}}$ [mK]$^2$')
ax.set_xlabel('z')
ax.set_xlim(z_range.min() - 0.2, z_range.max())
print(f'k={k:.2f} Mpc$^{{-1}}$')
return k
[docs]
def draw_dTb_power_spectrum_of_k(ax: plt.Axes, grid: TemporalCube, parameters: Parameters, z_index=1, z_value=None, label=None, color=None):
"""Plot the dTb power spectrum as a function of k at a given z.
Args:
ax (matplotlib.axes.Axes): Axis to draw on.
grid (TemporalCube): Temporal cube providing ``Grid_dTb`` and ``z``.
parameters (Parameters): Simulation parameters containing kbins and box size.
z_index (int): Index of the redshift slice to analyse.
label (str, optional): Legend label.
color (str, optional): Color used for plotted lines.
"""
z = grid.z[z_index] if z_value is None else grid.z[np.abs(grid.z-z_value).argmin()]
current_grid = grid.Grid_dTb[z_index, ...]
mean_dtb = np.mean(current_grid)
delta_quantity = current_grid / mean_dtb - 1
bin_number = parameters.simulation.kbins.size
box_dims = parameters.simulation.Lbox
ps, bins = t2c.power_spectrum.power_spectrum_1d(delta_quantity, box_dims=box_dims, kbins=bin_number)
ps_c = ps * bins ** 3 * mean_dtb ** 2 / (2 * np.pi ** 2)
ax.semilogy(bins, ps_c, ls='-', label=f"{label} (z={z:.2f})", color=color)
ax.set_ylim(1e-1, 1e3)
ax.set_ylabel(r'$\Delta_\mathrm{21}^{{2}}$ [mK]$^2$')
ax.set_xlabel('k [cMpc$^{-1}$]')
print(f'z={z:.2f}')
return z
[docs]
def full_diff_plot(fig: plt.Figure, grid: TemporalCube, baseline_grid: TemporalCube = None, label: str = None, color: str = None):
"""Create a multi-panel comparison plot of global quantities.
If ``baseline_grid`` is provided the routine also plots fractional
deviations between ``grid`` and ``baseline_grid`` for the set of
global quantities.
Args:
fig (matplotlib.figure.Figure): Figure to draw on; axes will be created if the figure is empty.
grid (TemporalCube): Primary temporal cube to visualise.
baseline_grid (TemporalCube, optional): Reference temporal cube for fractional deviation plots.
label (str, optional): Label applied to plotted lines.
color (str, optional): Color used for plotted lines.
"""
if fig.axes:
axs = fig.axes
else:
axs = fig.subplots(2, 4, sharex=True)
axs = axs.flatten()
draw_x_alpha_signal(axs[0], grid, label=label, color=color)
draw_Temp_signal(axs[1], grid, label=label, color=color)
draw_xHII_signal(axs[2], grid, label=label, color=color)
draw_dTb_signal(axs[3], grid, label=label, color=color)
if baseline_grid is not None:
if grid == baseline_grid:
print("Not comparing baseline grid to itself.")
return
for ax, field, ylabel in [
(axs[4], 'Grid_xal', r'$\Delta x_\alpha$ / $x_\alpha$'),
(axs[5], 'Grid_Temp', r'$\Delta T_k$ / $T_k$'),
(axs[6], 'Grid_xHII', r'$\Delta x_{\mathrm{HII}}$ / $x_{\mathrm{HII}}$'),
(axs[7], 'Grid_dTb', r'$\Delta dT_b$ / $dT_b$'),
]:
grid_value = grid.global_mean(field)
baseline_value = baseline_grid.global_mean(field)
deviation = (grid_value - baseline_value) / baseline_value
ax.plot(grid.z[:], deviation, color=color, label=label)
ax.set_xlabel('z')
ax.set_ylabel(ylabel)