Source code for pisces.extensions.simulation.core.initial_conditions

"""
Core classes for facilitating simulation initial conditions.

This module defines the :class:`InitialConditions` class, which serves as a base for
all simulation initial conditions in Pisces. It allows users to insert models into
a 3D box, providing a framework for defining and manipulating initial conditions
in astrophysical simulations.

"""

import datetime
import shutil
from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union

import numpy as np
import unyt

from pisces.models.core.base import BaseModel
from pisces.models.core.utils import inspect_model_grid, load_model
from pisces.utilities.config import ConfigManager
from pisces.utilities.io_tools import unyt_yaml
from pisces.utilities.log import LogDescriptor

if TYPE_CHECKING:
    from logging import Logger

    from pisces.geometry.grids.base import Grid
    from pisces.particles.base import ParticleDataset


[docs] class InitialConditions(ABC): """ Central base class for configuring simulation initial conditions. This class provides the backbone of simulation initial conditions (ICs) in Pisces. It effectively provides a wrapper by which to place multiple models into a cartesian space with various orientations, velocities, etc. The initial conditions can then be processed by the simulation frontends to generate input files for various astrophysical simulation codes. For more detailed information about using and interacting with initial conditions, read :ref:`initial_conditions_overview`. """ # ============================== # # Class Attributes # # ============================== # # The following attributes are defined at the class level and provide # shared, configurable resources for all instances of InitialConditions. # # Subclasses can override these attributes to customize behavior without # rewriting core logic. For example: # - Replace `__YAML__` to change serialization format or settings # - Replace `logger` to use a different logging configuration or target # # These are not intended to be mutated at the instance level. __YAML__ = unyt_yaml """ The YAML (de)serialization utility used to read/write the ``IC_CONFIG.yaml`` file. By default, this is set to :data:`unyt_yaml`, which is a YAML representer/dumper configured to correctly handle `unyt` array types (units, quantities, etc.). Subclasses may override this attribute to: - Use a different serialization backend - Provide custom YAML formatting options - Support additional non-standard Python types """ logger: "Logger" = LogDescriptor(mode="ics") """ Logger instance for the InitialConditions class. This descriptor provides a class-scoped logger configured with the ``"ics"`` mode. It should be used for all diagnostic, info, warning, and error output within the class. Subclasses may override this attribute to: - Use a different logging category or name - Redirect output to a custom logging handler Settings for the logger may be adjusted in the pisces configuration file under the ``ics`` section of ``logging``. """ _model_file_extensions: dict[str, dict[str, str]] = {"path": {"extension": ""}, "particles": {"extension": "_p"}} """ Standard file extensions for model-related files. This dictionary maps keys used in the model configuration (e.g., ``"path"``, ``"particles"``) to their corresponding file extensions. The extensions are used when copying or moving model files into the initial conditions directory to avoid filename collisions. """ # ============================== # # Class Flags # # ============================== # # These flags are easily modified settings that are used throughout # the base class and should be easily accessible for modification # by subclasses. _model_metadata_required_keys = [] _model_metadata_allowed_keys = ["particles"] _ndim = 3 # ============================== # # Initialization Methods # # ============================== # def _validate_configuration(self, *_, **__): """ Validate the loaded IC configuration for structural and metadata consistency. This method performs the following checks: 1. Ensures that both the metadata section (``__metadata__``) and at least one model entry are present in the configuration. 2. Verifies that the stored `class_name` in metadata matches the current class name. This helps prevent loading ICs with an incompatible class. 3. Validates that all referenced model and particle files exist on disk. Subclasses ---------- Subclasses can override this method to add more specialized validation rules, such as checking for specific metadata keys, required physical parameters, or additional file dependencies. If overriding, consider calling ``super()._validate_configuration(*args, **kwargs)`` to preserve base checks. Raises ------ ValueError If required metadata keys are missing or inconsistent, or if models have mismatched dimensions. FileNotFoundError If any referenced model or particle file is missing. """ self.logger.debug(f"Attempting configuration validation of {self}.") # Start by performing basic validation. Ensure that we have both # metadata and the models sections in the configuration. if "metadata" not in self.__config__: raise ValueError("Configuration file is missing 'metadata' section.") if "models" not in self.__config__: raise ValueError("Configuration file is missing 'models' section.") # Now ensure that the metadata validation is performed and passes. # This section of the mode might need to be altered in subclasses because # of newly added metadata which needs validation. In this default case, # we ensure that the ndim field exists and we check that the class name # matches the current class. metadata = dict(self.__config__["metadata"]) # check the class name. class_name = metadata.get("class_name", None) if class_name != self.__class__.__name__: raise ValueError( f"Configuration class name '{class_name}' does not match the current class '{self.__class__.__name__}'." ) # As long as these pass, we consider the configuration valid. We now need to # validate that the models in models each are valid. That is done # separately in the __init__ method. self.logger.debug(f"{self} passed configuration validation.") @classmethod @abstractmethod def _validate_model(cls, model_name: str, model_info: dict) -> None: # Ensure that the model info contains a path and that # the path exists. model_path = Path(model_info["path"]) if not model_path.exists(): raise FileNotFoundError(f"Model file for '{model_name}' not found: {model_path}") # Ensure all REQUIRED keys are present _check_keys = set(cls._model_metadata_required_keys).union({"path"}) for key in _check_keys: if key not in model_info: raise ValueError(f"Model '{model_name}' is missing required key '{key}'.")
[docs] def __init__(self, directory: Union[str, Path]): """ Load an existing set of initial conditions from a target directory. This base initializer is responsible for: 1. Normalizing and storing the target directory path. 2. Verifying that the directory exists and is a valid directory on disk. 3. Locating and loading the ``IC_CONFIG.yaml`` file associated with the initial condition set. 4. Initializing the internal :class:`~pisces.utilities.config.ConfigManager` for accessing and modifying configuration data. 5. Running the default configuration validation via ``_validate_configuration``. Parameters ---------- directory : str or pathlib.Path Filesystem path to an **existing** initial condition directory that contains a valid ``IC_CONFIG.yaml`` file. Raises ------ FileNotFoundError If the provided directory does not exist or is not a directory. FileNotFoundError If no ``IC_CONFIG.yaml`` file is found in the specified directory. ValueError If the configuration file exists but is structurally invalid. Notes ----- - Subclasses overriding ``__init__`` should call ``super().__init__(directory)`` to ensure the configuration is loaded and validated before performing subclass-specific setup. - The loaded configuration is stored in :attr:`config`, a :class:`~pisces.utilities.config.ConfigManager` instance, and the absolute directory path is available as :attr:`directory`. """ # Normalize to a Path object for consistent behavior # this ensures that we can use all of the relevant Path methods. self.__directory__ = Path(directory) self.logger.info("Loading %s object from directory: %s", self.__class__.__name__, self.__directory__.absolute()) # Verify that the target directory exists and is actually a directory. # If we fail to find the directory, we raise an error. if not self.__directory__.exists(): raise FileNotFoundError(f"Directory '{self.__directory__}' does not exist.") # Locate the configuration file and ensure that it exists before proceeding. # We will additionally check the consistency of the configuration file contents # during the following step in the initialization process. config_path = self.__directory__ / "IC_CONFIG.yaml" if not config_path.exists(): raise FileNotFoundError(f"Configuration file '{config_path}' does not exist in the directory.") self.__config__ = ConfigManager(config_path, autosave=True) # With the configuration manager in place, we proceed to the validation # step. This passes off to ``._validate_configuration``, which could be altered # in subclasses, but essentially just checks that everything is in order. self._validate_configuration() # We now validate the models. for model_name, model_info in dict(self.__config__["models"]).items(): self.logger.debug("Validating model '%s'...", model_name) self._validate_model(model_name, model_info)
# ============================== # # Properties # # ============================== # @property def directory(self) -> Path: """ The directory where the initial conditions are stored. Returns ------- ~pathlib.Path The directory path as a `Path` object. """ return self.__directory__ @property def config(self) -> ConfigManager: """ The configuration manager for the initial conditions. Returns ------- ~pisces.utilities.config.ConfigManager The configuration manager instance containing the initial conditions data. """ return self.__config__ @property def models(self) -> dict[str, dict]: """ The processed models used in the initial conditions. Returns ------- dict A dictionary mapping model names to their properties. """ return dict(self.__config__["models"]) @property def particles(self) -> dict[str, Union[Path, None]]: """ The particle files associated with the models in the initial conditions. Returns ------- dict A dictionary mapping model names to their particle file paths as `Path`. If no particles are associated, the value is `None`. """ return {name: Path(info.get("particles", None)) for name, info in self.models.items()} @property def metadata(self) -> dict: """ The metadata associated with the initial conditions. Returns ------- dict A dictionary containing metadata information. """ return dict(self.__config__["metadata"]) @property def ndim(self) -> int: """ The number of spatial dimensions for the initial conditions. Returns ------- int The number of spatial dimensions (e.g., 3 for 3D). """ return self.__class__._ndim # ============================== # # MAGIC Methods # # ============================== # def __repr__(self) -> str: return f"<{self.__class__.__name__}(dir='{self.__directory__}', models={len(self.models)})>" def __str__(self) -> str: """Human-readable summary of the InitialConditions object.""" return f"{self.__class__.__name__}({self.__directory__.name})" def __len__(self) -> int: """Get the number of models in this initial condition set.""" return len(self.models) def __getitem__(self, model_name: str) -> dict: """ Retrieve a model’s configuration by name. Parameters ---------- model_name : str The model identifier. Returns ------- dict Dictionary of the model’s configuration. Raises ------ KeyError If no model with the given name exists. """ try: return self.models[model_name] except KeyError as exp: raise KeyError(f"Model '{model_name}' not found in initial conditions.") from exp def __contains__(self, model_name: str) -> bool: """Check if a model with the given name exists in this initial condition set.""" return model_name in self.models def __iter__(self): """Iterate over the names of all models in the initial conditions.""" return iter(self.models) # ============================== # # Methods - Model Interaction # # ============================== #
[docs] def load_model(self, name: str) -> BaseModel: """ Load a specific model from the initial conditions. This method retrieves the file path for the given model name from the internal configuration, verifies that the file exists, and uses :func:`load_model` to instantiate and return the model. Parameters ---------- name : str The name/identifier of the model to load. Must match one of the keys in :attr:`models`. Returns ------- ~pisces.models.core.base.BaseModel An instance of the loaded model. Raises ------ KeyError If no model with the given ``name`` exists in the configuration. FileNotFoundError If the referenced model file does not exist on disk. """ models = self.__config__.get("models", {}) if name not in models: raise KeyError(f"No model named '{name}' found in initial conditions.") model_path = Path(models[name]["path"]) if not model_path.exists(): raise FileNotFoundError(f"Model file not found for '{name}': {model_path}") return load_model(model_path)
[docs] def remove_model(self, name: str) -> None: """ Remove a model and its associated files from the initial conditions. This method: 1. Verifies that the specified model exists in the configuration. 2. Deletes the model's main file and optional particle file from disk (if they exist). 3. Removes the model entry from the configuration. Parameters ---------- name : str The name/identifier of the model to remove. Must match one of the keys in :attr:`models`. Raises ------ KeyError If no model with the given ``name`` exists in the configuration. """ models = self.__config__.get("models", {}) if name not in models: raise KeyError(f"No model named '{name}' found in initial conditions.") model_info = models[name] # Remove associated files if they exist for key in self._model_file_extensions: if key in model_info and model_info[key] is not None: try: Path(model_info[key]).unlink() except FileNotFoundError: pass # File already gone; no issue # Remove from configuration del self.__config__[f"models.{name}"] self.logger.info(f"Removed model '{name}' from initial conditions.")
[docs] def add_model( self, name: str, model: Union[str, Path, BaseModel], model_config: dict[str, Any], file_processing_mode: str = "copy", overwrite: bool = False, **kwargs, ): """ Add a new model to the initial conditions set. This method can be used to incorporate a new model into the initial conditions. In doing so, the model is validated to ensure that it satisfies the conditions for inclusion in the ICs and then is added in the same way that the original dataset was generated. This includes copying or moving the model file into the IC directory and updating the configuration file. Parameters ---------- name : str The name of the model being added. model : str, ~pathlib.Path or ~pisces.models.core.base.BaseModel The model to add, specified as either: - **Path-like (str or Path)**: Path to the model file on disk. The file will be copied or moved into the initial conditions directory according to ``file_processing_mode``. - **BaseModel instance**: An already loaded model object. Its source file path will be obtained from the model's metadata. In both cases, the model file must exist and be accessible before calling this method. model_config : tuple A tuple specifying the model configuration settings. This may vary from code to code and should be documented in the subclass. Commonly, this includes parameters such as: - ``position`` : unyt array or sequence of length ``ndim`` - ``velocity`` : unyt array or sequence of length ``ndim`` - ``orientation`` : sequence or array defining orientation vector - ``spin`` : scalar float file_processing_mode : {"copy", "move"}, default="copy" Determines how the provided model file is placed into the initial conditions directory: - ``"copy"``: The source file is copied into the IC directory, preserving the original file in its current location. - ``"move"``: The source file is moved into the IC directory, removing it from its original location. overwrite : bool, optional If ``True``, an existing model entry with the same ``name`` will be replaced, and any associated files will be deleted or overwritten as needed. If ``False`` (default), attempting to add a model with a duplicate name will raise a :class:`ValueError`. kwargs: Additional keyword arguments passed to the model processing function. This may include options specific to the subclass or model type. """ # Ensure that the model name does not already exist. # If it does, we need tell the user to manually remove it first. if name in self.models: if not overwrite: raise ValueError( f"Model '{name}' already exists. Set `overwrite=True` to replace it or remove it first." ) else: self.logger.warning(f"Overwriting existing model '{name}'.") self.remove_model(name) # --- Process the new model --- # # This mirrors the process we go through for # the initialization of models when creating the class. model_config = dict(model_config) model_config["model_name"] = name model_config["model"] = model _validated_model = self.__class__._validate_input_model(model, self.models, **kwargs) _validated_model_name = _validated_model.pop("model_name") # Now we need to ensure that the model gets either copied or moved # into the directory as needed. This is done with the ``_process_model`` # method. _validated_model = self.__class__._process_model( _validated_model_name, _validated_model, file_processing_mode=file_processing_mode, **kwargs ) # Add the post-validation model to the dictionary of ready-to-go # models. self.__config__["models"][_validated_model_name] = _validated_model self.logger.info(f"Added model '{name}' to initial conditions.")
[docs] def list_models(self): """ List all models currently stored in these initial conditions. Returns ------- list A list of model names (strings) present in the initial conditions. """ return list(self.models.keys())
[docs] def has_model(self, model_name: str): """ Check if a model with the given name exists in these initial conditions. Parameters ---------- model_name : str The name/identifier of the model to check. Returns ------- bool ``True`` if the model exists, ``False`` otherwise. """ return model_name in self.models
[docs] def update_model(self, model_name: str, **parameters) -> None: """ Update parameters of an existing model in the initial conditions. Parameters ---------- model_name : str The name/identifier of the model to update. parameters : dict Key-value pairs of parameters to update. Raises ------ KeyError If no model with the given ``model_name`` exists. ValueError If provided parameters are invalid or inconsistent with expected dimensions/units. """ # Ensure that we have the model to update. if model_name not in self.models: raise ValueError(f"No model named '{model_name}' found in initial conditions.") # Cycle through the keys and values of the # parameters dictionary and update. _permitted_keys = set(self._model_metadata_required_keys).union(self._model_metadata_allowed_keys) for key, value in parameters.items(): if key not in _permitted_keys: raise ValueError(f"Parameter '{key}' is not a valid model parameter. ") # Update the value self.__config__[f"models.{model_name}.{key}"] = value
[docs] def get_model_info(self, model_name: str): """ Retrieve the configuration information for a specific model. Parameters ---------- model_name : str The name/identifier of the model to retrieve. Returns ------- dict A dictionary containing the model's configuration information. Raises ------ KeyError If no model with the given ``model_name`` exists. """ if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") return self.models[model_name]
[docs] def get_model_fields(self, model_name: str): """ Inspect the available fields in a stored model without fully loading it. This method uses :func:`~pisces.models.core.utils.inspect_model_fields` to open the model file, look inside its ``/FIELDS`` group, and return a list of field names and their shapes. Parameters ---------- model_name : str The name/identifier of the model to inspect. Returns ------- list of tuple A list of ``(field_name, shape)`` pairs. Raises ------ KeyError If the model name is not found in the initial conditions. FileNotFoundError If the referenced model file does not exist. """ from pisces.models.core.utils import inspect_model_fields if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") model_path = Path(self.models[model_name]["path"]) if not model_path.exists(): raise FileNotFoundError(f"Model file not found for '{model_name}': {model_path}") return inspect_model_fields(model_path)
[docs] def get_model_metadata(self, model_name: str): """ Inspect the metadata of a stored model without fully loading it. This method uses :func:`~pisces.models.core.utils.inspect_model_metadata` to open the model file, look inside its ``/METADATA`` group, and return a dictionary of metadata key-value pairs. Parameters ---------- model_name : str The name/identifier of the model to inspect. Returns ------- dict A dictionary containing the model's metadata. Raises ------ KeyError If the model name is not found in the initial conditions. FileNotFoundError If the referenced model file does not exist. """ from pisces.models.core.utils import inspect_model_metadata if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") model_path = Path(self.models[model_name]["path"]) if not model_path.exists(): raise FileNotFoundError(f"Model file not found for '{model_name}': {model_path}") return inspect_model_metadata(model_path)
[docs] def get_model_coordinate_system(self, model_name: str): """ Retrieve the coordinate system of a stored model without fully loading it. This method opens the model file, reads the metadata, and returns the coordinate system object. Parameters ---------- model_name : str The name/identifier of the model to inspect. Returns ------- CoordinateSystem The coordinate system object associated with the model. Raises ------ KeyError If the model name is not found in the initial conditions. FileNotFoundError If the referenced model file does not exist. """ from pisces.models.core.utils import inspect_model_coordinate_system if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") model_path = Path(self.models[model_name]["path"]) if not model_path.exists(): raise FileNotFoundError(f"Model file not found for '{model_name}': {model_path}") return inspect_model_coordinate_system(model_path)
[docs] def get_model_grid(self, model_name: str) -> "Grid": """ Retrieve the grid object of a stored model without fully loading it. This method opens the model file, reads the metadata, and returns the grid object. Parameters ---------- model_name : str The name/identifier of the model to inspect. Returns ------- Grid The grid object associated with the model. Raises ------ KeyError If the model name is not found in the initial conditions. FileNotFoundError If the referenced model file does not exist. """ from pisces.models.core.utils import inspect_model_grid if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") model_path = Path(self.models[model_name]["path"]) if not model_path.exists(): raise FileNotFoundError(f"Model file not found for '{model_name}': {model_path}") return inspect_model_grid(model_path)
[docs] def get_model_summary(self, model_name: str, include_fields: bool = True): """ Retrieve a combined summary of a stored model without fully loading it. Parameters ---------- model_name : str The name/identifier of the model to inspect. include_fields : bool, optional If True (default), also include the model's fields and shapes. Returns ------- dict A dictionary with keys: - ``metadata`` : dict - ``grid`` : Grid - ``coordinate_system`` : CoordinateSystem - ``fields`` : list[(name, shape)] (if ``include_fields`` is True) """ summary = { "metadata": self.get_model_metadata(model_name), "grid": self.get_model_grid(model_name), "coordinate_system": self.get_model_coordinate_system(model_name), } if include_fields: summary["fields"] = self.get_model_fields(model_name) return summary
# ================================ # # Methods - Particle Interaction # # ================================ #
[docs] def has_particles(self, model_name: str) -> bool: """ Check if a specific model has associated particles. Parameters ---------- model_name : str The name/identifier of the model to check. Returns ------- bool ``True`` if the model has associated particles, ``False`` otherwise. Raises ------ KeyError If no model with the given ``model_name`` exists. """ if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") particle_path = self.models[model_name].get("particles", None) return particle_path is not None and Path(particle_path).exists()
[docs] def list_models_with_particles(self) -> list[str]: """ List all models that have associated particle datasets. Returns ------- list A list of model names (strings) that have associated particles. """ return [name for name, info in self.models.items() if info.get("particles", None) is not None]
[docs] def get_particle_path(self, model_name: str) -> Union[Path, None]: """ Retrieve the particle file path associated with a specific model. Parameters ---------- model_name : str The name/identifier of the model to inspect. Returns ------- ~pathlib.Path or None The path to the particle dataset file if it exists, otherwise `None`. Raises ------ KeyError If no model with the given ``model_name`` exists. """ if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") particle_path = self.models[model_name].get("particles", None) return Path(particle_path) if particle_path is not None else None
[docs] def get_particle_path_dict(self) -> dict[str, Union[Path, None]]: """ Retrieve a dictionary mapping model names to their associated particle file paths. Returns ------- dict A dictionary where keys are model names and values are paths to the particle dataset files (or `None` if no particles are associated). """ return { name: Path(info["particles"]) if info.get("particles", None) is not None else None for name, info in self.models.items() }
[docs] def load_particles(self, model_name: str, **kwargs) -> "ParticleDataset": """ Load the particle dataset associated with a specific model. Parameters ---------- model_name : str The name/identifier of the model whose particles to load. kwargs : Additional keyword arguments to pass to the :class:`~pisces.particles.base.ParticleDataset` constructor. Returns ------- ~pisces.particles.base.ParticleDataset An instance of the loaded particle dataset. Raises ------ KeyError If no model with the given ``model_name`` exists. FileNotFoundError If the referenced particle file does not exist. ValueError If no particles are associated with the specified model. """ from pisces.particles import ParticleDataset if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") particle_path = self.models[model_name].get("particles", None) if particle_path is None: raise ValueError(f"No particles associated with model '{model_name}'.") particle_path = Path(particle_path) if not particle_path.exists(): raise FileNotFoundError(f"Particle file for '{model_name}' not found: {particle_path}") return ParticleDataset(particle_path, **kwargs)
[docs] def add_particles_to_model( self, particle_path: Union[str, Path], model_name: str, file_processing_mode: str = "copy", overwrite: bool = False, ): """ Attach a particle dataset file to an existing model in the initial conditions. This method: 1. Validates that the target model exists in the configuration. 2. Checks if the specified particle file exists on disk. 3. Copies or moves the file into the initial conditions directory, naming it ``<model_name>_p.hdf5`` to avoid collisions. 4. Updates the configuration to record the particle file path. 5. Optionally overwrites any previously associated particle file if ``overwrite=True``. Parameters ---------- particle_path : str or ~pathlib.Path Path to the particle dataset file to associate with the model. model_name : str The name/identifier of the model to which the particles will be added. file_processing_mode : {"copy", "move"}, default="copy" How to handle the provided file: * ``"copy"`` – Copy the file into the IC directory (original remains intact). * ``"move"`` – Move the file into the IC directory (original is removed). overwrite : bool, optional If ``True``, any existing particle file for this model will be overwritten. If ``False`` (default), an error will be raised if a particle file is already associated with the model. Raises ------ KeyError If no model with the given ``model_name`` exists in the configuration. FileNotFoundError If the specified ``particle_path`` does not exist. ValueError If ``file_processing_mode`` is not one of ``"copy"`` or ``"move"``. FileExistsError If a particle file is already associated with the model and ``overwrite`` is False. """ _particle_file_extension = self.__class__._model_file_extensions["particles"]["extension"] if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") particle_path = Path(particle_path) if not particle_path.exists() or not particle_path.is_file(): raise FileNotFoundError(f"Particle file '{particle_path}' does not exist.") dest_path = self.directory / f"{model_name}{_particle_file_extension}.hdf5" # Check for existing particle file existing_particle_path = self.models[model_name].get("particles") if existing_particle_path: existing_file = Path(existing_particle_path) if existing_file.exists(): if not overwrite: raise FileExistsError( f"Model '{model_name}' already has a particle file at {existing_file}. " "Use `overwrite=True` to replace it." ) else: existing_file.unlink() self.logger.warning(f"Overwriting existing particle file for model '{model_name}'.") # Copy or move the new particle file if file_processing_mode == "copy": shutil.copy2(particle_path, dest_path) elif file_processing_mode == "move": shutil.move(str(particle_path), str(dest_path)) else: raise ValueError(f"Invalid file_processing_mode '{file_processing_mode}'; must be 'copy' or 'move'.") # Update configuration self.__config__[f"models.{model_name}.particles"] = str(dest_path) self.logger.info(f"Added particles from '{dest_path}' to model '{model_name}'.")
[docs] def remove_particles_from_model(self, model_name: str, delete_particle_file: bool = True): """ Detach and optionally delete the particle dataset associated with a specific model. This method: 1. Validates that the target model exists in the configuration. 2. Checks if a particle file is associated with the model. 3. Optionally deletes the particle file from disk. 4. Updates the configuration to remove the particle file reference. Parameters ---------- model_name : str The name/identifier of the model from which to remove particles. delete_particle_file : bool, optional If ``True`` (default), the particle file will be deleted from disk. If ``False``, only the reference in the configuration will be removed. Raises ------ KeyError If no model with the given ``model_name`` exists in the configuration. ValueError If no particles are associated with the specified model. """ if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") particle_path = self.models[model_name].get("particles", None) if particle_path is None: raise ValueError(f"No particles associated with model '{model_name}' to remove.") particle_path = Path(particle_path) if delete_particle_file and particle_path.exists(): particle_path.unlink() self.logger.info(f"Deleted particle file '{particle_path}' for model '{model_name}'.") # Remove reference from configuration del self.__config__[f"models.{model_name}.particles"] self.logger.info(f"Removed particle association from model '{model_name}'.")
[docs] def get_particle_count(self, model_name: str) -> dict[str, int]: """Inspect the number of particles in each species group for a stored model. This is a wrapper around :func:`~pisces.particles.utils.inspect_particle_count`. Parameters ---------- model_name : str The name/identifier of the model whose particle counts should be inspected. Returns ------- dict of str, int Mapping of particle species names to the number of particles in each. Raises ------ KeyError If the model name is not found in the initial conditions. FileNotFoundError If the referenced particle file does not exist. ValueError If a particle group is missing the required ``NUMBER_OF_PARTICLES`` attribute. """ from pisces.particles.utils import inspect_particle_count if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") particle_path = self.particles[model_name] if particle_path is None or not particle_path.is_file(): raise FileNotFoundError(f"No particle file associated with model '{model_name}'.") return inspect_particle_count(particle_path)
[docs] def get_particle_species(self, model_name: str) -> list[str]: """List the particle species present in the stored particle dataset for a given model. This is a wrapper around :func:`~pisces.particles.utils.inspect_species`. Parameters ---------- model_name : str The name/identifier of the model whose particle species should be inspected. Returns ------- list of str A list of particle species names (HDF5 group names) present in the particle file. Raises ------ KeyError If the model name is not found in the initial conditions. FileNotFoundError If the referenced particle file does not exist. """ from pisces.particles.utils import inspect_species if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") particle_path = self.particles[model_name] if particle_path is None or not particle_path.is_file(): raise FileNotFoundError(f"No particle file associated with model '{model_name}'.") return inspect_species(particle_path)
[docs] def get_particle_fields(self, model_name: str) -> dict[str, list[tuple[str, tuple[int, ...]]]]: """ Inspect the available fields for each particle species in a model. This is a wrapper around :func:`~pisces.particles.utils.inspect_fields`. Parameters ---------- model_name : str The name/identifier of the model whose particle fields should be inspected. Returns ------- dict Mapping of particle species names to lists of ``(field_name, element_shape)`` tuples. Raises ------ KeyError If the model name is not found in the initial conditions. FileNotFoundError If the referenced particle file does not exist. """ from pisces.particles.utils import inspect_fields if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") particle_path = self.particles[model_name] if particle_path is None or not particle_path.is_file(): raise FileNotFoundError(f"No particle file associated with model '{model_name}'.") return inspect_fields(particle_path)
[docs] def generate_particles( self, model_name: str, num_particles: dict[str, int], overwrite: bool = False, **kwargs ) -> None: """ Generate particles for a stored model using its internal particle generation method. This method will load the model, verify that it implements a ``generate_particles`` method, and invoke it with the provided particle counts. Parameters ---------- model_name : str The name/identifier of the model for which to generate particles. num_particles : dict of str, int Mapping of particle species names to the number of particles to generate for each. overwrite : bool If ``True``, any existing particle file for this model will be overwritten. If ``False`` (default), an error will be raised if a particle file is already associated with the model. kwargs: Additional keyword arguments to pass to the model's ``generate_particles`` method. Raises ------ KeyError If the specified model is not found in the initial conditions. NotImplementedError If the loaded model does not implement a ``generate_particles`` method. """ # Ensure that we actually have access to this model # and then load the model into memory so that we can # have it generate the particles. if model_name not in self.models: raise KeyError(f"No model named '{model_name}' found in initial conditions.") if self.__config__.get(f"models.{model_name}.particles", None) is not None: if overwrite: self.remove_particles_from_model(model_name, delete_particle_file=True) else: raise FileExistsError( f"Model '{model_name}' already has associated particles. " "Set `overwrite=True` to regenerate and replace them." ) model = self.load_model(model_name) # Ensure that the model actually supports particle generation before # going any further with the process. if not hasattr(model, "generate_particles"): raise NotImplementedError(f"The model '{model_name}' does not support particle generation.") # Delegate to the model's particle generation method _particle_file_extension = self.__class__._model_file_extensions["particles"]["extension"] # noinspection PyUnresolvedReferences _generated_particles = model.generate_particles( self.__directory__ / f"{model_name}{_particle_file_extension}.hdf5", num_particles, **kwargs ) self.config[f"models.{model_name}.particles"] = ( self.__directory__ / f"{model_name}{_particle_file_extension}.hdf5" ) self.logger.info(f"Generated particles for model '{model_name}' with counts: {num_particles}")
# ============================== # # Generator Methods # # ============================== # # These methods are used to generate the skeleton of an initial conditions # object. The methods here should be overridden in subclasses to specialize # the behavior of the initial conditions for a specific simulation code. # # ``_process_models`` takes a set of models and associated metadata and # proceeds to process them, returning a CONFIG-compatible dictionary of # model names and their properties. It also moves or copies the model files # into the initial conditions directory as needed. # # `` @classmethod def _process_metadata(cls, directory: Path, **kwargs) -> dict: """ Generate core metadata for the initial conditions set. The base implementation records only invariant structural information required for all IC sets. Subclasses are expected to extend this to add simulation- or geometry-specific keys (e.g., ``ndim``, coordinate system, cosmology parameters). Parameters ---------- directory : ~pathlib.Path The directory where ICs are being created. **kwargs : Ignored in the base implementation. Subclasses may interpret these values when extending this method. Returns ------- dict Metadata dictionary to be stored alongside model data. Guaranteed keys: - ``created_at`` : ISO 8601 timestamp (UTC) - ``class_name`` : str, subclass name - ``directory`` : str, absolute resolved path """ timestamp = datetime.datetime.now(datetime.UTC).isoformat() return { "metadata": { "created_at": timestamp, "class_name": cls.__name__, "directory": str(directory.resolve()), } } @classmethod @abstractmethod def _validate_input_model(cls, model: dict, existing_models: dict, **kwargs) -> dict: """ Validate and normalize a single input model definition. This method enforces a consistent structure for model dictionaries before they are incorporated into an initial conditions set. It is the *primary extension point* for subclasses that need specialized rules (e.g., different dimensionalities, extra fields, spherical coordinates). .. rubric:: Validation guarantees By default, validation ensures that: - ``model_name`` is a unique string across all models in this IC set. - ``model`` points to a valid model file on disk (string/Path or :class:`~pisces.models.core.base.BaseModel`). If a model was provided directly originally, the path to its file is extracted and stored. - ``particles`` (if specified) points to a valid file on disk. - ``position`` is a :class:`~unyt.unyt_array` with length units. - ``velocity`` is a :class:`~unyt.unyt_array` with length/time units. This should be improved / extended for subclasses to ensure that the models conform to the expectations of the specific simulation code. .. rubric:: Subclass extension points - Add or modify required keys by overriding :attr:`_model_metadata_required_keys`. - Enforce dimensionality of vectors (e.g., 1D, 2D, spherical). - Validate additional fields (orientation, spin, metallicity, etc.). - Inject default values for optional parameters if not provided. Parameters ---------- model : dict Raw model specification provided by the user. Must include at least the keys listed in :attr:`_model_metadata_required_keys`. existing_models : dict Dictionary of already-processed models. Used to ensure uniqueness of ``model_name`` and to cross-check consistency. **kwargs : Extra options passed down from higher-level calls. Subclasses may use these to specialize validation (e.g., enforcing ``ndim``). Returns ------- dict A validated, standardized model dictionary ready for inclusion in the configuration. Keys are guaranteed to include: - ``model_name`` : str - ``path`` : Path - ``position`` : unyt_array - ``velocity`` : unyt_array - ``[particles_path]`` : Path (if provided) Subclasses may add additional validated keys. Raises ------ ValueError If required keys are missing or invalid. TypeError If fields are of incorrect types or units. FileNotFoundError If the referenced model file does not exist. """ # --- Check Required Keys [Invariant] --- # # Ensure that all of the required keys are present. This is invariant # across all subclasses and should not need to be overwritten. for _required_key in cls._model_metadata_required_keys: if _required_key not in model: raise ValueError(f"Model definition is missing required key: {_required_key}.") # --- Check Only Optional / Required Keys [Invariant] --- # # Ensure that no unexpected keys are present. This is invariant # across all subclasses and should not need to be overwritten. allowed_keys = set(cls._model_metadata_required_keys).union(cls._model_metadata_allowed_keys) allowed_keys = allowed_keys.union({"model_name", "model"}) for key in model.keys(): if (key not in allowed_keys) and (key != "model_name"): raise ValueError(f"Model definition contains unexpected key: {key}.") # --- Model Name Uniqueness [Invariant] --- # # Ensure that the model name is a string and is unique across # all models in this IC set. This is invariant across all # subclasses and should not need to be overwritten. model["model_name"] = str(model["model_name"]) if model["model_name"] in existing_models: suffix = 1 new_name = f"{model['model_name']}_{suffix}" while new_name in existing_models: suffix += 1 new_name = f"{model['model_name']}_{suffix}" model["model_name"] = new_name # --- Model Processing [Invariant] --- # # At this stage, we validate the model itself and ensure # that the model exists, convert the model to a path if # necessary and then proceed. # Ensure that the model specification is actually a path. attached_model = model.pop("model") if isinstance(attached_model, (str, Path)): model_path = Path(attached_model) if not model_path.exists() or not model_path.is_file(): raise FileNotFoundError(f"Model file '{model_path}' does not exist or is not a file.") model["path"] = model_path elif isinstance(attached_model, BaseModel): model_path = Path(attached_model.__path__) if not model_path.exists() or not model_path.is_file(): raise FileNotFoundError(f"Model file '{model_path}' does not exist or is not a file.") model["path"] = model_path else: raise TypeError( f"Model must be a string, Path, or BaseModel instance, not {type(attached_model).__name__}." ) # Once we complete the validation, we return the model. return model @classmethod def _process_model( cls, directory: Path, model_name: str, model_info: dict, file_processing_mode: str = "copy", **kwargs ) -> dict: """ Finalize a validated model for inclusion in the IC directory. This method performs any required file operations (copy/move) and ensures that the model dictionary is in a config-ready format. Subclasses may extend this to perform additional tasks such as particle generation, orientation defaults, etc. Parameters ---------- directory : ~pathlib.Path The target directory where the model file should be placed. model_name : str Unique identifier for the model. This is guaranteed to have been validated for uniqueness upstream in :meth:`_validate_input_model`. model_info : dict The validated model dictionary. file_processing_mode : {"copy", "move"}, default="copy" How to place the model file into the IC directory: * ``"copy"`` – Copy the source file, preserving the original. * ``"move"`` – Move the source file, removing the original. **kwargs : Extra options forwarded from higher-level calls. Subclasses may use these to specialize processing. Returns ------- dict The finalized model dictionary. Raises ------ FileNotFoundError If the source model file does not exist. ValueError If ``file_processing_mode`` is invalid. """ # For each of the models, we need to go through all of the # filehook keys and move things. for file_key, info in cls._model_file_extensions.items(): extension = info.get("extension", "") if file_key in model_info: src_path = Path(model_info[file_key]) if not src_path.exists(): raise FileNotFoundError(f"File for key '{file_key}' not found for '{model_name}': {src_path}") dest_path = directory / f"{model_name}{extension}{src_path.suffix}" print(dest_path) if file_processing_mode == "copy": shutil.copy2(src_path, dest_path) elif file_processing_mode == "move": shutil.move(str(src_path), str(dest_path)) else: raise ValueError( f"Invalid file_processing_mode '{file_processing_mode}'; must be 'copy' or 'move'." ) model_info[file_key] = dest_path cls.logger.debug( "Processed model '%s': stored at %s [mode=%s]", model_name, model_info["path"], file_processing_mode, ) # Return the finalized dictionary for this model. return model_info
[docs] @classmethod def create_ics(cls, directory: Union[str, Path], *models: dict, file_processing_mode: str = "copy", **kwargs): """ Create a new initial conditions (IC) directory and return an :class:`InitialConditions` instance. This is the main entry point for building an IC set from scratch. It performs all required filesystem setup, model validation, and configuration writing. .. rubric:: Workflow 1. **Directory setup**: - Create the target directory if it does not exist. - If it exists and is non-empty, ``overwrite=True`` is required in ``kwargs`` to clear it safely. 2. **Model processing**: - Each model definition (dict) is validated via :meth:`_validate_input_model`. - Files are copied or moved into the IC directory via :meth:`_process_model`. 3. **Configuration file**: - Metadata and processed model information are assembled. - A YAML file ``IC_CONFIG.yaml`` is written for later reload. Parameters ---------- directory : str or ~pathlib.Path Target directory for the IC set. Must be empty unless ``overwrite=True`` is provided. All of the model files and any connected files will be copied / moved into this directory so that it becomes the centralized location for the IC set. *models : dict Model definitions. The expected keys in each model may vary from subclass to subclass, but at a minimum we expect: - ``"model_name"`` : str The unique name/identifier for this model. - ``"model"`` : str, Path, or BaseModel The model specification, either as a path to a model file - ``"position"`` : unyt_array with length units The position of the model in the simulation volume. - ``"velocity"`` : unyt_array with length/time units The bulk velocity of the model in the simulation volume. file_processing_mode: {"copy", "move"}, default="copy" How to handle the provided model files: * ``"copy"`` – Copy the files into the IC directory (originals remain intact). * ``"move"`` – Move the files into the IC directory (originals are removed). **kwargs : Additional keyword arguments which are subclass dependent. Returns ------- InitialConditions A fully initialized instance pointing to the new directory. Raises ------ FileExistsError If the directory exists and is non-empty, and ``overwrite`` is False. FileNotFoundError If a referenced model file does not exist. ValueError If a model definition is missing required keys or is otherwise invalid. Notes ----- - This method is not intended to be overridden by subclasses. Instead, override the following hooks to customize behavior: * :meth:`_validate_input_model` * :meth:`_process_model` * :meth:`_process_metadata` - The returned instance is ready for immediate use in frontend converters (e.g., Gadget, AREPO). """ # --- Directory Setup [INVARIANT] --- # # Process the provided directory. We check that it is a valid directory # and that it doesn't contain any existing files that need to be overwritten. # # This is a structural invariant of this class and should NOT be overwritten # by subclasses to ensure that the structure is contiguous. cls.logger.info("Creating initial conditions in directory: %s", directory) directory = Path(directory) if not directory.exists(): directory.mkdir(parents=True, exist_ok=True) cls.logger.debug("Created new directory: %s", directory) else: has_content = any(directory.iterdir()) if has_content and kwargs.get("overwrite", False): shutil.rmtree(directory) directory.mkdir(parents=True, exist_ok=True) cls.logger.debug("Overwrote existing directory: %s", directory) elif has_content: raise FileExistsError( f"Directory `{directory}` is not empty. Use `overwrite=True` to replace its contents." ) # --- Model Processing --- # # At this stage, we move onto model processing. This is a two step process: # 1. We convert the metadata into a standardized dictionary that we can write # directly into the IC file when we're ready. # 2. We copy or move the model files into the directory as needed. # # Subclasses may override parts of this process. _validated_models = {} for model in models: # Each model is a dictionary contained some set of keys. The first # step will be to ensure all of the expected model attributes are # present and to ensure that there are no unexpected keys. This is # do in the _validate_input_model method. _validated_model = cls._validate_input_model(model, _validated_models, **kwargs) _validated_model_name = _validated_model.pop("model_name") # Now we need to ensure that the model gets either copied or moved # into the directory as needed. This is done with the ``_process_model`` # method. _validated_model = cls._process_model( directory, _validated_model_name, _validated_model, file_processing_mode=file_processing_mode, **kwargs ) # Add the post-validation model to the dictionary of ready-to-go # models. _validated_models[_validated_model_name] = _validated_model # --- Create the IC_CONFIG File --- # # We create a configuration file that contains the processed models # and their properties. This is a simple YAML file that can be read later. with open(directory / "IC_CONFIG.yaml", "w") as f: metadata = cls._process_metadata(directory, **kwargs) metadata["models"] = _validated_models cls.__YAML__.dump(metadata, f) # Return the class initialized with the directory. cls.logger.info("Initial conditions created successfully in %s", directory) return cls(directory)
[docs] class InitialConditions1DSpherical(InitialConditions): """ Initial conditions for 1D spherical simulations. This subclass enforces that all models are defined in a strictly spherical coordinate system with only a radial dependence (``r`` axis). Bulk positions and velocities are not used—models are implicitly located at the origin with zero translational motion. Use this class when constructing spherically symmetric ICs, e.g. radial gas/halo profiles. The class ensures that: - The model grid uses :class:`~pisces.geometry.coordinates.coordinate_systems.SphericalCoordinateSystem`. - The grid has exactly one active axis (``r``). - Only ``model_name`` and ``model`` are required keys; particles may be added. Subclasses may extend validation to enforce additional spherical-specific metadata (e.g., radial boundary conditions, outer cutoff radii). """ # ============================== # # Class Flags # # ============================== # # These flags are easily modified settings that are used throughout # the base class and should be easily accessible for modification # by subclasses. _model_metadata_required_keys = [] _model_metadata_allowed_keys = ["particles"] _ndim = 1 # ============================== # # Initialization Methods # # ============================== # @classmethod def _validate_model(cls, model_name: str, model_info: dict) -> None: super()._validate_model(model_name, model_info) # ============================== # # Generator Methods # # ============================== # # These methods are used to generate the skeleton of an initial conditions # object. The methods here should be overridden in subclasses to specialize # the behavior of the initial conditions for a specific simulation code. # # ``_process_models`` takes a set of models and associated metadata and # proceeds to process them, returning a CONFIG-compatible dictionary of # model names and their properties. It also moves or copies the model files # into the initial conditions directory as needed. # @classmethod def _validate_input_model(cls, model: dict, existing_models: dict, **kwargs) -> dict: # --- Check Required Keys [Invariant] --- # # Ensure that all of the required keys are present. This is invariant # across all subclasses and should not need to be overwritten. for _required_key in cls._model_metadata_required_keys: if _required_key not in model: raise ValueError(f"Model definition is missing required key: {_required_key}.") # --- Check Only Optional / Required Keys [Invariant] --- # # Ensure that no unexpected keys are present. This is invariant # across all subclasses and should not need to be overwritten. allowed_keys = set(cls._model_metadata_required_keys).union(cls._model_metadata_allowed_keys) for key in model.keys(): if key not in allowed_keys: raise ValueError(f"Model definition contains unexpected key: {key}.") # --- Model Name Uniqueness [Invariant] --- # # Ensure that the model name is a string and is unique across # all models in this IC set. This is invariant across all # subclasses and should not need to be overwritten. model["model_name"] = str(model["model_name"]) if model["model_name"] in existing_models: suffix = 1 new_name = f"{model['model_name']}_{suffix}" while new_name in existing_models: suffix += 1 new_name = f"{model['model_name']}_{suffix}" model["model_name"] = new_name # --- Model Processing [Invariant] --- # # At this stage, we validate the model itself and ensure # that the model exists, convert the model to a path if # necessary and then proceed. # Ensure that the model specification is actually a path. attached_model = model.pop("model") if isinstance(attached_model, (str, Path)): model_path = Path(attached_model) if not model_path.exists() or not model_path.is_file(): raise FileNotFoundError(f"Model file '{model_path}' does not exist or is not a file.") model["path"] = model_path elif isinstance(attached_model, BaseModel): model_path = Path(attached_model.__path__) if not model_path.exists() or not model_path.is_file(): raise FileNotFoundError(f"Model file '{model_path}' does not exist or is not a file.") model["path"] = model_path else: raise TypeError( f"Model must be a string, Path, or BaseModel instance, not {type(attached_model).__name__}." ) # --- Ensure Spherical Coordinate System --- # # We access the model's grid and ensure that it is spherical and that # we only have 1 active axis (r). model_grid = inspect_model_grid(model["path"]) if model_grid.coordinate_system.__class__.__name__ != "SphericalCoordinateSystem": raise ValueError( f"Model '{model['model_name']}' must use a spherical coordinate system, " f"not {model_grid.coordinate_system.__class__.__name__}." ) if set(model_grid.active_axes) != {"r"}: raise ValueError( f"Model '{model['model_name']}' must have only a radial grid dependence, not {model_grid.active_axes}." ) # Once we complete the validation, we return the model. return model
[docs] class InitialConditionsCartesian(InitialConditions, ABC): """ Abstract base class for Cartesian initial conditions. Provides a general framework for simulations defined in Cartesian coordinates with ``ndim`` active spatial dimensions. This class enforces the presence and dimensionality of ``position`` and ``velocity`` vectors and provides convenience accessors for them. Key features: - Ensures positions are length vectors of shape ``(ndim,)``. - Ensures velocities are length/time vectors of shape ``(ndim,)``. - Defines ``model_positions`` and ``model_velocities`` properties for retrieving validated unyt arrays with physical units. Subclasses specify the dimensionality by setting ``_ndim`` and may extend validation to include orientation, spin, or other Cartesian-specific keys. """ # ============================== # # Class Flags # # ============================== # # These flags are easily modified settings that are used throughout # the base class and should be easily accessible for modification # by subclasses. _model_metadata_required_keys = [] _model_metadata_allowed_keys = ["particles"] _ndim = 3 # ============================== # # Initialization Methods # # ============================== # @classmethod def _validate_model(cls, model_name: str, model_info: dict) -> None: super()._validate_model(model_name, model_info) # ============================== # # Properties # # ============================== # @property def model_positions(self) -> dict[str, unyt.unyt_array]: """ The positions of the models in the initial conditions. Returns ------- dict A dictionary mapping model names to their position vectors as `unyt.array.unyt_array`. """ return {name: unyt.unyt_array(info["position"], units="m") for name, info in self.models.items()} @property def model_velocities(self) -> dict[str, unyt.unyt_array]: """ The velocities of the models in the initial conditions. Returns ------- dict A dictionary mapping model names to their velocity vectors as `unyt.array.unyt_array`. """ return {name: unyt.unyt_array(info["velocity"], units="km/s") for name, info in self.models.items()} # ============================== # # Generator Methods # # ============================== # # These methods are used to generate the skeleton of an initial conditions # object. The methods here should be overridden in subclasses to specialize # the behavior of the initial conditions for a specific simulation code. # # ``_process_models`` takes a set of models and associated metadata and # proceeds to process them, returning a CONFIG-compatible dictionary of # model names and their properties. It also moves or copies the model files # into the initial conditions directory as needed. # @classmethod @abstractmethod def _validate_input_model(cls, model: dict, existing_models: dict, **kwargs) -> dict: # Perform the super-class initialization to ensure that we # have the basic structure in place. model = super()._validate_input_model(model, existing_models, **kwargs) # --- Parameter Processing [Extensible] --- # # At this stage, we validate and process the other parameters # in the model. This is the primary extension point for subclasses # that need to enforce different dimensionalities, coordinate systems, # or additional parameters. # Ensure that the model has its location and velocity specified correctly. model["position"] = unyt.unyt_array(model["position"]) if model["position"].units.dimensions != unyt.dimensions.length: raise TypeError(f"Position must have length units, not {model['position'].units.dimensions}.") if model["position"].shape != (cls._ndim,): raise ValueError(f"Position must be a {cls._ndim}D vector, not shape {model['position'].shape}.") model["velocity"] = unyt.unyt_array(model["velocity"]) if model["velocity"].units.dimensions != (unyt.dimensions.length / unyt.dimensions.time): raise TypeError(f"Velocity must have length/time units, not {model['velocity'].units.dimensions}.") if model["velocity"].shape != (cls._ndim,): raise ValueError(f"Velocity must be a {cls._ndim}D vector, not shape {model['velocity'].shape}.") # Once we complete the validation, we return the model. return model # ============================== # # Physics Methods # # ============================== #
[docs] def compute_center_of_mass( self, models: Union[str, list[str]] = "all", masses: Optional[dict[str, "unyt.unyt_quantity"]] = None, ): r"""Compute the mass-weighted center of mass (COM) position for one or more models in the initial conditions. The COM is calculated as: .. math:: \mathbf{R}_{\mathrm{COM}} = \frac{\sum_i m_i \mathbf{r}_i}{\sum_i m_i} where ``m_i`` is the total mass of model ``i`` and ``r_i`` is its position. Mass values are obtained in the following order: 1. If ``masses`` is provided, use those values (must map model name → scalar mass). 2. Otherwise, attempt to read ``total_mass`` from the model's metadata via :meth:`get_model_metadata`. Parameters ---------- models : str or list of str, optional Which models to include in the COM calculation. If ``"all"``, all models in :attr:`models` are used. Otherwise, provide an explicit list of model names. If a single model name is provided as a string, it will be treated as a list of one model. masses : dict of str, ~unyt.array.unyt_quantity, optional Optional mapping from model name to its total mass. If provided for a given model, this value overrides the ``total_mass`` from metadata. Returns ------- unyt.array.unyt_array The COM position vector with shape ``(ndim,)`` and units of length. Raises ------ KeyError If any specified model name is not found in the initial conditions. ValueError If a required mass is missing from both ``masses`` and the model metadata. """ import unyt # Manage the model processing logic so that # we have a list of model names to work with. if models == "all": model_names = list(self.models.keys()) elif isinstance(models, str): model_names = [models] elif isinstance(models, Sequence): model_names = list(models) else: raise TypeError("`models` must be 'all', a string, or a list of strings.") # Start the loop through all of the various # models to begin computing the center of mass # value and the weighted position value. total_mass_sum = 0.0 * unyt.Unit("Msun") weighted_position_sum = unyt.unyt_array([0.0] * self.ndim, units="Msun*pc") # Iterate through each of the models in our # initial conditions / those specified by the user. for name in model_names: if name not in self.models: raise KeyError(f"No model named '{name}' found in initial conditions.") # Determine mass if (masses is not None) and (name in masses): mass = masses[name] if not isinstance(mass, unyt.unyt_quantity): mass = unyt.unyt_quantity(mass, "Msun") else: metadata = self.get_model_metadata(name) if "total_mass" not in metadata: raise ValueError( f"Model '{name}' is missing 'total_mass' in metadata and no override was provided." ) mass = metadata["total_mass"] if not isinstance(mass, unyt.unyt_quantity): mass = unyt.unyt_quantity(mass, "Msun") # Position from stored config pos = self.model_positions[name] weighted_position_sum += mass * pos total_mass_sum += mass if total_mass_sum <= 0 * unyt.Unit("Msun"): raise ValueError("Total mass is zero; cannot compute center of mass.") return weighted_position_sum / total_mass_sum
[docs] def compute_center_of_mass_frame_positions( self, models: Union[str, list[str]] = "all", masses: Optional[dict[str, "unyt.unyt_quantity"]] = None, com_position: Optional["unyt.unyt_array"] = None, ) -> dict[str, "unyt.unyt_array"]: """ Compute the positions of models in the center-of-mass (COM) frame. This subtracts the COM position (mass-weighted) from each model's stored position. Parameters ---------- models : str or list of str, optional Models to include in the computation. Defaults to ``"all"`` for all models. masses : dict of str, ~unyt.array.unyt_quantity, optional Optional mapping from model name to its total mass, used for COM calculation. Overrides ``total_mass`` in metadata if provided. com_position : ~unyt.array.unyt_array, optional Precomputed COM position vector. If not given, it will be computed using :meth:`compute_center_of_mass` with the same ``models`` and ``masses``. Returns ------- dict of str, ~unyt.array.unyt_array Mapping of model names to their position vectors in the COM frame. Raises ------ KeyError If any specified model name is not found. ValueError If mass data is missing for a model and COM computation is required. """ # Normalize model list if models == "all": model_names = list(self.models.keys()) elif isinstance(models, str): model_names = [models] elif isinstance(models, Sequence): model_names = list(models) else: raise TypeError("`models` must be 'all', a string, or a list of strings.") # Compute COM position if not provided if com_position is None: com_position = self.compute_center_of_mass(model_names, masses=masses) com_position = unyt.unyt_array(com_position, units="m") # Compute COM-frame positions com_frame_positions: dict[str, unyt.unyt_array] = {} for name in model_names: if name not in self.models: raise KeyError(f"No model named '{name}' found in initial conditions.") pos = self.model_positions[name] com_frame_positions[name] = pos - com_position return com_frame_positions
[docs] def compute_center_of_mass_velocity( self, models: Union[str, list[str]] = "all", masses: Optional[dict[str, "unyt.unyt_quantity"]] = None, ): r""" Compute the mass-weighted center-of-mass (COM) velocity for one or more models in the initial conditions. The COM velocity is: .. math:: \mathbf{V}_{\mathrm{COM}} = \frac{\sum_i m_i \mathbf{v}_i}{\sum_i m_i} where ``m_i`` is the total mass of model ``i`` and ``v_i`` is its velocity vector. Mass values are obtained in the following order: 1. If ``masses`` is provided, use those values (must map model name → scalar mass). 2. Otherwise, attempt to read ``total_mass`` from the model's metadata. Parameters ---------- models : str or list of str, optional Which models to include in the COM velocity calculation. If ``"all"``, all models in :attr:`models` are used. If a single model name is given, it will be treated as a list of one. masses : dict of str, ~unyt.array.unyt_quantity, optional Optional mapping from model name to its total mass. Overrides ``total_mass`` in metadata if provided. Returns ------- ~unyt.array.unyt_array The COM velocity vector with shape ``(ndim,)`` and units of velocity. Raises ------ KeyError If any specified model name is not found in the initial conditions. ValueError If a required mass is missing from both ``masses`` and the model metadata. """ # Normalize model list if models == "all": model_names = list(self.models.keys()) elif isinstance(models, str): model_names = [models] elif isinstance(models, Sequence): model_names = list(models) else: raise TypeError("`models` must be 'all', a string, or a list of strings.") total_mass_sum = 0.0 * unyt.Unit("Msun") weighted_velocity_sum = unyt.unyt_array([0.0] * self.ndim, units="Msun*km/s") for name in model_names: if name not in self.models: raise KeyError(f"No model named '{name}' found in initial conditions.") # Determine mass if (masses is not None) and (name in masses): mass = masses[name] if not isinstance(mass, unyt.unyt_quantity): mass = unyt.unyt_quantity(mass, "Msun") else: metadata = self.get_model_metadata(name) if "total_mass" not in metadata: raise ValueError( f"Model '{name}' is missing 'total_mass' in metadata and no override was provided." ) mass = metadata["total_mass"] if not isinstance(mass, unyt.unyt_quantity): mass = unyt.unyt_quantity(mass, "Msun") vel = self.model_velocities[name] weighted_velocity_sum += mass * vel total_mass_sum += mass if total_mass_sum <= 0 * unyt.Unit("Msun"): raise ValueError("Total mass is zero; cannot compute center-of-mass velocity.") return weighted_velocity_sum / total_mass_sum
[docs] def compute_center_of_mass_frame_velocities( self, models: Union[str, list[str]] = "all", masses: Optional[dict[str, "unyt.unyt_quantity"]] = None, com_velocity: Optional["unyt.unyt_array"] = None, ) -> dict[str, "unyt.unyt_array"]: """ Compute the velocities of models in the center-of-mass (COM) frame. This subtracts the COM velocity (mass-weighted) from each model's stored velocity. Parameters ---------- models : str or list of str, optional Models to include in the computation. Defaults to ``"all"`` for all models. masses : dict of str, ~unyt.array.unyt_quantity, optional Optional mapping from model name to its total mass, used for COM calculation. Overrides ``total_mass`` in metadata if provided. com_velocity : ~unyt.array.unyt_array, optional Precomputed COM velocity vector. If not given, it will be computed using :meth:`compute_center_of_mass_velocity` with the same ``models`` and ``masses``. Returns ------- dict of str, ~unyt.array.unyt_array Mapping of model names to their velocity vectors in the COM frame. Raises ------ KeyError If any specified model name is not found. ValueError If mass data is missing for a model and COM computation is required. """ # Normalize model list if models == "all": model_names = list(self.models.keys()) elif isinstance(models, str): model_names = [models] elif isinstance(models, Sequence): model_names = list(models) else: raise TypeError("`models` must be 'all', a string, or a list of strings.") # Compute COM velocity if not provided if com_velocity is None: com_velocity = self.compute_center_of_mass_velocity(model_names, masses=masses) com_velocity = unyt.unyt_array(com_velocity, units="km/s") com_frame_velocities: dict[str, unyt.unyt_array] = {} for name in model_names: if name not in self.models: raise KeyError(f"No model named '{name}' found in initial conditions.") vel = self.model_velocities[name] com_frame_velocities[name] = vel - com_velocity return com_frame_velocities
[docs] def shift_to_COM_frame( self, models: Union[str, list[str]] = "all", masses: Optional[dict[str, "unyt.unyt_quantity"]] = None, ) -> None: """ Shift the positions and velocities of models into the center-of-mass (COM) frame. This method: 1. Computes the COM position and velocity for the specified models. 2. Subtracts these values from each model's position and velocity. 3. Updates the configuration in-place so the changes are persistent. Parameters ---------- models : str or list of str, optional Models to include in the COM calculation and shifting. If ``"all"`` (default), all models are included. If a single model name is given, it will be treated as a list of one. masses : dict of str, ~unyt.array.unyt_quantity, optional Optional mapping from model name to its total mass. Overrides ``total_mass`` in metadata if provided. Raises ------ KeyError If any specified model name is not found in the initial conditions. ValueError If mass data is missing for a model and COM computation is required. """ import unyt # Normalize model list if models == "all": model_names = list(self.models.keys()) elif isinstance(models, str): model_names = [models] elif isinstance(models, Sequence): model_names = list(models) else: raise TypeError("`models` must be 'all', a string, or a list of strings.") # Compute COM position and velocity com_position = self.compute_center_of_mass(model_names, masses=masses) com_velocity = self.compute_center_of_mass_velocity(model_names, masses=masses) # Apply shifts and update config for name in model_names: if name not in self.models: raise KeyError(f"No model named '{name}' found in initial conditions.") pos_key = f"models.{name}.position" vel_key = f"models.{name}.velocity" pos = unyt.unyt_array(self.__config__[pos_key], units=self.model_positions[name].units) vel = unyt.unyt_array(self.__config__[vel_key], units=self.model_velocities[name].units) self.__config__[pos_key] = pos - com_position self.__config__[vel_key] = vel - com_velocity self.logger.info( f"Shifted {len(model_names)} models to COM frame: COM position={com_position}, COM velocity={com_velocity}" )
[docs] def compute_total_mass( self, models: Union[str, list[str]] = "all", masses: Optional[dict[str, "unyt.unyt_quantity"]] = None, ) -> "unyt.unyt_quantity": """ Compute the total mass of one or more models in the initial conditions. Mass values are determined in the same order as in :meth:`compute_center_of_mass`: 1. If ``masses`` is provided, use those values (must map model name → scalar mass). 2. Otherwise, attempt to read ``total_mass`` from the model's metadata via :meth:`get_model_metadata`. Parameters ---------- models : str or list of str, optional Which models to include in the total mass calculation. If ``"all"`` (default), all models in :attr:`models` are used. If a single model name is provided as a string, it will be treated as a list containing that model. masses : dict of str, ~unyt.array.unyt_quantity, optional Optional mapping from model name to its total mass. If provided for a given model, this value overrides the ``total_mass`` from metadata. Returns ------- ~unyt.array.unyt_quantity The total mass of the specified models, with units of mass. Raises ------ KeyError If any specified model name is not found in the initial conditions. ValueError If a required mass is missing from both ``masses`` and the model metadata. """ import unyt # Normalize model list if models == "all": model_names = list(self.models.keys()) elif isinstance(models, str): model_names = [models] elif isinstance(models, Sequence): model_names = list(models) else: raise TypeError("`models` must be 'all', a string, or a list of strings.") total_mass_sum = 0.0 * unyt.Unit("Msun") for name in model_names: if name not in self.models: raise KeyError(f"No model named '{name}' found in initial conditions.") # Determine mass if masses is not None and name in masses: mass = masses[name] if not isinstance(mass, unyt.unyt_quantity): mass = unyt.unyt_quantity(mass, "Msun") else: metadata = self.get_model_metadata(name) if "total_mass" not in metadata: raise ValueError( f"Model '{name}' is missing 'total_mass' in metadata and no override was provided." ) mass = metadata["total_mass"] if not isinstance(mass, unyt.unyt_quantity): mass = unyt.unyt_quantity(mass, "Msun") total_mass_sum += mass return total_mass_sum
[docs] class InitialConditions1DCartesian(InitialConditionsCartesian): """ Initial conditions for 1D Cartesian simulations. Specialized Cartesian subclass with ``ndim=1``. Models are placed along a one-dimensional line with scalar position and velocity. This is useful for toy models, 1D test problems, or simplified collapse/expansion scenarios. Extends :class:`InitialConditionsCartesian` with physics utilities for: - Mass-weighted center-of-mass (COM) position and velocity. - Shifting models into the COM frame. - Computing the total system mass. Models must include: - ``model_name`` and ``model`` (file reference). - ``position`` (1D vector with length units). - ``velocity`` (1D vector with length/time units). """ # ============================== # # Class Flags # # ============================== # # These flags are easily modified settings that are used throughout # the base class and should be easily accessible for modification # by subclasses. _model_metadata_required_keys = ["position", "velocity"] _model_metadata_allowed_keys = ["particles"] _ndim = 1 # ============================== # # Initialization Methods # # ============================== # @classmethod def _validate_model(cls, model_name: str, model_info: dict) -> None: super()._validate_model(model_name, model_info) # ============================== # # Generator Methods # # ============================== # # These methods are used to generate the skeleton of an initial conditions # object. The methods here should be overridden in subclasses to specialize # the behavior of the initial conditions for a specific simulation code. # # ``_process_models`` takes a set of models and associated metadata and # proceeds to process them, returning a CONFIG-compatible dictionary of # model names and their properties. It also moves or copies the model files # into the initial conditions directory as needed. # @classmethod def _validate_input_model(cls, model: dict, existing_models: dict, **kwargs) -> dict: # Perform the super-class initialization to ensure that we # have the basic structure in place. model = super()._validate_input_model(model, existing_models, **kwargs) return model
[docs] class InitialConditions2DCartesian(InitialConditionsCartesian): """ Initial conditions for 2D Cartesian simulations. Specialized Cartesian subclass with ``ndim=2``. Models are embedded in a planar (x, y) geometry and may include an orientation vector to define spin axes or angular alignment. Key features: - Validates that positions/velocities are 2D vectors with correct units. - Normalizes orientation vectors to unit length. - Provides ``model_orientations`` property for access. Models must include: - ``model_name``, ``model``, ``position``, ``velocity``. - Optional ``orientation`` (default = [0, 1]). - Optional ``particles``. """ # ============================== # # Class Flags # # ============================== # # These flags are easily modified settings that are used throughout # the base class and should be easily accessible for modification # by subclasses. _model_metadata_required_keys = ["position", "velocity"] _model_metadata_allowed_keys = ["particles", "orientation"] _ndim = 2 # ============================== # # Initialization Methods # # ============================== # @classmethod def _validate_model(cls, model_name: str, model_info: dict) -> None: super()._validate_model(model_name, model_info) # ============================== # # Properties # # ============================== # @property def model_orientations(self) -> dict[str, np.ndarray]: """ The orientations of the models in the initial conditions. Returns ------- dict A dictionary mapping model names to their orientation vectors as `np.ndarray`. """ return {name: np.asarray(info["orientation"], dtype=float) for name, info in self.models.items()} # ============================== # # Generator Methods # # ============================== # # These methods are used to generate the skeleton of an initial conditions # object. The methods here should be overridden in subclasses to specialize # the behavior of the initial conditions for a specific simulation code. # # ``_process_models`` takes a set of models and associated metadata and # proceeds to process them, returning a CONFIG-compatible dictionary of # model names and their properties. It also moves or copies the model files # into the initial conditions directory as needed. # @classmethod def _validate_input_model(cls, model: dict, existing_models: dict, **kwargs) -> dict: # Perform the super-class initialization to ensure that we # have the basic structure in place. model = super()._validate_input_model(model, existing_models, **kwargs) # Now we check the orientation to ensure that they are both valid and # correctly dimensioned. If they are not present, we set them to default values. model["orientation"] = np.asarray(model.get("orientation", np.asarray([0, 1], dtype="f8"))) if model["orientation"].shape != (2,): raise ValueError(f"Orientation must be a 3D vector, not shape {model['orientation'].shape}.") if np.linalg.norm(model["orientation"]) <= 1e-8: raise ValueError("Orientation vector cannot be the zero vector.") model["orientation"] = model["orientation"] / np.linalg.norm(model["orientation"]) return model
[docs] class InitialConditions3DCartesian(InitialConditionsCartesian): """ Initial conditions for 3D Cartesian simulations. Specialized Cartesian subclass with ``ndim=3``. This is the most general Cartesian IC class, allowing full 3D positioning, velocities, orientations, and spins. It is designed for galaxy, cluster, and cosmological ICs where full spatial and kinematic degrees of freedom are needed. Key features: - Validates 3D position and velocity vectors. - Ensures orientation is a valid nonzero 3D unit vector. - Enforces ``spin`` to be a scalar float. - Provides convenience properties: ``model_spins``. Physics utilities include: - Integration of point-mass orbits with the `rebound` N-body package. - Future extensions may add angular momentum alignment, merging utilities, or orbit fitting. Models must include: - ``model_name``, ``model``, ``position``, ``velocity``. - Optional ``orientation`` (default = [0, 0, 1]). - Optional ``spin`` (default = 0.0). - Optional ``particles``. """ # ============================== # # Class Flags # # ============================== # # These flags are easily modified settings that are used throughout # the base class and should be easily accessible for modification # by subclasses. _model_metadata_required_keys = ["position", "velocity"] _model_metadata_allowed_keys = ["particles", "orientation", "spin"] _ndim = 3 # ============================== # # Initialization Methods # # ============================== # @classmethod def _validate_model(cls, model_name: str, model_info: dict) -> None: super()._validate_model(model_name, model_info) # ============================== # # Properties # # ============================== # @property def model_spins(self) -> dict[str, float]: """ The spins of the models in the initial conditions. Returns ------- dict A dictionary mapping model names to their spin values as `float`. """ return {name: float(info["spin"]) for name, info in self.models.items()} @property def model_orientations(self) -> dict[str, np.ndarray]: """ The orientations of the models in the initial conditions. Returns ------- dict A dictionary mapping model names to their orientation vectors as `np.ndarray`. """ return {name: np.asarray(info["orientation"], dtype=float) for name, info in self.models.items()} # ============================== # # Generator Methods # # ============================== # # These methods are used to generate the skeleton of an initial conditions # object. The methods here should be overridden in subclasses to specialize # the behavior of the initial conditions for a specific simulation code. # # ``_process_models`` takes a set of models and associated metadata and # proceeds to process them, returning a CONFIG-compatible dictionary of # model names and their properties. It also moves or copies the model files # into the initial conditions directory as needed. # @classmethod def _validate_input_model(cls, model: dict, existing_models: dict, **kwargs) -> dict: # Perform the super-class initialization to ensure that we # have the basic structure in place. model = super()._validate_input_model(model, existing_models, **kwargs) # Now we check the orientation and the spin to ensure that they are both valid and # correctly dimensioned. If they are not present, we set them to default values. model["orientation"] = np.asarray(model.get("orientation", np.asarray([0, 0, 1], dtype="f8"))) if model["orientation"].shape != (3,): raise ValueError(f"Orientation must be a 3D vector, not shape {model['orientation'].shape}.") if np.linalg.norm(model["orientation"]) <= 1e-8: raise ValueError("Orientation vector cannot be the zero vector.") model["orientation"] = model["orientation"] / np.linalg.norm(model["orientation"]) # We now check the spin to ensure that it is a scalar value. model["spin"] = float(model.get("spin", 0.0)) if not isinstance(model["spin"], float): raise TypeError(f"Spin must be a float, not {type(model['spin']).__name__}.") # Once we complete the validation, we return the model. return model # ============================== # # Physics Methods # # ============================== #
[docs] def integrate_point_mass_orbits( self, models: Union[str, list[str]] = "all", masses: Optional[dict[str, "unyt.unyt_quantity"]] = None, t_end: "unyt.unyt_quantity" = None, dt: Optional["unyt.unyt_quantity"] = None, integrator: str = "whfast", ): """ Integrate the point-mass equivalent orbits for selected models. This method treats each selected model as a single massive particle with its center-of-mass position, velocity, and total mass. The integration is done in a self-consistent gravitational N-body simulation without hydrodynamics. Parameters ---------- models : str or list of str, default="all" Models to include as point masses. If ``"all"``, uses all models. If a single string is given, it is treated as a list of one. masses : dict of str, ~unyt.array.unyt_quantity, optional Mapping from model name to its total mass. If provided for a given model, overrides the ``total_mass`` value from that model's metadata. All masses must be scalar `unyt_quantity` with units convertible to ``Msun``. t_end : ~unyt.array.unyt_quantity Total integration time **from the current epoch**. Must have time units. dt : ~unyt.array.unyt_quantity, optional Time step for the integrator. If not provided, defaults to ``t_end / 1000``. Must have time units. integrator : str, default="whfast" The REBOUND integrator to use. Options include: - ``"whfast"``: fast symplectic integrator (default) - ``"ias15"``: high-accuracy adaptive integrator - Other integrators supported by `rebound` Returns ------- rebound.Simulation The `rebound` simulation object after integration to ``t_end``. Particle indices correspond to the order of models in the ``models`` parameter. Raises ------ ImportError If the `rebound` package is not installed. KeyError If any requested model is not found in the configuration. ValueError If `t_end` is not provided, is non-positive, or if mass data is missing. """ import unyt try: import rebound except ImportError as e: raise ImportError( "The `rebound` package is required for this method. Install it via: pip install rebound" ) from e # --- Normalize model selection if models == "all": model_names = list(self.models.keys()) elif isinstance(models, str): model_names = [models] elif isinstance(models, Sequence): model_names = list(models) else: raise TypeError("`models` must be 'all', a string, or a list of strings.") if not model_names: raise ValueError("No models specified for integration.") # --- Validate time inputs if t_end is None: raise ValueError("`t_end` must be provided as a time quantity.") t_end = unyt.unyt_quantity(t_end) if t_end.units.dimensions != unyt.dimensions.time: raise ValueError("`t_end` must have time units.") if t_end <= 0 * t_end.units: raise ValueError("`t_end` must be positive.") if dt is None: dt = t_end / 1000 dt = unyt.unyt_quantity(dt) if dt.units.dimensions != unyt.dimensions.time: raise ValueError("`dt` must have time units.") # --- Prepare model data positions = self.model_positions velocities = self.model_velocities total_masses: dict[str, unyt.unyt_quantity] = {} for name in model_names: if name not in self.models: raise KeyError(f"No model named '{name}' found in initial conditions.") # Determine mass (user-specified overrides metadata) if masses and name in masses: m_val = masses[name] if not isinstance(m_val, unyt.unyt_quantity): m_val = unyt.unyt_quantity(m_val, "Msun") else: metadata = self.get_model_metadata(name) if "total_mass" not in metadata: raise ValueError(f"Model '{name}' is missing 'total_mass' in metadata.") m_val = metadata["total_mass"] if not isinstance(m_val, unyt.unyt_quantity): m_val = unyt.unyt_quantity(m_val, "Msun") total_masses[name] = m_val.to("Msun") # --- Build REBOUND simulation sim = rebound.Simulation() sim.units = ("pc", "Msun", "Myr") sim.integrator = integrator sim.dt = dt.to_value("Myr") for name in model_names: pos = positions[name].to("pc").value vel = velocities[name].to("pc/Myr").value m = total_masses[name].value sim.add(m=m, x=pos[0], y=pos[1], z=pos[2], vx=vel[0], vy=vel[1], vz=vel[2]) # Shift to COM frame before integration sim.move_to_com() # --- Integrate sim.integrate(sim.t + t_end.to_value("Myr")) return sim
[docs] def load_ics(path: Union[str, Path]) -> InitialConditions: """ Load an initial conditions directory from its configuration file. This function reads the ``IC_CONFIG.yaml`` file in the specified directory, inspects its metadata to determine the correct :class:`InitialConditions` subclass, and returns an instance of that class. Parameters ---------- path : str or Path Path to the initial conditions directory containing an ``IC_CONFIG.yaml`` file. Returns ------- InitialConditions An instance of the appropriate subclass populated from the file. Raises ------ FileNotFoundError If the directory or its ``IC_CONFIG.yaml`` file does not exist. ValueError If the configuration is invalid or the class cannot be resolved. """ path = Path(path) if not path.exists(): raise FileNotFoundError(f"Initial conditions directory does not exist: {path}") if not path.is_dir(): raise ValueError(f"Path must be a directory, not a file: {path}") config_path = path / "IC_CONFIG.yaml" if not config_path.exists(): raise FileNotFoundError(f"No 'IC_CONFIG.yaml' found in directory: {path}") try: with open(config_path) as f: config = InitialConditions.__YAML__.load(f) except Exception as e: raise ValueError(f"Failed to parse YAML configuration: {config_path}") from e metadata = config.get("metadata", {}) class_name = metadata.get("class_name") if not class_name: raise ValueError(f"Configuration missing required 'metadata.class_name': {config_path}") # Recursive subclass lookup def _all_subclasses(cls): for sub in cls.__subclasses__(): yield sub yield from _all_subclasses(sub) subclass_map = {cls.__name__: cls for cls in _all_subclasses(InitialConditions)} ics_class = subclass_map.get(class_name) if ics_class is None: raise ValueError(f"Unknown InitialConditions subclass '{class_name}' in {config_path}") return ics_class(path)