Source code for fields.utils.utilities
"""
Utility functions for fields.
"""
from typing import TYPE_CHECKING, Optional, Tuple
if TYPE_CHECKING:
from ._typing import Signature, SignatureInput
[docs]
def validate_rank_signature(
rank: int, signature: Optional["SignatureInput"] = None
) -> "Signature":
"""
Validate that a tensor signature is consistent with a given tensor rank. If it is not provided,
create a fully contravariant signature.
Parameters
----------
rank : int
The rank of the tensor (i.e., number of element-wise dimensions).
signature : int or list of int, optional
A sequence of values indicating index variance:
- `1` for contravariant (upper index)
- `-1` for covariant (lower index)
Returns
-------
tuple of int
The corrected / standardized tensor signature.
Raises
------
ValueError
If the signature length does not match the rank, or if any value in the
signature is not -1 or 1.
"""
if signature is None:
return (1,) * rank
if isinstance(signature, int):
if signature not in (-1, 1):
raise ValueError(f"Signature value must be 1 or -1, got: {signature}")
return (signature,) * rank
# Now assume it's a sequence
signature = tuple(signature)
if len(signature) != rank:
raise ValueError(
f"Signature length {len(signature)} does not match tensor rank {rank}."
)
if any(s not in (-1, 1) for s in signature):
raise ValueError(
f"Signature values must be either 1 (contravariant) or -1 (covariant), got: {signature}."
)
return signature
[docs]
def signature_to_tensor_class(signature: "SignatureInput") -> Tuple[int, int]:
"""
Convert a tensor signature to its :math:`(p, q)` form.
The :math:`(p, q)` notation describes the number of contravariant (upper) and
covariant (lower) indices in a tensor:
- ``p``: number of ``1``'s in the signature (contravariant indices)
- ``q``: number of ``-1``'s in the signature (covariant indices)
Parameters
----------
signature : sequence of int
The tensor signature, typically a tuple of `1` and `-1`.
Returns
-------
tuple of int
A tuple `(p, q)` where:
- `p` is the count of contravariant indices
- `q` is the count of covariant indices
Raises
------
ValueError
If the signature contains values other than `1` or `-1`.
"""
# Validate the signature.
signature = validate_rank_signature(signature)
# Compute the sums.
p = sum(1 for s in signature if s == 1)
q = sum(1 for s in signature if s == -1)
return p, q