"""Core classes for particle data support in Pisces.
This module defines the base class for particle datasets, providing a standard interface
for reading, writing, and interacting with particle datasets in the Pisces framework.
"""
from datetime import datetime
from pathlib import Path
import h5py
import numpy as np
import unyt
from pisces.utilities.io_tools import HDF5Serializer
[docs]
class ParticleDataset:
"""Base class for particle datasets in Pisces.
This class provides the standard interface for reading, writing, and interacting with particle datasets
in the Pisces framework. It parallels the standard HDF5 formats used in codes like AREPO, GADGET, etc.
but includes some additional flexibility and features.
Details regarding the expected structure of the dataset and how to work with particle
datasets can be found at :ref:`particles_overview`.
"""
metadata_serializer: HDF5Serializer = HDF5Serializer
"""~utilities.io_tools.HDF5Serializer: The HDF5 serializer used for reading and writing metadata.
This serializer class is responsible for converting metadata types into
formats which are compatible with HDF5 attributes. By default, this is the
:py:class:`~pisces.utilities.io_tools.HDF5Serializer` class, which handles
serialization of common Python types (e.g., dict, list, str) into HDF5 attributes
along with unyt arrays, quantities, and units.
Developers may extend or replace this serializer to support custom types
if there is need.
"""
# --------------------------------- #
# Class Constants / Flags #
# --------------------------------- #
# These flags provide standard names for built-in access to specific
# fields in the dataset. This ensures that those areas of the code are
# not hardcoded with string literals. These can be overridden in
# subclasses if the dataset uses different naming conventions.
# --- Field Name Conventions --- #
_POSITION_FIELD_NAME: str = "particle_position"
"""str: The standard name for the position field in each particle group."""
_VELOCITY_FIELD_NAME: str = "particle_velocity"
"""str: The standard name for the velocity field in each particle group."""
_MASS_FIELD_NAME: str = "particle_mass"
"""str: The standard name for the mass field in each particle group."""
_ID_FIELD_NAME: str = "particle_id"
"""str: The standard name for the unique identifier field in each particle group."""
# --- Class Settings --- #
_ID_POLICY: str = "global"
"""str: The policy for assigning unique particle IDs.
This can be either "global" or "per_group":
- "global": Particle IDs are unique across all groups in the dataset.
- "per_group": Particle IDs are only unique within each group.
This setting affects how particle IDs are interpreted and managed within the dataset.
Default is "global".
"""
# --- Metadata Requirements ---#
__REQUIRED_GLOBAL_METADATA__: list[str] = ["CLASS_NAME", "GEN_TIME"]
"""list of str: The required global metadata attributes for this dataset.
If these are not all present in the global metadata on load, then
the dataset will raise a :py:class:`IOError` during validation.
"""
__REQUIRED_GROUP_METADATA__: list[str] = ["NUMBER_OF_PARTICLES"]
"""list of str: The required group metadata attributes for each particle group.
If these are not all present in the group metadata, then
the dataset will raise a :py:class:`ValueError` during validation.
"""
@classmethod
def _serialized(cls, obj):
"""Serialize an object using the dataset's metadata serializer.
Parameters
----------
obj : Any
The object to serialize.
Returns
-------
Any
The serialized representation of the object.
"""
if isinstance(obj, dict):
return cls.metadata_serializer.serialize_dict(obj)
else:
return cls.metadata_serializer.serialize_data(obj)
@classmethod
def _deserialized(cls, obj):
"""Deserialize an object using the dataset's metadata serializer.
Parameters
----------
obj : Any
The object to deserialize.
Returns
-------
Any
The deserialized representation of the object.
"""
if isinstance(obj, dict):
return cls.metadata_serializer.deserialize_dict(obj)
else:
return cls.metadata_serializer.deserialize_data(obj)
# -------------------------------------- #
# Initialization and Validation Methods #
# -------------------------------------- #
# These methods are responsible for initializing the dataset.
def __validate__(self):
"""Validate that this dataset meets a minimum set of format requirements.
The following steps are performed to check the dataset structure:
- Check the **global metadata**:
We look through all of the __REQUIRED_GLOBAL_METADATA__ attributes
to ensure that they are all present in the global metadata.
- Check the **group metadata**:
For each of the particle groups which DOESN'T have the
``NOT_PARTICLE_GROUP`` attribute, we check that the __REQUIRED_GROUP_METADATA__
attributes are present in the group metadata.
- Check the **number of particles**:
For each dataset in each particle group, we check that the number of particles
matches the ``NUMBER_OF_PARTICLES`` attribute in the group metadata.
This method can be extended in subclasses to implement additional validation logic
or constraints specific to the dataset type. It is called automatically during
initialization to ensure that the dataset is in a valid state before any operations
are performed.
"""
# CHECKING GLOBAL METADATA:
# Ensure that all required global metadata attributes are present
# and that the CLASS_NAME flag is set to the correct class name.
# The global metadata is ALREADY DESERIALIZED.
_glob_metadata = self.get_global_metadata()
# Ensure required metadata is present.
if any(required_key not in _glob_metadata for required_key in self.__REQUIRED_GLOBAL_METADATA__):
missing_keys = [key for key in self.__REQUIRED_GLOBAL_METADATA__ if key not in _glob_metadata]
raise OSError(f"Missing required global metadata keys: {', '.join(missing_keys)}")
# Check that this is the correct loading class.
_expected_class_name = _glob_metadata.get("CLASS_NAME")
if _expected_class_name != self.__class__.__name__:
raise OSError(
f"Expected global metadata CLASS_NAME to be '{self.__class__.__name__}', "
f"but found '{_expected_class_name}'. This file may not be a valid "
f"{self.__class__.__name__} dataset."
)
# CHECKING PARTICLE GROUPS METADATA:
# Cycle through all particle groups and validate their metadata.
for group_name in self.__handle__.keys():
_group_handle = self.__handle__[group_name]
# Check if the group actually has the `NOT_PARTICLE_GROUP` attribute. If
# so, we just skip it straight up.
if "NOT_PARTICLE_GROUP" in _group_handle.attrs:
continue
# Otherwise, we need to validate the metadata.
_group_metadata = self.get_group_metadata(group_name)
# Ensure that all required group metadata attributes are present.
if any(required_key not in _group_metadata for required_key in self.__REQUIRED_GROUP_METADATA__):
missing_keys = [key for key in self.__REQUIRED_GROUP_METADATA__ if key not in _group_metadata]
raise ValueError(f"Group '{group_name}' is missing required metadata keys: {', '.join(missing_keys)}")
# Finally, check that the number of particles in each
# dataset matches the number of particles specified in the metadata.
num_particles = _group_metadata.get("NUMBER_OF_PARTICLES")
for dataset_name in _group_handle.keys():
dataset_handle = _group_handle[dataset_name]
# Check that the dataset has the correct number of particles.
if dataset_handle.shape[0] != num_particles:
raise ValueError(
f"Dataset '{dataset_name}' in group '{group_name}' has {dataset_handle.shape[0]} "
f"particles, but 'NUMBER_OF_PARTICLES' metadata indicates {num_particles}."
)
[docs]
def __init__(self, path: str | Path, mode="r+"):
"""Initialize a :class:`ParticleDataset` from a file on disk.
This constructor opens the specified HDF5 file and validates
the global metadata to ensure that it conforms to the expected
format / structure.
Parameters
----------
path : str or ~pathlib.Path
The path to the HDF5 file containing the particle data.
This can be a string or a :class:`pathlib.Path` object.
If the path does not exist, a `FileNotFoundError` is raised.
mode: str, optional
The mode in which to open the HDF5 file. Defaults to "r+" (read/write mode).
The available modes are:
- "r": Read-only mode. The file must exist.
- "r+": Read/write mode. The file must exist.
- "w": Write mode. Creates a new file or truncates an existing file.
- "w-": Write mode, but fails if the file already exists.
- "x": Exclusive creation mode. Fails if the file already exists.
Notes
-----
At this level, initialization consists of only the following 4 steps:
1. Set the path to the HDF5 file and check that it exists.
2. Open the HDF5 file in the specified mode and create the handle
reference to the file.
3. Load the global metadata from the file using the serializer.
4. Validate the dataset structure by calling the ``.__validate__`` method.
Subclasses may extend this behavior to include custom behavior beyond this.
Additionally, the ``.__post_init__`` method is called after initialization,
allowing for further customization or setup that is specific to the subclass.
"""
# Set the path and open the handle to the HDF5 file.
self.__path__ = Path(path)
if not self.__path__.exists():
raise FileNotFoundError(f"Particle dataset file not found: {self.__path__}")
self.__handle__ = h5py.File(self.__path__, mode=mode)
# Load the global metadata from disk via
# the serializer.
self.__global_metadata__ = self.get_global_metadata()
# Check that the file is a valid particle dataset. This defers
# to the __validate__ method to ensure that the file structure
# and metadata conform to the expected format. This can be overridden
# in subclasses to implement custom validation logic.
self.__validate__()
# Pass on to the post init method.
self.__post_init__()
def __post_init__(self):
"""Post-initialization hook for the :class:`ParticleDataset` class.
This method is called after the dataset has been initialized and validated.
It can be overridden in subclasses to perform additional setup or
configuration that is specific to the subclass implementation.
"""
pass
# ------------------------------------ #
# Properties #
# ------------------------------------ #
# --- Basic Attributes --- #
@property
def path(self) -> str | Path:
"""The path to the HDF5 file containing the particle dataset.
This property returns the path as a string or a Path object, depending on how
it was initialized.
Returns
-------
Union[str, Path]
The path to the HDF5 file.
"""
return self.__path__
@property
def handle(self) -> h5py.File:
"""The HDF5 file handle for the particle dataset.
This property provides direct access to the underlying HDF5 file handle,
allowing for low-level operations if needed. It is recommended to use
higher-level methods and properties for most use cases.
Returns
-------
h5py.File
The HDF5 file handle.
"""
return self.__handle__
@property
def global_metadata(self) -> dict:
"""Global metadata attributes at the root level of the HDF5 file.
This includes attributes such as creation time, cosmological parameters,
and dataset-wide configuration flags. Attributes marked here are
accessible via :attr:`ParticleDataset.metadata`.
Returns
-------
dict
A dictionary of all global HDF5 attributes.
"""
# We return a copy to prevent weird editing attempts.
return self.__global_metadata__.copy()
@property
def particle_groups(self) -> list[str]:
"""Names of all particle groups present in the dataset.
This excludes any HDF5 groups that are marked with the attribute
``NOT_PARTICLE_GROUP``.
Returns
-------
list of str
The names of valid particle groups.
"""
groups = []
for name, group in self.__handle__.items():
if isinstance(group, h5py.Group) and "NOT_PARTICLE_GROUP" not in group.attrs:
groups.append(name)
return groups
@property
def fields(self) -> list[str]:
"""List of all fields (datasets) available in the dataset, in dot notation.
This property returns a list of all field names across all particle groups,
using the format ``group_name.field_name``. This allows direct access via
indexing, e.g., ``ds["baryons.particle_velocity"]``.
Returns
-------
list of str
A sorted list of all field names in dot notation.
"""
field_names = []
for group_name in self.particle_groups:
group = self.__handle__[group_name]
field_names.extend(f"{group_name}.{field}" for field in group.keys())
return sorted(field_names)
@property
def num_particles(self) -> dict[str, int]:
"""Number of particles in each particle group.
This property returns a dictionary mapping each particle group name to the
number of particles it contains, as specified by the ``NUMBER_OF_PARTICLES``
attribute in each group's metadata.
All groups must define this attribute; otherwise, a ValueError is raised.
Returns
-------
dict
A dictionary mapping group names to the number of particles in each group.
Raises
------
ValueError
If any group is missing the ``NUMBER_OF_PARTICLES`` attribute.
"""
counts = {}
for group_name in self.particle_groups:
metadata = self.get_group_metadata(group_name)
if "NUMBER_OF_PARTICLES" not in metadata:
raise ValueError(f"Group '{group_name}' is missing required 'NUMBER_OF_PARTICLES' attribute.")
counts[group_name] = metadata["NUMBER_OF_PARTICLES"]
return counts
@property
def total_particles(self) -> int:
"""Total number of particles across all particle groups.
This property sums the number of particles in each group as specified by
the ``NUMBER_OF_PARTICLES`` attribute in each group's metadata.
Returns
-------
int
The total number of particles across all groups.
"""
return sum(self.num_particles.values())
# ------------------------------------ #
# Dunder Methods #
# ------------------------------------ #
def __str__(self) -> str:
"""Return a human-readable string representation of the ParticleDataset.
This includes the file path, creation date, total number of particles, and
a list of particle groups with their particle counts.
Returns
-------
str
A formatted string describing the dataset.
"""
return f"<{self.__class__.__name__} @ {self.path.name} | N = {self.total_particles}>"
def __repr__(self) -> str:
"""Return a detailed string representation for debugging.
This includes the class name, file path, total particle count,
and number of groups.
Returns
-------
str
A concise technical summary of the dataset.
"""
return (
f"<{self.__class__.__name__}("
f"path={repr(str(self.path))}, "
f"groups={len(self.particle_groups)}, "
f"total_particles={self.total_particles})>"
)
def __del__(self):
"""Destructor for the ParticleDataset class. Closes the HDF5 file handle if it is open."""
if hasattr(self, "__handle__") and self.__handle__ is not None:
self.__handle__.close()
del self.__handle__
def __getitem__(self, key: str) -> unyt.unyt_array:
"""Get a particle field by its name in dot notation.
Parameters
----------
key : str
The field name in the format ``"group_name.field_name"``.
Returns
-------
unyt.array.unyt_array
The data for the specified field, converted to a unyt array.
Raises
------
KeyError
If the specified field does not exist in the dataset.
"""
try:
group_name, field_name = key.split(".")
except ValueError as err:
raise KeyError(f"Invalid field name format: '{key}'. Expected 'group_name.field_name'.") from err
if not self.__contains__(key):
raise KeyError(f"Field '{key}' does not exist in the dataset.")
return self.get_particle_field(group_name, field_name)
def __contains__(self, key: str) -> bool:
"""Check if a particle field exists in the dataset.
Parameters
----------
key : str
The field name in the format ``"group_name.field_name"``.
Returns
-------
bool
True if the field exists, False otherwise.
"""
try:
group_name, field_name = key.split(".")
except ValueError as err:
raise KeyError(f"Invalid field name format: '{key}'. Expected 'group_name.field_name'.") from err
return f"{group_name}/{field_name}" in self.__handle__
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.__del__()
def __len__(self) -> int:
"""Return the number of particle fields in the dataset."""
return len(self.fields)
def __iter__(self):
"""Iterate over all particle fields in dot notation."""
return iter(self.fields)
def __dir__(self):
"""Return a list of all attributes and methods of the ParticleDataset."""
return list(super().__dir__()) + self.fields
# ------------------------------------ #
# Metadata Management Methods #
# ------------------------------------ #
[docs]
def get_global_metadata(self) -> dict:
"""Get the global metadata attributes of the dataset.
This method retrieves all global attributes at the root level of the HDF5 file
and returns them as a dictionary. It is used to access metadata such as creation date,
cosmological parameters, and dataset-wide configuration flags.
Returns
-------
dict
A dictionary containing all global metadata attributes.
"""
return self.metadata_serializer.deserialize_dict(dict(self.__handle__.attrs))
[docs]
def get_group_metadata(self, group_name: str) -> dict:
"""Get the metadata attributes for a specific particle group.
This method retrieves all attributes associated with a given particle group
and returns them as a dictionary. It is used to access metadata such as the
number of particles in the group and any additional attributes defined by the user.
Parameters
----------
group_name : str
The name of the particle group whose metadata is to be retrieved.
Returns
-------
dict
A dictionary containing all metadata attributes for the specified group.
Raises
------
KeyError
If the specified group does not exist in the dataset.
"""
if group_name not in self.particle_groups:
raise KeyError(f"Particle group '{group_name}' does not exist in the dataset.")
return self.metadata_serializer.deserialize_dict(dict(self.__handle__[group_name].attrs))
[docs]
def get_field_metadata(self, group_name: str, field_name: str) -> dict:
"""Get the metadata attributes for a specific field in a particle group.
This method retrieves all attributes associated with a given field in a particle group
and returns them as a dictionary. It is used to access metadata such as units, data type,
and any additional attributes defined by the user.
Parameters
----------
group_name : str
The name of the particle group containing the field.
field_name : str
The name of the field whose metadata is to be retrieved.
Returns
-------
dict
A dictionary containing all metadata attributes for the specified field.
Raises
------
KeyError
If the specified group or field does not exist in the dataset.
"""
dataset_handle = self.get_particle_field_handle(group_name, field_name)
return self.metadata_serializer.deserialize_dict(dict(dataset_handle.attrs))
[docs]
def reload_global_metadata(self):
"""Reload the global metadata from the HDF5 file.
This method refreshes the global metadata attributes by re-reading them from the
HDF5 file. It is useful if the metadata has been modified externally or if you want
to ensure you have the latest version of the metadata.
"""
self.__global_metadata__ = self.get_global_metadata()
[docs]
def update_global_metadata(self, metadata: dict):
"""Update the global metadata attributes of the dataset.
This method allows you to modify or add global metadata attributes at the root level
of the HDF5 file. It updates the attributes in memory and writes them back to the file.
Parameters
----------
metadata : dict
A dictionary containing the metadata attributes to update or add.
"""
# Serialize the metadata dictionary
serialized_meta = self.metadata_serializer.serialize_dict(metadata)
# Now write.
for key, value in serialized_meta.items():
self.__handle__.attrs[key] = value
# Reload global metadata.
self.reload_global_metadata()
[docs]
def update_group_metadata(self, group_name: str, metadata: dict):
"""Update the metadata attributes for a specific particle group.
This method allows you to modify or add metadata attributes for a given particle group.
It updates the attributes in memory and writes them back to the HDF5 file.
Parameters
----------
group_name : str
The name of the particle group whose metadata is to be updated.
metadata : dict
A dictionary containing the metadata attributes to update or add.
Raises
------
KeyError
If the specified group does not exist in the dataset.
"""
group_handle = self.get_particle_group_handle(group_name)
# Serialize the metadata dictionary
serialized_meta = self.metadata_serializer.serialize_dict(metadata)
# Now write.
for key, value in serialized_meta.items():
group_handle.attrs[key] = value
[docs]
def update_field_metadata(self, group_name: str, field_name: str, metadata: dict):
"""Update the metadata attributes for a specific field in a particle group.
This method allows you to modify or add metadata attributes for a given field in a particle group.
It updates the attributes in memory and writes them back to the HDF5 file.
Parameters
----------
group_name : str
The name of the particle group containing the field.
field_name : str
The name of the field whose metadata is to be updated.
metadata : dict
A dictionary containing the metadata attributes to update or add.
Raises
------
KeyError
If the specified group or field does not exist in the dataset.
"""
dataset_handle = self.get_particle_field_handle(group_name, field_name)
# Serialize the metadata dictionary
serialized_meta = self.metadata_serializer.serialize_dict(metadata)
# Now write.
for key, value in serialized_meta.items():
dataset_handle.attrs[key] = value
[docs]
def delete_global_metadata_keys(self, *keys: str):
"""
Delete one or more metadata keys from the global metadata.
Parameters
----------
*keys : str
The names of the metadata keys to delete from the global metadata.
Notes
-----
This method will not permit you to remove a required global metadata key.
"""
for key in keys:
if key in self.__REQUIRED_GLOBAL_METADATA__:
raise ValueError("Cannot delete required global metadata key: " + key)
if key in self.__handle__.attrs:
del self.__handle__.attrs[key]
self.reload_global_metadata()
[docs]
def delete_group_metadata_keys(self, group_name: str, *keys: str):
"""Delete one or more metadata keys from a specific particle group.
Parameters
----------
group_name : str
The name of the particle group from which to delete metadata keys.
*keys : str
The names of the metadata keys to delete from the global metadata.
Notes
-----
This method will not permit you to remove a required global metadata key.
"""
group_handle = self.get_particle_group_handle(group_name)
for key in keys:
if key in self.__REQUIRED_GROUP_METADATA__:
raise ValueError("Cannot delete required group metadata key: " + key)
if key in group_handle.attrs:
del group_handle.attrs[key]
[docs]
def delete_field_metadata_keys(self, group_name: str, field_name: str, *keys: str):
"""Delete one or more metadata keys from a specific field in a particle group.
Parameters
----------
group_name : str
The name of the particle group containing the field.
field_name : str
The name of the field from which to delete metadata keys.
Notes
-----
You may not remove ``"UNITS"`` from the field metadata, as this is required.
"""
field_handle = self.get_particle_field_handle(group_name, field_name)
for key in keys:
if key == "UNITS":
raise ValueError("Cannot delete required field metadata key: 'UNITS'")
if key in field_handle.attrs:
del field_handle.attrs[key]
# ------------------------------------ #
# Data Access Methods #
# ------------------------------------ #
# --- HDF5 Group and Dataset Accessors --- #
[docs]
def get_particle_group_handle(self, group_name: str) -> h5py.Group:
"""Get the HDF5 group handle for a specific particle group.
Parameters
----------
group_name : str
The name of the particle group to retrieve.
Returns
-------
h5py.Group
The HDF5 group handle for the specified particle group.
Raises
------
KeyError
If the specified group does not exist in the dataset.
"""
if group_name not in self.particle_groups:
raise KeyError(f"Particle group '{group_name}' does not exist in the dataset.")
return self.__handle__[group_name]
[docs]
def get_particle_field_handle(self, group_name: str, field_name: str):
"""Get the HDF5 dataset handle for a specific field in a particle group.
Parameters
----------
group_name : str
The name of the particle group containing the field.
field_name : str
The name of the field to retrieve.
Returns
-------
h5py.Dataset
The HDF5 dataset handle for the specified field.
Raises
------
KeyError
If the specified group or field does not exist in the dataset.
"""
group_handle = self.get_particle_group_handle(group_name)
if field_name not in group_handle:
raise KeyError(f"Field '{field_name}' does not exist in group '{group_name}'.")
return group_handle[field_name]
# --- Field Accessors --- #
[docs]
def get_particle_field(self, group_name: str, field_name: str) -> unyt.unyt_array:
"""Get the particle field data as a unyt array.
Parameters
----------
group_name : str
The name of the particle group containing the field.
field_name : str
The name of the field to retrieve.
Returns
-------
unyt.array.unyt_array
The data for the specified field, converted to a unyt array.
Raises
------
KeyError
If the specified group or field does not exist in the dataset.
"""
dataset_handle = self.get_particle_field_handle(group_name, field_name)
return unyt.unyt_array(dataset_handle[...], units=self.get_field_units(group_name, field_name))
[docs]
def get_particle_fields(self, fields: list[str]) -> dict[str, unyt.unyt_array]:
"""Get multiple particle fields as a dictionary of unyt arrays.
Parameters
----------
fields : List[str]
A list of field names in the format ``"group_name.field_name"``.
Returns
-------
dict
A dictionary mapping field names to their data as unyt arrays.
Raises
------
KeyError
If any specified field does not exist in the dataset.
"""
field_data = {}
for field in fields:
group_name, field_name = field.split(".")
field_data[field] = self.get_particle_field(group_name, field_name)
return field_data
[docs]
def get_field_units(self, group_name: str, field_name: str) -> unyt.Unit:
"""Get the units of a specific particle field.
Parameters
----------
group_name : str
The name of the particle group containing the field.
field_name : str
The name of the field whose units are to be retrieved.
Returns
-------
~unyt.unit_object.Unit
The units of the specified field as a unyt Unit object.
"""
dataset_handle = self.get_particle_field_handle(group_name, field_name)
return unyt.Unit(self._deserialized(dataset_handle.attrs.get("UNITS", "")))
# ------------------------------------ #
# Modification Methods #
# ------------------------------------ #
[docs]
def copy(self, output_path: str | Path, overwrite: bool = False, **kwargs) -> "ParticleDataset":
"""Create a full copy of this particle dataset at a new location.
This method replicates the entire contents of the HDF5 file, including all
particle groups, fields, field metadata, and global attributes, into a new file.
It returns a new :class:`ParticleDataset` instance pointing to the copied file.
Parameters
----------
output_path : str or Path
The path to the new HDF5 file to create.
overwrite : bool, optional
If True, overwrite the file at `output_path` if it already exists. Defaults to False.
**kwargs
Additional keyword arguments passed to the constructor of the copied dataset.
Returns
-------
ParticleDataset
A new dataset instance pointing to the copied file.
Raises
------
FileExistsError
If `output_path` exists and `overwrite` is False.
IsADirectoryError
If `output_path` is a directory.
"""
output_path = Path(output_path)
if output_path.exists():
if output_path.is_dir():
raise IsADirectoryError(f"Cannot copy dataset to a directory: {output_path}")
if not overwrite:
raise FileExistsError(f"File already exists at {output_path}. Use overwrite=True to overwrite.")
output_path.unlink()
output_path.parent.mkdir(parents=True, exist_ok=True)
# Open the output file and copy all content
with h5py.File(output_path, "w") as f_out:
# Copy all groups and datasets
self.handle.copy(source="/", dest=f_out, name="/")
return self.__class__(output_path, **kwargs)
[docs]
def add_particle_type(self, name: str, num_particles: int, metadata: dict = None, **kwargs):
"""Add a new particle group to the dataset.
This method creates a new HDF5 group representing a particle type (e.g., ``"baryons"``, ``"dark_matter"``)
and assigns the required metadata attribute ``NUMBER_OF_PARTICLES`` to the group. Optional metadata
can be added via the `metadata` dictionary or keyword arguments. This method is intended to be general
and extensible for use in simulation initialization, preprocessing, or structured data generation.
It is safe to override in subclasses that implement additional constraints or need to attach simulation-specific
annotations, provenance, or physical properties to each group.
Parameters
----------
name : str
The name of the new particle group. This must be unique within the dataset and conform to
HDF5 naming rules (alphanumeric, no slashes).
num_particles : int
The number of particles in the new group. This value is stored in the group's metadata
under the ``NUMBER_OF_PARTICLES`` key and is required for downstream field shape validation.
metadata : dict, optional
A dictionary of metadata attributes to attach to the group. Keys must be strings and values
must be serializable by HDF5 (e.g., int, float, str). These attributes will be written in
addition to ``NUMBER_OF_PARTICLES``.
**kwargs
Additional metadata attributes provided as keyword arguments. These are merged with `metadata`
and override any overlapping keys. Use this to quickly attach single attributes.
Raises
------
ValueError
If a group with the specified name already exists in the dataset.
Notes
-----
- All metadata is stored as HDF5 attributes on the group object.
- This method does **not** allocate any fields or datasets; it only creates the group and metadata.
- Subclasses may override this method to add simulation-specific metadata keys or validation logic.
"""
if name in self.particle_groups:
raise ValueError(f"Particle group '{name}' already exists.")
group = self.__handle__.create_group(name)
group.attrs["NUMBER_OF_PARTICLES"] = self._serialized(num_particles)
# Merge metadata from both dict and kwargs, prioritizing kwargs
metadata = metadata or {}
merged = {**metadata, **kwargs}
for key, value in merged.items():
group.attrs[key] = value
[docs]
def add_particle_field(
self,
group_name: str,
field_name: str,
data: unyt.unyt_array | np.ndarray,
metadata: dict = None,
overwrite: bool = False,
):
"""Add a new field (dataset) to an existing particle group.
This method creates and writes a new field to the specified particle group
within the dataset. The data can be a :class:`numpy.ndarray` or a :class:`~unyt.array.unyt_array`.
Units (if present) will be stored in the dataset metadata. Additional metadata
may also be included via the `metadata` argument.
Parameters
----------
group_name : str
The name of the particle group to which the field will be added. This must
be a pre-existing group in the particle dataset.
field_name : str
The name of the new field to create (e.g., ``"particle_mass"``). If the `field_name` is
already specified in the group, it will raise an error unless `overwrite` is True.
data : ~unyt.array.unyt_array or ~numpy.ndarray
The data to write to the new field. The leading dimension must match the number of particles
in the group.
metadata : dict, optional
Optional dictionary of metadata to attach to the dataset. This is generally used
by subclasses of the base class to add type specific metadata.
overwrite : bool, optional
Whether to overwrite an existing dataset with the same name. Defaults to False.
Raises
------
KeyError
If the group does not exist, or if the field exists and `overwrite` is False.
ValueError
If the field's leading dimension does not match the group's particle count.
"""
# Ensure that the group exists and that the field name is valid /
# correctly handle the overwrite behavior.
group = self.get_particle_group_handle(group_name)
group_attrs = self.get_group_metadata(group_name)
if field_name in group.keys():
# The field already exists. Our behavior depends on the `overwrite` flag.
if not overwrite:
raise ValueError("Field already exists. Set `overwrite=True` to replace it.")
else:
del group[field_name]
# Determine the number of particles expected and
# ensure that the data matches this shape.
num_particles = group_attrs["NUMBER_OF_PARTICLES"]
data = np.atleast_1d(data)
if data.shape[0] != num_particles:
raise ValueError(
f"Number of particles in field {field_name} does not match number of particles in group {group_name}"
)
# Validation has been completed and we can therefore now
# proceed with writing the field to the group.
units = getattr(data, "units", "")
if isinstance(data, unyt.unyt_array):
dset = group.create_dataset(field_name, data=data.d, dtype=data.dtype)
else:
dset = group.create_dataset(field_name, data=data, dtype=data.dtype)
# Handle the metadata.
dset.attrs["UNITS"] = self.metadata_serializer.serialize_data(units)
if metadata is not None:
dset.attrs.update(self.metadata_serializer.serialize_dict(metadata))
[docs]
def remove_particle_group(self, group_name: str):
"""Remove a particle group from the dataset.
This method deletes the specified particle group and all its associated fields
from the dataset. It is a destructive operation and cannot be undone.
Parameters
----------
group_name : str
The name of the particle group to remove.
Raises
------
KeyError
If the specified group does not exist in the dataset.
"""
if group_name not in self.particle_groups:
raise KeyError(f"Particle group '{group_name}' does not exist in the dataset.")
del self.__handle__[group_name]
[docs]
def remove_particle_field(self, group_name: str, field_name: str):
"""Remove a specific field from a particle group.
This method deletes the specified field from the given particle group.
It is a destructive operation and cannot be undone.
Parameters
----------
group_name : str
The name of the particle group containing the field to remove.
field_name : str
The name of the field to remove.
Raises
------
KeyError
If the specified group or field does not exist in the dataset.
"""
group_handle = self.get_particle_group_handle(group_name)
if field_name not in group_handle:
raise KeyError(f"Field '{field_name}' does not exist in group '{group_name}'.")
del group_handle[field_name]
[docs]
def extend_group(
self,
group_name: str,
num_particles: int,
fields: dict[str, np.ndarray | unyt.unyt_array] = None,
):
"""Extend a particle group by adding new particles and updating all fields.
This method appends `num_particles` new entries to each existing field in the group.
For fields provided in the `fields` dictionary, the new particle values are appended.
For fields not provided, the new values are filled with NaN (if supported).
The group's ``NUMBER_OF_PARTICLES`` attribute is updated accordingly.
Parameters
----------
group_name : str
The name of the group to extend.
num_particles : int
The number of new particles to append to the group.
fields : dict, optional
A dictionary mapping field names to new particle data arrays of shape
(num_particles, ...) to append. Any fields not specified here will be
extended with `NaN` fill values if their dtype supports it.
Raises
------
KeyError
If the specified group does not exist.
ValueError
If the new data for any field has incompatible shape.
TypeError
If an existing field cannot be filled with NaNs and no new data is provided.
"""
# Ensure access to the group and that the
# number of new particles is non zero.
group = self.get_particle_group_handle(group_name)
group_attrs = self.get_group_metadata(group_name)
if not isinstance(num_particles, int) or num_particles <= 0:
raise ValueError("`num_particles` must be a positive integer.")
# Create the field dictionary and
# modify the group attribute.
fields = fields or {}
old_particle_count = group_attrs["NUMBER_OF_PARTICLES"]
new_particle_count = old_particle_count + num_particles
# Make corrections to the fields.
for field_name in group.keys():
# In order to correct the fields, the first step is
# to extract the previously existing data and its
# relevant metadata since we'll need to delete it to
# continue.
_original_field_array = group[field_name][...]
_metadata = dict(group[field_name].attrs)
# Start by extending the dataset to accommodate the new particles.
# this is always performed the same way regardless whether we
# have the field data or not.
_new_field_array = np.zeros((new_particle_count,) + _original_field_array.shape[1:])
_new_field_array[:old_particle_count, ...] = _original_field_array
if field_name in fields:
# Coerce the field data to an unyt array so that
# we have an easier time manipulating the units.
# This will assign dimensionless units to empty arrays.
field_data = unyt.unyt_array(fields[field_name])
# If we have a field that we are going to insert, we need to
# check the shape and handle the units to ensure that everything
# behaves correctly.
if field_data.shape != (num_particles,) + _original_field_array.shape[1:]:
raise ValueError(
f"Field '{field_name}' data must have shape "
f"({num_particles}, ...) to match existing field shape."
)
try:
field_data = field_data.to_value(self.get_field_units(group_name, field_name))
except Exception as exp:
raise TypeError(f"Cannot convert field '{field_name}' data to existing units: {exp}") from exp
# Fill the remaining data with the correct values.
_new_field_array[old_particle_count:, ...] = field_data
else:
_new_field_array[old_particle_count:, ...] = np.nan
# Now that the _new_field_array is filled, we need to
# delete and replace the existing dataset with the new one.
del group[field_name]
dset = group.create_dataset(field_name, data=_new_field_array, dtype=_new_field_array.dtype)
dset.attrs["UNITS"] = _metadata.get("UNITS", "")
for key, value in _metadata.items():
if key != "UNITS":
dset.attrs[key] = value
[docs]
def concatenate_inplace(self, *others: "ParticleDataset", groups=None):
"""Concatenate another :class:`ParticleDataset` into this one, extending specified groups.
This method appends the particle data from `other` to this dataset for the specified groups.
If `groups` is None, all groups in `other` are concatenated. The number of particles in each
group is updated accordingly.
Parameters
----------
*others : list of ParticleDataset
The datasets to concatenate into this one.
groups : list of str, optional
Names of groups to concatenate. If None, all groups are concatenated.
Raises
------
KeyError
If a specified group does not exist in either dataset.
ValueError
If the datasets have incompatible shapes for concatenation.
"""
for other in others:
# Select the groups of `other` that we're going to add to
# our own groups. If `groups` is None, we will use all of the groups.
if groups is None:
groups = other.particle_groups
# Iterate through all of the groups so
# that we can concatenate all of the groups.
for group in groups:
# Check if the group is already present in the new dataset
# or if it needs to be added.
if group in self.particle_groups:
# This group is already present. We need to
# concatenate the data. This will be a little bit
# trickier than the missing group case.
# Extract all the fields from the old group and
# begin the procedure of extending the group.
group_fields = [k for k in other.fields if k.startswith(group + ".")]
old_fields = {field: other.get_particle_field(group, field) for field in group_fields}
self.extend_group(group, other.num_particles[group], fields=old_fields)
else:
# We don't already have the group so we can just
# copy the group directly across.
source_group = other.handle[group]
self.handle.copy(source=source_group, dest=self.handle, name=group)
[docs]
def reduce_group(
self,
group_name: str,
mask: np.ndarray | unyt.unyt_array,
):
"""Reduce a particle group by applying a boolean mask.
This method filters all fields in the specified group by the given mask,
retaining only those particles where the mask is `True`. All other particles
are discarded. The group's ``NUMBER_OF_PARTICLES`` attribute is updated accordingly.
This is a destructive operation.
Parameters
----------
group_name : str
Name of the particle group to apply the mask to.
mask : array_like of bool
Boolean array of shape (N,) where N is the number of particles in the group.
Must be 1D and have exactly one element per particle.
Raises
------
KeyError
If the specified group does not exist.
ValueError
If the mask has an incorrect shape or is not boolean.
"""
group = self.get_particle_group_handle(group_name)
group_attrs = self.get_group_metadata(group_name)
n = group_attrs["NUMBER_OF_PARTICLES"]
mask = np.asarray(mask)
if mask.shape != (n,) or mask.dtype != bool:
raise ValueError(f"Mask must be a 1D boolean array of shape ({n},)")
new_count = int(np.count_nonzero(mask))
for field_name in list(group.keys()):
old_data = group[field_name][...]
new_data = old_data[mask]
# Preserve metadata and overwrite dataset
metadata = dict(group[field_name].attrs)
del group[field_name]
dset = group.create_dataset(field_name, data=new_data, dtype=new_data.dtype)
for key, value in metadata.items():
dset.attrs[key] = value
group.attrs["NUMBER_OF_PARTICLES"] = self.metadata_serializer.serialize_data(new_count)
[docs]
def rename_field(self, group_name: str, old_name: str, new_name: str):
"""Rename a field within a particle group.
This method renames the dataset (field) `old_name` to `new_name` in the specified group.
Metadata is preserved during the renaming. This operation is destructive and cannot
be undone.
Parameters
----------
group_name : str
The name of the particle group containing the field.
old_name : str
The current name of the field to rename.
new_name : str
The new name to assign to the field.
Raises
------
KeyError
If the specified group or field does not exist.
ValueError
If the new field name already exists in the group.
"""
group = self.get_particle_group_handle(group_name)
if old_name not in group:
raise KeyError(f"Field '{old_name}' does not exist in group '{group_name}'.")
if new_name in group:
raise ValueError(f"Field '{new_name}' already exists in group '{group_name}'.")
# Extract existing data and metadata
data = group[old_name][...]
metadata = dict(group[old_name].attrs)
# Create the new dataset with the same data and metadata
dset = group.create_dataset(new_name, data=data, dtype=data.dtype)
for key, value in metadata.items():
dset.attrs[key] = value
# Remove the old field
del group[old_name]
[docs]
def offset_particle_positions(self, offset: unyt.unyt_array, groups: list[str] = None):
"""Apply a constant offset to particle positions in specified groups.
The method adds the given offset vector to the particle position field.
This is the correct way to shift particle coordinates around via translation.
.. note::
The name of the particle position field is assumed to be that specified by
the class's ``_POSITION_FIELD_NAME`` attribute. If your dataset does not have
this field, you will need to manually apply the offset to the appropriate field
or rename the field.
Parameters
----------
offset : unyt.array.unyt_array
A vector specifying the offset to apply. Must have units compatible
with the ``particle_position`` field(s). The `offset` may be any 1D array; however,
it must match the shape of the particle positions. Thus, if the particles are in 3D space,
the `offset` must be a 3-element vector.
groups : list of str, optional
Names of groups to apply the offset to. If None, all particle groups are used.
Raises
------
ValueError
If `offset` is not a 3-element vector.
Notes
-----
To ensure that this method functions properly across subclasses with various naming
conventions, we require that the position field be named according to the class
attribute ``_POSITION_FIELD_NAME``. Subclasses may change this attribute to match
a particular naming convention.
"""
# Ensure that the offset gets cast to an unyt array so
# that it at least has unit attributes. We will check for
# unit consistency later.
offset = unyt.unyt_array(offset)
# Handle the groups.
if groups is None:
groups = self.particle_groups
# Now for each of the groups, we're going to
# cycle through, apply the offset, and continue.
# If we run into a shape issue, we raise an error.
for group in groups:
field_key = f"{group}.{self.__class__._POSITION_FIELD_NAME}"
if field_key not in self:
continue
# Obtain the handle and the units.
handle = self.get_particle_field_handle(group, self.__class__._POSITION_FIELD_NAME)
units = self.get_field_units(group, self.__class__._POSITION_FIELD_NAME)
# Check the shape.
if handle.shape[-1] != len(offset):
raise ValueError(f"Offset must match the shape of particle positions in group '{group}'.")
# Apply the offset.
handle[...] += offset.to_value(units)
[docs]
def offset_particle_velocities(self, offset: unyt.unyt_array, groups: list[str] = None):
"""Apply a constant offset to particle velocities in specified groups.
The method adds the given offset vector to the particle velocity field.
This is the correct way to shift particle coordinates around via translation.
.. note::
The name of the particle velocity field is assumed to be that specified by
the class's ``_VELOCITY_FIELD_NAME`` attribute. If your dataset does not have
this field, you will need to manually apply the offset to the appropriate field
or rename the field.
Parameters
----------
offset : unyt.array.unyt_array
A vector specifying the velocity offset to apply. Must have units compatible
with the ``particle_velocity`` field(s). The `offset` may be any 1D array; however,
it must match the shape of the particle velocities. Thus, if the particles are in 3D space,
the `offset` must be a 3-element vector.
groups : list of str, optional
Names of groups to apply the offset to. If None, all particle groups are used.
Raises
------
ValueError
If `offset` is not the correct shape.
"""
# Ensure the offset has unit information.
offset = unyt.unyt_array(offset)
# Determine the list of groups to modify.
if groups is None:
groups = self.particle_groups
for group in groups:
field_key = f"{group}.{self.__class__._VELOCITY_FIELD_NAME}"
if field_key not in self:
continue
# Obtain the handle and the units.
handle = self.get_particle_field_handle(group, self.__class__._VELOCITY_FIELD_NAME)
units = self.get_field_units(group, self.__class__._VELOCITY_FIELD_NAME)
# Confirm shape match.
if handle.shape[-1] != len(offset):
raise ValueError(f"Offset must match the shape of particle velocities in group '{group}'.")
# Apply the velocity offset.
handle[...] += offset.to_value(units)
[docs]
def apply_linear_transformation(
self,
matrix: np.ndarray,
groups: list[str] = None,
fields: tuple[str, ...] = None,
):
r"""Apply a linear transformation matrix to vector fields in specified particle groups.
This method performs an in-place matrix transformation on each specified vector field
(e.g., ``particle_position``, ``particle_velocity``) in one or more particle groups. It is
useful for performing operations such as coordinate rotation, scaling, reflection, or shear.
Each particle's vector field :math:`\mathbf{x}_i` is updated according to:
.. math::
\mathbf{x}_i \rightarrow \mathbf{A} \cdot \mathbf{x}_i
where :math:`\mathbf{A}` is the transformation matrix and :math:`\mathbf{x}_i` is the
vector value (e.g., position or velocity) of the :math:`i`-th particle.
Parameters
----------
matrix : array_like
A 2D NumPy array of shape :math:`(D, D)` representing the linear transformation
to apply. The dimension :math:`D` must match the last axis of each target field.
groups : list of str, optional
List of particle group names to which the transformation will be applied. If None,
all groups in the dataset are used.
fields : tuple of str, optional
Tuple of field names (e.g., ``particle_position``, ``particle_velocity``) to transform.
Default is ``("particle_position", "particle_velocity")``.
Raises
------
ValueError
If `matrix` is not square or its shape does not match the vector dimensionality
of the fields being transformed.
KeyError
If a specified field is not present in the given group(s).
Notes
-----
- The transformation is performed in-place and modifies the original field values.
- Fields that are not present in a group are skipped silently.
- This operation assumes that vector fields are stored with shape :math:`(N, D)`, where
:math:`N` is the number of particles and :math:`D` is the number of spatial dimensions.
"""
# Set the default fields to all of the standard vector fields we
# expect.
if fields is None:
fields = (self.__class__._POSITION_FIELD_NAME, self.__class__._VELOCITY_FIELD_NAME)
matrix = np.asarray(matrix)
if matrix.ndim != 2 or matrix.shape[0] != matrix.shape[1]:
raise ValueError(f"Transformation matrix must be square (D, D), got shape {matrix.shape}.")
if groups is None:
groups = self.particle_groups
for group in groups:
for field in fields:
field_key = f"{group}.{field}"
if field_key not in self:
continue # Skip if the field is not present
handle = self.get_particle_field_handle(group, field)
if handle.shape[-1] != matrix.shape[0]:
raise ValueError(
f"Field '{field}' in group '{group}' has vector dimension {handle.shape[-1]}, "
f"which does not match transformation matrix shape {matrix.shape}."
)
# Apply transformation using Einstein summation (broadcast-safe)
transformed = np.einsum("ij,nj->ni", matrix, handle[...])
handle[...] = transformed
[docs]
def rotate_particles(
self,
norm: np.ndarray,
angle: float,
groups: list[str] = None,
fields: tuple[str, ...] = None,
):
r"""Rotate vector fields in specified particle groups around a given axis.
This method rotates each specified vector field (e.g., ``particle_position``, ``particle_velocity``)
around the axis defined by `norm` by a given angle. The rotation is applied uniformly across
all particles in the specified groups.
The transformation uses the **Rodrigues' rotation formula**, which constructs a rotation matrix
for an axis–angle pair. For a unit vector :math:`\hat{n}` and angle :math:`\theta`, the formula is:
.. math::
R = I + \sin\theta [\hat{n}]_\times + (1 - \cos\theta) [\hat{n}]_\times^2
where :math:`[\hat{n}]_\times` is the skew-symmetric matrix of the axis vector.
For more details, see: `Rodrigues Formula <https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula>`__.
Parameters
----------
norm : array_like
A 3-element vector representing the rotation axis. This does not need to be normalized;
it will be internally converted to a unit vector.
angle : float
The rotation angle in radians.
groups : list of str, optional
The particle groups to apply the rotation to. If None, all particle groups are used.
fields : tuple of str, optional
The vector fields to rotate. Defaults to ("particle_position", "particle_velocity").
Raises
------
ValueError
If the axis is not a 3-element vector, or if the angle is invalid.
"""
# Set the default fields to all of the standard vector fields we
# expect.
if fields is None:
fields = (self.__class__._POSITION_FIELD_NAME, self.__class__._VELOCITY_FIELD_NAME)
# Normalize axis
axis = np.asarray(norm, dtype=float)
if axis.shape != (3,):
raise ValueError("Rotation axis must be a 3-element vector.")
axis /= np.linalg.norm(axis)
# Compute Rodrigues rotation matrix
x, y, z = axis
K = np.array([[0, -z, y], [z, 0, -x], [-y, x, 0]])
R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
self.apply_linear_transformation(R, groups=groups, fields=fields)
[docs]
def reorient_particles(
self,
direction_vector: np.ndarray,
spin: float,
groups: list[str] = None,
fields: tuple[str, ...] = None,
):
"""
Reorient particles given a direction vector and spin angle.
This method aligns the +z axis with the specified `direction_vector` and then applies
a spin rotation about the new axis. The transformation is applied to the specified
vector fields (e.g., ``particle_position``, ``particle_velocity``) in the specified particle groups.
The transformation uses a two-step process:
1. **Alignment**: Compute the minimal rotation that aligns the +z axis with the
normalized `direction_vector`.
2. **Spin**: Apply a rotation by `spin` radians about the new z-axis (aligned with `direction_vector`).
The final rotation matrix is the composition of the alignment and spin rotations.
Parameters
----------
direction_vector : array_like
Target direction to become the new +z axis. Need not be unit, but must be nonzero.
spin : float
Spin angle (in radians) applied about the new axis after alignment.
groups : list of str, optional
Particle groups to transform. If None, all groups are used.
fields : tuple of str, optional
Vector fields to transform. Defaults to ``("particle_position", "particle_velocity")``.
Returns
-------
numpy.ndarray
The 3x3 rotation matrix applied: ``R = R_spin @ R_align``.
"""
# Set the default fields to all of the standard vector fields we
# expect.
if fields is None:
fields = (self.__class__._POSITION_FIELD_NAME, self.__class__._VELOCITY_FIELD_NAME)
# Validate and normalize the direction vector for
# the target axis.
direction_vector = np.asarray(direction_vector, dtype=float)
if direction_vector.ndim != 1 or direction_vector.shape[0] != 3:
raise ValueError("Direction vector must be a 3-element vector.")
dv_norm = np.linalg.norm(direction_vector)
if dv_norm <= 1e-8:
raise ValueError("Direction vector must be nonzero.")
_dv = direction_vector / dv_norm # unit target axis
# Construct the necessary data to construct the
# relevant rotation matrices.
z_axis = np.array([0.0, 0.0, 1.0])
cross = np.cross(z_axis, _dv)
sin_phi = np.linalg.norm(cross) # = sin(tilt)
cos_phi = float(np.dot(z_axis, _dv)) # = cos(tilt)
if sin_phi < 1e-12 and cos_phi > 0.0:
# Already aligned
R_align = np.eye(3)
elif sin_phi < 1e-12 and cos_phi < 0.0:
# Exactly opposite: rotate pi about any axis ⟂ z (choose x)
Kx = np.array([[0, 0, 0], [0, 0, -1], [0, 1, 0]]) # skew([1,0,0])
R_align = np.eye(3) + np.sin(np.pi) * Kx + (1 - np.cos(np.pi)) * (Kx @ Kx)
else:
# General case: axis = unit(cross), angle = atan2(sin_phi, cos_phi)
rot_axis = cross / sin_phi
x, y, z = rot_axis
K = np.array([[0, -z, y], [z, 0, -x], [-y, x, 0]])
# Use sin/cos of angle directly via sin_phi/cos_phi (avoids recomputing trig)
R_align = np.eye(3) + sin_phi * K + (1 - cos_phi) * (K @ K)
# After alignment, the new z-axis is exactly _dv.
# Step 2: Spin by `spin` about _dv (which is world-space axis after R_align)
x, y, z = _dv
Kspin = np.array([[0, -z, y], [z, 0, -x], [-y, x, 0]])
s = np.sin(spin)
c = np.cos(spin)
R_spin = np.eye(3) + s * Kspin + (1 - c) * (Kspin @ Kspin)
# Compose to "align then spin" (active, column-vector convention)
R = R_spin @ R_align
# Apply to requested particle groups/fields
self.apply_linear_transformation(R, groups=groups, fields=fields)
return R
[docs]
def cut_particles_to_bbox(
self,
bbox: unyt.unyt_array,
groups: list[str] = None,
center: unyt.unyt_array = None,
):
"""
Cut out particles outside the specified bounding box.
This method takes a **cartesian** bounding box defined by `bbox` and
checks the positions of all particles in all groups. Particles that lie
outside of the bounding box are removed from the dataset.
Parameters
----------
bbox: ~unyt.array.unyt_array
A 2D unyt array of shape (2, D) defining the bounding box in D dimensions.
The first row specifies the minimum corner, and the second row specifies
the maximum corner. Units must be compatible with particle positions.
groups: list of str, optional
List of particle group names to apply the cut to. If None, all groups are used.
center: ~unyt.array.unyt_array, optional
An optional center point to offset the bounding box. If provided, the bounding
box is shifted by this center before applying the cut. Units must be compatible
with particle positions.
"""
# Validate the particle positions field so that
# we can uniformly access it as a dictionary.
groups = groups if groups is not None else self.particle_groups
# manage center if it needs to be managed.
if center is None:
center = unyt.unyt_array([0.0] * bbox.shape[1], bbox.units)
else:
center = unyt.unyt_array(center)
for particle_type in self.particle_groups:
# Check if we are processing this group.
if particle_type not in groups:
continue
# Extract the position array for this particle type
# so that we can determine the dimension and eventually
# obtain the mask.
position_field_handle = self.get_particle_field_handle(particle_type, self._POSITION_FIELD_NAME)
position_field_units = self.get_field_units(particle_type, self._POSITION_FIELD_NAME)
ndim = position_field_handle.shape[-1]
# Check that the number of dimensions is compatible with the
# bounding box.
if bbox.shape != (2, ndim):
raise ValueError(
f"Bounding box shape {bbox.shape} is incompatible with "
f"particle positions of dimension {ndim} in group '{particle_type}'."
)
if center.shape != (ndim,):
raise ValueError(
f"Center shape {center.shape} is incompatible with "
f"particle positions of dimension {ndim} in group '{particle_type}'."
)
# We now want to make a local copy of the bounding box in the correct
# units for the position field.
_bbox_unitless = bbox.to_value(position_field_units)
_center_unitless = center.to_value(position_field_units)
# We now generate the mask against the bounding box.
mask = np.all(
[
(_bbox_unitless[0, _k] <= position_field_handle[:, _k] + _center_unitless[_k])
& (_bbox_unitless[1, _k] >= position_field_handle[:, _k] + _center_unitless[_k])
for _k in range(ndim)
],
axis=0,
)
# With the mask, we can now reduce the particle group accordingly.
self.reduce_group(particle_type, mask)
[docs]
def add_particle_ids(self, groups: list[str] = None, policy: str = None, overwrite: bool = False, **kwargs):
"""
Add unique particle IDs to specified groups.
Parameters
----------
groups: list of str, optional
List of particle group names to add particle ids to. If None, all groups are used.
This can be used both to specify the groups that should be given an ID field and
(for ``policy='global'``) to specify the order in which the IDs are assigned to each
group.
policy : {"global", "per_group"}, optional
The policy for assigning particle IDs. Options are:
- "global": Assign unique IDs across all specified groups, ensuring no duplicates.
IDs are assigned sequentially starting from 1, in the order of groups provided.
- "per_group": Assign unique IDs within each group, starting from 1 for each group.
This allows for duplicate IDs across different groups.
Default is "global".
overwrite: bool, optional
Whether to overwrite existing ID fields if they already exist. Defaults to False.
kwargs:
Additional keyword arguments passed to the ID generation function. The following
are recognized:
- ``start_id``: int, optional
The starting ID number for the first group (default is 1).
- ``dtype``: numpy dtype, optional
The numpy dtype to use for the ID field (default is np.uint32).
Notes
-----
Behind the scenes, this method will do two things:
1. Create a "particle id" group (named following the class's ``_ID_FIELD_NAME`` attribute) which contains
an ordered list of all of the particle IDs for that particle type. Depending on the policy,
these ids will either start from 1 for each group or will be globally unique across all groups.
2. Each group will get a ``PIDOFF`` attribute, which indicates the PID of the very first particle in
that group.
"""
# Validate the input information, set up the groups, pull the ID
# policy, and the starting ID.
groups = groups if groups is not None else self.particle_groups
if any(grp not in self.particle_groups for grp in groups):
raise ValueError("All specified groups must exist in the dataset.")
policy = policy if policy is not None else self.__class__._ID_POLICY
if policy not in ("global", "per_group"):
raise ValueError("Policy must be either 'global' or 'per_group'.")
start_id = kwargs.get("start_id", 1)
dtype = kwargs.get("dtype", np.uint32)
# Now for each group in the list of groups, we go through and add the
# relevant field. If we encounter an existing field and overwrite is False,
# we raise an error.
_offset = start_id
for group in groups:
# Extract the number of particles in this group
num_particles = self.get_group_metadata(group)["NUMBER_OF_PARTICLES"]
# Generate the IDs based on the policy.
_id_array = np.arange(_offset, _offset + num_particles, dtype=dtype)
self.add_particle_field(
group_name=group,
field_name=self.__class__._ID_FIELD_NAME,
data=unyt.unyt_array(_id_array, ""),
overwrite=overwrite,
)
self.update_group_metadata(group, {"PIDOFF": _offset})
if policy == "global":
_offset += num_particles
else:
pass
[docs]
def update_particle_ids(self, groups: list[str] = None, policy: str = None, **kwargs):
"""
Update (reset) particle IDs for groups that already have an ID field.
This method goes through each of the specified groups and checks for the existence
of a particle ID field. If the field exists, it is overwritten with a new set of IDs
according to the specified policy. If the field does not exist, the group is skipped
silently.
For ``policy='global'``, IDs are assigned sequentially across all specified groups. As such,
even if a group does not have particle ids, it's total number of particles is still counted
towards the global ID assignment.
Parameters
----------
groups: list of str, optional
List of particle group names to add particle ids to. If None, all groups are used.
This can be used both to specify the groups that should be given an ID field and
(for ``policy='global'``) to specify the order in which the IDs are assigned to each
group.
policy : {"global", "per_group"}, optional
The policy for assigning particle IDs. Options are:
- "global": Assign unique IDs across all specified groups, ensuring no duplicates.
IDs are assigned sequentially starting from 1, in the order of groups provided.
- "per_group": Assign unique IDs within each group, starting from 1 for each group.
This allows for duplicate IDs across different groups.
Default is "global".
kwargs:
Additional keyword arguments passed to the ID generation function. The following
are recognized:
- ``start_id``: int, optional
The starting ID number for the first group (default is 1).
- ``dtype``: numpy dtype, optional
The numpy dtype to use for the ID field (default is np.uint32).
Notes
-----
- Groups without an existing ID field are skipped silently.
- This operation **overwrites** the existing ID field for selected groups.
- IDs are always assigned as **1-based** positive integers, consistent with
Gadget and most downstream analysis tools.
"""
# Validate the input information, set up the groups, pull the ID
# policy, and the starting ID.
groups = groups if groups is not None else self.particle_groups
if any(grp not in self.particle_groups for grp in groups):
raise ValueError("All specified groups must exist in the dataset.")
policy = policy if policy is not None else self.__class__._ID_POLICY
if policy not in ("global", "per_group"):
raise ValueError("Policy must be either 'global' or 'per_group'.")
start_id = kwargs.get("start_id", 1)
dtype = kwargs.get("dtype", np.uint32)
# Track running offset for global assignment
offset = start_id
id_field = self.__class__._ID_FIELD_NAME
for group in groups:
n = int(self.get_group_metadata(group)["NUMBER_OF_PARTICLES"])
if f"{group}.{id_field}" in self:
# Group has an existing ID field → reset it
if policy == "global":
ids = np.arange(offset, offset + n, dtype=dtype)
self.update_group_metadata(group, {"PIDOFF": offset})
else: # per_group
ids = np.arange(start_id, start_id + n, dtype=dtype)
self.update_group_metadata(group, {"PIDOFF": start_id})
self.add_particle_field(
group_name=group,
field_name=id_field,
data=unyt.unyt_array(ids, ""), # dimensionless
overwrite=True,
)
# For global policy: always advance the offset,
# even if the group did not have an ID field.
if policy == "global":
offset += n
# ------------------------------------- #
# Generation Methods #
# ------------------------------------- #
[docs]
@classmethod
def build_particle_dataset(
cls,
path: str | Path,
*args,
fields: dict[str, unyt.unyt_array] = None,
overwrite: bool = False,
**kwargs,
):
"""Create a new :class:`ParticleDataset` HDF5 file with the given fields.
This method initializes a new HDF5 file, organizes the provided fields into
groups based on their dot notation names (e.g., ``"baryons.particle_mass"``),
and writes each field to the appropriate group. Each group must have fields
with the same number of particles (i.e., matching leading dimension).
Global metadata (such as creation date) is also written.
This is the standard factory method for creating new Pisces-compatible particle datasets.
Parameters
----------
path : str or pathlib.Path
The target path for the new HDF5 file.
fields : dict of {str: unyt.array.unyt_array}, optional
A dictionary mapping dot-notation field names to unyt arrays.
If None or empty, a valid file with metadata but no particle data is created.
overwrite : bool, optional
Whether to overwrite the file if it already exists. Defaults to False.
*args, **kwargs:
Additional positional and keyword arguments passed to the dataset constructor.
Returns
-------
ParticleDataset
An instance of the newly created dataset.
Raises
------
FileExistsError
If the file exists and `overwrite` is False.
IsADirectoryError
If `path` is a directory.
ValueError
If any field name is not in dot notation or group fields mismatch in particle count.
"""
path = Path(path)
# --- Path validation ---
if path.exists() and not overwrite:
raise FileExistsError(f"File already exists: {path}. Use overwrite=True to replace it.")
elif path.exists() and overwrite:
path.unlink()
elif path.is_dir():
raise IsADirectoryError(f"Path is a directory: {path}")
path.parent.mkdir(parents=True, exist_ok=True)
# --- Create HDF5 file ---
with h5py.File(path, "w") as f:
# Add required global metadata
initial_metadata = {
"GEN_TIME": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"CLASS_NAME": cls.__name__,
}
for k, v in cls.metadata_serializer.serialize_dict(initial_metadata).items():
f.attrs[k] = v
if fields:
group_registry: dict[str, int] = {}
for full_field_name, data in fields.items():
# Validate field name format
try:
group_name, field_name = full_field_name.split(".")
except ValueError as err:
raise ValueError(
f"Invalid field name '{full_field_name}'. Expected format 'group.field'."
) from err
# Validate data type
if not isinstance(data, unyt.unyt_array):
raise TypeError(f"Field '{full_field_name}' must be a unyt_array.")
# Create or validate the group
if group_name not in group_registry:
group = f.create_group(group_name)
group_registry[group_name] = data.shape[0]
group.attrs["NUMBER_OF_PARTICLES"] = cls.metadata_serializer.serialize_data(data.shape[0])
else:
if data.shape[0] != group_registry[group_name]:
raise ValueError(
f"Inconsistent particle count for field '{full_field_name}'. "
f"Expected {group_registry[group_name]}, got {data.shape[0]}"
)
# Create dataset and write units
dset = f[group_name].create_dataset(field_name, data=data, dtype=data.dtype)
dset.attrs["UNITS"] = cls.metadata_serializer.serialize_data(data.units)
# Return a validated ParticleDataset instance
return cls(path, *args, **kwargs)