"""IO utilities for Pisces.
This module contains a number of helpful IO operations for Pisces which are used
frequently in various parts of the project.
"""
import json
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Union
import numpy as np
import unyt
from ruamel.yaml.comments import CommentedMap
if TYPE_CHECKING:
from ruamel.yaml import YAML
from ruamel.yaml.constructor import Constructor
from ruamel.yaml.nodes import Node
from ruamel.yaml.representer import Representer
# ----------------------------------------------- #
# HDF5 File Management and Serialization Tools #
# ----------------------------------------------- #
# This module provides a JSON-based serialization utility for storing complex
# Python objects as HDF5 attributes. It supports custom types like `unyt` arrays,
# quantities, and units, allowing them to be serialized to JSON strings that can
# be safely stored as attributes in HDF5 files.
[docs]
class HDF5Serializer:
"""
A JSON-based serialization utility for storing complex Python objects as HDF5 attributes.
This class provides a standardized mechanism for serializing and
deserializing Python objects — including extended types like `unyt` arrays,
quantities, and units — to and from JSON strings that can be safely stored
as attributes in HDF5 files.
Custom types are registered via the `__REGISTRY__` class variable, which maps
Python types to a tuple of:
- A unique string tag identifying the type in serialized form,
- A serialization function (returning a JSON-serializable dict),
- A deserialization function (taking that dict and reconstructing the object).
Built-in JSON-compatible types (e.g., int, float, str, list, dict) are serialized
directly via `json.dumps`. All deserialized outputs are reconstructed either
via the registry or returned as-is for base types.
This class is intended to be subclassed or extended to support additional
custom types.
"""
# ----------------------------------------------- #
# Serialization Methods #
# ----------------------------------------------- #
[docs]
@staticmethod
def serialize_unyt(o: Union[unyt.unyt_array, unyt.unyt_quantity]) -> dict:
"""
Serialize a unyt object into a dictionary.
Parameters
----------
o : unyt.array.unyt_array or unyt.array.unyt_quantity
The unyt object to serialize.
Returns
-------
dict
A dictionary representation of the unyt object.
"""
if isinstance(o, unyt.unyt_array):
return {"value": o.value.tolist(), "units": str(o.units)}
elif isinstance(o, unyt.unyt_quantity):
return {"value": o.value, "units": str(o.units)}
else:
raise TypeError(
f"Unsupported type for serialization: {type(o)}. Expected unyt.unyt_array or unyt.unyt_quantity."
)
[docs]
@staticmethod
def serialize_unyt_unit(o: unyt.Unit) -> dict:
"""
Serialize a unyt unit into a dictionary.
Parameters
----------
o : unyt.Unit
The unyt unit to serialize.
Returns
-------
dict
A dictionary representation of the unyt unit.
"""
return {"value": str(o)}
# ----------------------------------------------- #
# Deserialization Methods #
# ----------------------------------------------- #
[docs]
@staticmethod
def deserialize_unyt_unit(o: dict) -> unyt.Unit:
"""
Deserialize a dictionary into a unyt unit.
Parameters
----------
o : dict
The dictionary to deserialize.
Returns
-------
unyt.Unit
A unyt unit object.
"""
return unyt.Unit(o["value"])
[docs]
@staticmethod
def deserialize_unyt(o: dict) -> Union[unyt.unyt_array, unyt.unyt_quantity]:
"""
Deserialize a dictionary into a unyt object.
Parameters
----------
o : dict
The dictionary to deserialize.
Returns
-------
unyt.array.unyt_array or unyt.array.unyt_quantity
A unyt object.
"""
if "units" not in o:
raise ValueError("Missing 'units' key in the dictionary for deserialization.")
units = unyt.Unit(o["units"])
if "value" in o and isinstance(o["value"], list):
return unyt.unyt_array(o["value"], units)
elif "value" in o:
return unyt.unyt_quantity(o["value"], units)
else:
raise ValueError("Missing 'value' key in the dictionary for deserialization.")
# ----------------------------------------------- #
# Class Variables / Registry #
# ----------------------------------------------- #
__REGISTRY__: dict[type, tuple[str, Callable[[Any], dict], Callable[[dict], Any]]] = {
unyt.unyt_quantity: ("unyt_quantity", serialize_unyt, deserialize_unyt),
unyt.unyt_array: ("unyt_array", serialize_unyt, deserialize_unyt),
unyt.Unit: ("unyt_unit", serialize_unyt_unit, deserialize_unyt_unit),
}
@classmethod
def _to_jsonable(cls, obj: Any, _path="root"):
"""Recursively convert `obj` into a JSON-serializable Python structure.
Notes
-----
- This returns only JSON-native container/scalar
types: dict, list, str, int, float, bool, None.
- Custom types registered in `__REGISTRY__` are encoded
as dicts with a `"tag"` plus payload.
- Strings are treated as primitives (never iterated like sequences).
- Actual JSON encoding (quoting/escaping) is done by `json.dumps` elsewhere.
"""
# Start by processing matches to custom (registered) serialization types.
# If the object is an instance of any registered type, serialize it and inject a "tag".
for _type, (tag, serializer, _) in cls.__REGISTRY__.items():
if isinstance(obj, _type):
payload = serializer(obj)
if not isinstance(payload, dict):
raise ValueError(f"{_path}: custom serializer for {type(obj)} must return a dict")
# Merge the payload with a required "tag" to identify the type on load
return {"tag": tag, **payload}
# These are directly JSON-serializable and don't need any transformation.
if obj is None or isinstance(obj, (bool, int, float, str)):
return obj
# Convert numpy scalars to Python scalars, and numpy ndarrays to nested lists.
if isinstance(obj, np.generic):
return obj.item()
if isinstance(obj, np.ndarray):
return obj.tolist()
# Now start handling instances where recursive checking becomes necessary.
# HDF5 attribute keys must be strings, so enforce that here.
if isinstance(obj, dict):
out = {}
for key, value in obj.items():
if not isinstance(key, str):
raise TypeError(f"{_path}: HDF5 attribute keys must be strings; got key type {type(key)}")
# Recurse into each value; add a breadcrumb path for helpful error messages
out[key] = cls._to_jsonable(value, _path=f"{_path}.{key}")
return out
# Convert any non-string, non-bytes sequence to a plain JSON list.
if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes, bytearray)):
return [cls._to_jsonable(v, _path=f"{_path}[{i}]") for i, v in enumerate(obj)]
# Store small binary blobs as base64-encoded strings with a distinct tag for round-tripping.
if isinstance(obj, (bytes, bytearray)):
import base64
b64 = base64.b64encode(bytes(obj)).decode("ascii")
return {"tag": "bytes_b64", "value": b64}
# Anything else can't be represented in JSON without a custom serializer.
raise ValueError(f"{_path}: cannot serialize object of unsupported type {type(obj)}")
@classmethod
def _from_jsonable(cls, obj, _path="root"):
"""Recursively reconstruct Python objects from JSON-serializable structures.
Mirrors `_to_jsonable`:
- Dicts with a `"tag"` key are dispatched to a registered deserializer,
or handled specially (e.g., `"bytes_b64"`).
- Plain dicts/lists are recursively walked.
- Primitives (None/bool/int/float/str) pass through unchanged.
"""
# ---- 1) Tagged custom objects ------------------------------------------------------------
# Objects serialized by registry serializers include a "tag" key.
if isinstance(obj, dict) and "tag" in obj:
tag = obj["tag"]
# Try custom deserializers from the registry first
for _type, (reg_tag, _, deserializer) in cls.__REGISTRY__.items(): # noqa: PERF102
if tag == reg_tag:
# Pass the full dict to the registered deserializer
return deserializer(obj)
# Built-in special tag for base64-encoded bytes
if tag == "bytes_b64":
import base64
try:
return base64.b64decode(obj["value"])
except Exception as exp:
raise ValueError(f"{_path}: failed to decode base64 bytes") from exp
# Unknown tag
raise ValueError(f"{_path}: unrecognized serialization tag '{tag}'")
# ---- 2) Plain mapping → dict (recurse over values) --------------------------------------
# Keys are expected to be strings (enforced on the serialize path).
if isinstance(obj, dict):
out = {}
for k, v in obj.items():
# We do not enforce key type here; serialize path already did.
out[k] = cls._from_jsonable(v, _path=f"{_path}.{k}")
return out
# ---- 3) JSON list → Python list (recurse) ------------------------------------------------
if isinstance(obj, list):
return [cls._from_jsonable(v, _path=f"{_path}[]") for v in obj]
# ---- 4) Primitives (None/bool/int/float/str) --------------------------------------------
# These are already JSON-native; return as-is.
return obj
# -------- public API (updated) --------
[docs]
@classmethod
def serialize_data(cls, data: Any) -> str:
"""Recursively serialize `data` to a JSON string for HDF5 attributes."""
try:
jsonable = cls._to_jsonable(data)
return json.dumps(jsonable)
except Exception as exp:
raise ValueError(f"Failed to serialize data (type={type(data)}): {exp}") from exp
[docs]
@classmethod
def deserialize_data(cls, data: str) -> Any:
"""Recursively deserialize a JSON string produced by `serialize_data`."""
try:
parsed = json.loads(data)
except Exception as exp:
raise ValueError(f"Failed to parse JSON string: {exp}") from exp
return cls._from_jsonable(parsed)
[docs]
@classmethod
def serialize_dict(cls, data: dict) -> dict:
"""Recursively serialize a dict’s values (keys must be strings)."""
if not isinstance(data, dict):
raise TypeError("serialize_dict expects a dict")
return {k: cls.serialize_data(v) for k, v in data.items()}
[docs]
@classmethod
def deserialize_dict(cls, data: dict) -> dict:
"""Recursively deserialize a dict produced by `serialize_dict`."""
if not isinstance(data, dict):
raise TypeError("deserialize_dict expects a dict")
return {k: cls.deserialize_data(v) for k, v in data.items()}
[docs]
class NullHDF5Serializer(HDF5Serializer):
"""A no-op serializer that performs no serialization/deserialization."""
# ----------------------------------------------- #
# Class Variables / Registry #
# ----------------------------------------------- #
__REGISTRY__: dict[type, tuple[str, Callable[[Any], dict], Callable[[dict], Any]]] = {}
# -------- public API (updated) --------
[docs]
@classmethod
def serialize_data(cls, data: Any) -> Any:
"""Recursively serialize `data` to a JSON string for HDF5 attributes."""
return data
[docs]
@classmethod
def deserialize_data(cls, data: Any) -> Any:
"""Recursively deserialize a JSON string produced by `serialize_data`."""
return data
[docs]
@classmethod
def serialize_dict(cls, data: dict) -> dict:
"""Recursively serialize a dict’s values (keys must be strings)."""
return data
[docs]
@classmethod
def deserialize_dict(cls, data: dict) -> dict:
"""Recursively deserialize a dict produced by `serialize_dict`."""
return data
# ----------------------------------------------- #
# YAML Reader/Writer tools #
# ----------------------------------------------- #
class _YAMLHandler(ABC):
"""
Abstract base class for defining custom YAML representers and constructors.
Subclasses must define:
- `__tag__`: a YAML tag string (e.g., "!unyt_quantity")
- `__type__`: the Python type to associate with this handler
- `to_yaml()`: a static method to serialize the Python object to a YAML node
- `from_yaml()`: a static method to deserialize a YAML node to a Python object
This class provides a consistent interface for registering type-specific
(de)serialization logic with a `ruamel.yaml.YAML` instance.
Example
-------
class MyTypeHandler(_YAMLHandler):
__tag__ = "!my_type"
__type__ = MyType
@staticmethod
def to_yaml(representer, obj):
return representer.represent_mapping(MyTypeHandler.__tag__, {
"x": obj.x,
"y": obj.y
})
@staticmethod
def from_yaml(loader, node):
data = CommentedMap()
loader.construct_mapping(node, maptyp=data, deep=True)
return MyType(data["x"], data["y"])
yaml = YAML()
MyTypeHandler.register(yaml)
"""
__tag__: str = None
__type__: type = None
@staticmethod
@abstractmethod
def to_yaml(representer: "Representer", obj: Any) -> "Node":
"""
Convert a Python object to a YAML node.
Parameters
----------
representer : ruamel.yaml.representer.Representer
The YAML representer to use for creating the node.
obj : Any
The Python object to serialize.
Returns
-------
ruamel.yaml.nodes.Node
A YAML node representing the serialized object.
"""
pass
@staticmethod
@abstractmethod
def from_yaml(loader: "Constructor", node: "Node") -> Any:
"""
Convert a YAML node to a Python object.
Parameters
----------
loader : ruamel.yaml.constructor.Constructor
The YAML loader to use for constructing the Python object.
node : ruamel.yaml.nodes.Node
The YAML node to deserialize.
Returns
-------
Any
The deserialized Python object.
"""
pass
@classmethod
def register(cls, yaml: "YAML") -> None:
"""
Register this handler's representer and constructor with a YAML instance.
Parameters
----------
yaml : ruamel.yaml.YAML
The YAML instance to register the handler with.
"""
if cls.__tag__ is None or cls.__type__ is None:
raise ValueError(f"{cls.__name__} must define both __tag__ and __type__.")
yaml.representer.add_multi_representer(cls.__type__, cls.to_yaml)
yaml.constructor.add_constructor(cls.__tag__, cls.from_yaml)
[docs]
class UnytArrayHandler(_YAMLHandler):
"""Unyt array handler for YAML serialization/deserialization."""
__tag__ = "!unyt_array"
__type__ = unyt.unyt_array
[docs]
@staticmethod
def to_yaml(representer, obj):
return representer.represent_mapping(
UnytArrayHandler.__tag__, {"value": obj.d.tolist(), "units": str(obj.units)}
)
[docs]
@staticmethod
def from_yaml(loader, node):
data = CommentedMap()
loader.construct_mapping(node, maptyp=data, deep=True)
return unyt.unyt_array(data["value"], data["units"])
[docs]
class UnytQuantityHandler(_YAMLHandler):
"""Unyt quantity handler for YAML serialization/deserialization."""
__tag__ = "!unyt_quantity"
__type__ = unyt.unyt_quantity
[docs]
@staticmethod
def to_yaml(representer, obj):
return representer.represent_mapping(
UnytQuantityHandler.__tag__, {"value": obj.value.item(), "units": str(obj.units)}
)
[docs]
@staticmethod
def from_yaml(loader, node):
data = CommentedMap()
loader.construct_mapping(node, maptyp=data, deep=True)
return unyt.unyt_quantity(data["value"], data["units"])
[docs]
class UnytUnitHandler(_YAMLHandler):
"""Unyt unit handler for YAML serialization/deserialization."""
__tag__ = "!unyt_unit"
__type__ = unyt.Unit
[docs]
@staticmethod
def to_yaml(representer, obj):
return representer.represent_scalar(UnytUnitHandler.__tag__, str(obj))
[docs]
@staticmethod
def from_yaml(loader, node):
value = loader.construct_scalar(node)
return unyt.Unit(value)
[docs]
class NumpyArrayHandler(_YAMLHandler):
"""NumPy array handler for YAML serialization/deserialization."""
__tag__ = "!ndarray"
__type__ = np.ndarray
[docs]
@staticmethod
def to_yaml(representer, obj: np.ndarray):
# We store both dtype and shape to ensure safe reconstruction
return representer.represent_mapping(
NumpyArrayHandler.__tag__, {"dtype": str(obj.dtype), "shape": obj.shape, "data": obj.tolist()}
)
[docs]
@staticmethod
def from_yaml(loader, node):
data = CommentedMap()
loader.construct_mapping(node, maptyp=data, deep=True)
# Validate required fields
if not all(k in data for k in ("dtype", "shape", "data")):
raise ValueError(f"Invalid ndarray YAML mapping: {data}")
arr = np.array(data["data"], dtype=np.dtype(data["dtype"]))
# Optionally enforce shape
if tuple(arr.shape) != tuple(data["shape"]):
try:
arr = arr.reshape(data["shape"])
except Exception as e:
raise ValueError(f"Shape mismatch when reconstructing ndarray: {e}") from e
return arr
[docs]
class PathHandler(_YAMLHandler):
"""Path handler for YAML serialization/deserialization."""
__tag__ = "!path"
__type__ = Path
[docs]
@staticmethod
def to_yaml(representer, obj: Path):
data = {"absolute": obj.is_absolute(), "parts": list(obj.parts)}
if obj.drive:
data["drive"] = obj.drive
return representer.represent_mapping(PathHandler.__tag__, data)
[docs]
@staticmethod
def from_yaml(loader, node):
data = CommentedMap()
loader.construct_mapping(node, maptyp=data, deep=True)
parts = data.get("parts", [])
path = Path(*parts)
if data.get("absolute", False) and not path.is_absolute():
path = path.resolve()
return path
[docs]
def get_unyt_compatible_yaml() -> "YAML":
"""
Get a YAML instance configured for unyt compatibility.
This function creates a `ruamel.yaml.YAML` instance and registers
custom representers and constructors for unyt types.
Returns
-------
ruamel.yaml.YAML
A YAML instance with unyt support.
"""
from ruamel.yaml import YAML
yaml = YAML(typ="rt")
UnytArrayHandler.register(yaml)
UnytQuantityHandler.register(yaml)
UnytUnitHandler.register(yaml)
PathHandler.register(yaml)
NumpyArrayHandler.register(yaml)
return yaml
[docs]
def get_default_yaml() -> "YAML":
"""
Get a default YAML instance without unyt support.
This function creates a `ruamel.yaml.YAML` instance with no custom
representers or constructors registered.
Returns
-------
ruamel.yaml.YAML
A default YAML instance.
"""
from ruamel.yaml import YAML
return YAML(typ="rt")
unyt_yaml = get_unyt_compatible_yaml()
"""~ruamel.yaml.YAML: A YAML instance configured for unyt compatibility."""