"""
Utility functions for interacting with :py:mod:`sympy`.
"""
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import sympy as sp
from pymetric.utilities.logging import pg_log, pg_params
[docs]
def lambdify_expression(
expression: Union[str, sp.Basic],
axes: List[sp.Symbol],
parameters: Optional[Dict[str, Any]] = None,
) -> Callable:
r"""
Convert a `Sympy <https://docs.sympy.org/latest/index.html>`__ expression (scalar, Matrix, or N-dimensional Array) or a string
into a callable function. The returned function depends on the given ``axes``
and optional ``parameters``, and returns either:
- A scalar or array of shape matching the input arguments (if the expression is scalar).
- A 2D or higher-dimensional NumPy array (if the expression is a :py:class:`sympy.matrices.Matrix` or N-dim Array).
The returned function utilizes numpy functions to ensure that it is relatively optimized for use on array valued
arguments.
.. note::
While :py:func:`lambdify_expression` is designed to perform relatively well on array inputs, it should
not be called in large loops as the call overhead can be quite large.
Parameters
----------
expression : str or sympy.Basic
The expression to be converted into a callable function. This may be any valid ``sympy`` expression or a string
which can be parsed directly to a valid sympy expression using :py:func:`sympy.sympify`. In ``expression`` is a
sympy matrix or other non-scalar object, each element is lambdified independently and then combined again in
the callable that is returned.
axes : list of :py:class:`~sympy.core.symbol.Symbol`
The list of coordinate symbols which will become the arguments of the resulting function. This should be a list
of valid :py:class:`~sympy.core.symbol.Symbol` objects, each corresponding to an independent variable in the expression. The
order of the ``axes`` will determine the expected order of inputs when using the resulting callable function.
parameters : dict, optional
A mapping from symbol name to numerical value (or another valid Sympy expression)
which will be substituted into the final expression. Defaults to ``None``.
Returns
-------
Callable
A function ``f(*axes)`` that evaluates the expression numerically:
- **If the expression is scalar**: returns float or NumPy array.
- **If the expression is Matrix/Array**: returns a NumPy array of shape
``(rows, cols)`` or the corresponding multi-dimensional shape. If the first
argument is an array of length N, then the result generally has shape
``(N, rows, cols, ...)`` etc.
Raises
------
ValueError
- If sympifying the expression fails.
- If there are unrecognized symbols that do not match any of the axes or parameters.
- If a constant expression cannot be converted to float.
Examples
--------
>>> import sympy as sp
>>> x = sp.Symbol('x')
>>> expr = sp.Matrix([[1, 0], [x, 1]])
>>> f = lambdify_expression(expr, axes=[x])
>>> import numpy as np
>>> print(f(2.0))
[[1. 0.]
[2. 1.]]
>>> # For an array of x-values, you'll get shape (N, 2, 2).
>>> xs = np.linspace(0, 1, 3)
>>> print(f(xs).shape)
(3, 2, 2)
Notes
-----
This operation proceeds in the following basic steps:
1. Parse and sympify string expressions.
2. Substitute known parameters.
3. Validate that any remaining free symbols are among the provided `axes`.
4. Identify whether the expression is a scalar, Matrix, or Array.
5. Recursively lambdify each entry (for Matrix/Array), or do a scalar lambdify.
6. Return a function that evaluates those sub-lambdas and aggregates them
into a NumPy array of the correct shape (for Matrix/Array) or a scalar/array (for scalar expressions).
"""
# Ensure that parameters is a dictionary and not None by default. This
# ensures that the kwarg is immutable.
parameters = parameters or {}
# Parse the expression if it is a string.
if isinstance(expression, str):
try:
expression = sp.sympify(expression)
except (sp.SympifyError, TypeError) as e:
raise ValueError(
f"Failed to parse/convert expression to a Sympy object. Expression: {expression}"
) from e
# Substitute the parameters into the bound expression and then check for any missing symbols
# which are not axes but remain in the expression after the substitution.
bound_expr = expression.subs(parameters)
free_symbols = bound_expr.free_symbols
recognized_symbols = set(axes) | {sp.Symbol(k) for k in parameters.keys()}
unrecognized = free_symbols - recognized_symbols
if unrecognized:
raise ValueError(
f"The expression contains unrecognized symbols: {unrecognized}. "
"Ensure all symbols are in 'axes' or 'parameters'."
)
# Direct the function to perform differently depending on whether
# the expression is multidimensional or not.
if isinstance(bound_expr, sp.MatrixBase):
return _lambdify_matrix(bound_expr, axes, parameters)
elif isinstance(bound_expr, sp.NDimArray):
return _lambdify_ndarray(bound_expr, axes, parameters)
else:
# It's a scalar expression => proceed with scalar logic
return _lambdify_scalar(bound_expr, axes)
def _lambdify_scalar(bound_expr: sp.Basic, axes: List[sp.Symbol]) -> Callable:
"""
Lambdify a scalar Sympy expression that depends on given axes.
If the expression is constant, return a function that broadcasts.
Otherwise, return a standard lambdified function.
"""
# Determine the ops count so that we can determine if we are interested
# in trying to check for .is_constant or if that will be too taxing.
__ops_count__ = sp.count_ops(bound_expr, visual=False)
__will_check_flag__ = __ops_count__ <= pg_params["skip_constant_checks"]
if __will_check_flag__:
# Handle constant expressions
if bound_expr.is_constant():
try:
constant_value = float(sp.simplify(bound_expr))
except (TypeError, ValueError):
raise ValueError(
f"The expression simplifies to a constant but could not be converted to float: {bound_expr}"
)
def _constant_function(*args, _cv=constant_value):
# If no args, just return the constant
if not args:
return _cv
# Otherwise broadcast to the shape of the first argument
return np.full_like(args[0], _cv)
return _constant_function
else:
# We don't perform the bound_expr check.
pg_log.debug(
"Skipping .is_constant check for bound expression: %s.", bound_expr
)
pass
# If non-constant, do a normal lambdify
try:
func = sp.lambdify(axes, bound_expr, "numpy")
except Exception as e:
raise ValueError(
f"Failed to lambdify expression with axes={axes}.\n"
f"Sympy expression: {bound_expr}"
) from e
return func
def _lambdify_matrix(
mat_expr: sp.MatrixBase, axes: List[sp.Symbol], parameters: Dict[str, Any]
) -> Callable:
"""
Recursively lambdify each element of a :py:class:`sympy.matrices.Matrix` (2D).
Return a function that, given arrays for the axes, returns a
NumPy array of shape (rows, cols) or (N, rows, cols), etc.
"""
rows, cols = mat_expr.shape
# Build per-element scalar lambdas
element_funcs = []
for i in range(rows):
row_funcs = []
for j in range(cols):
row_funcs.append(_lambdify_scalar(mat_expr[i, j], axes))
element_funcs.append(row_funcs)
def matrix_func(*args):
"""
Evaluate the (rows × cols) symbolic matrix on *args*.
Returns
-------
ndarray
* If all inputs are scalars → shape (rows, cols)
* If inputs broadcast to B → shape B + (rows, cols)
where B = (N₁, N₂, …)
"""
flat_vals = [f(*args) for row in element_funcs for f in row] # length rows*cols
flat_vals = [np.asarray(v) for v in flat_vals] # force ndarray view
b_shape = flat_vals[0].shape # () for scalars, (N₁, …) for arrays
# result of stack: b_shape + (rows*cols,)
stacked = np.stack(flat_vals, axis=-1)
result = stacked.reshape(b_shape + (rows, cols))
return result.astype(float)
return matrix_func
def _lambdify_ndarray(
nd_expr: sp.NDimArray, axes: List[sp.Symbol], parameters: Dict[str, Any]
) -> Callable:
"""
Recursively lambdify each element of a sympy N-dimensional array.
Return a function that, given arrays for the axes, returns a
NumPy array of shape matching the original n-dim shape in nd_expr.
"""
shape = nd_expr.shape
# Build a list of scalar-lambdas for each element in flatten order
sub_exprs = list(nd_expr) # Flatten into 1D
sub_funcs = [_lambdify_scalar(e, axes) for e in sub_exprs]
def array_func(*args):
# evaluate every scalar sub-expression
pieces = [f(*args) for f in sub_funcs] # list of np.ndarray *or* scalars
# make sure every piece is an array (broadcast-scalar -> 0-d array)
pieces = [np.asarray(p) for p in pieces]
# common broadcast shape, e.g. (N1, N2, N3)
b_shape = pieces[0].shape
# stack on the LAST axis, then reshape … + expr_shape
stacked = np.stack(pieces, axis=-1) # shape b_shape + (∏expr_shape,)
result = stacked.reshape(b_shape + shape)
return result.astype(float)
return array_func