import re
from functools import partial
from pathlib import Path
import numpy as np
import xarray as xr
from xarray.core.utils import Frozen
from .neksuite import readnek
__all__ = (
"open_dataset",
"open_mfdataset",
)
nek_ext_pattern = re.compile(
r"""
.* # one or more characters
\. # character "."
f # character "f"
(\d{5}|ld) # 5 digits or the characters "ld"
""",
re.VERBOSE,
)
[docs]def can_open_nek_dataset(path):
"""A regular expression check of the file extension.
.. hint::
- Would not match: .f90 .f .fort .f0000
- Would match: .fld .f00001 .f12345
"""
return nek_ext_pattern.match(str(path))
[docs]def open_dataset(path, **kwargs):
"""Helper function for opening a file as an :class:`xarray.Dataset`.
Parameters
----------
path : str
Path to a field file (only Nek files are supported at the moment.)
kwargs : dict
Keyword arguments passed on to the compatible open function.
"""
if can_open_nek_dataset(path):
_open = _open_nek_dataset
else:
raise NotImplementedError(f"Filetype: {Path(path).suffix} is not supported.")
return _open(path, **kwargs)
open_mfdataset = partial(
xr.open_mfdataset, combine="nested", concat_dim="time", engine="pymech"
)
open_mfdataset.__doc__ = """Helper function for opening multiple files as an
:class:`xarray.Dataset`. See :func:`xarray.open_mfdataset` for documentation on
parameters."""
[docs]def _open_nek_dataset(path, drop_variables=None):
"""Interface for converting Nek field files into xarray_ datasets.
.. _xarray: https://docs.xarray.dev/en/stable/
"""
field = readnek(path)
if isinstance(field, int):
raise OSError(f"Failed to load {path}")
elements = field.elem
elem_stores = [_NekDataStore(elem) for elem in elements]
try:
elem_dsets = [
xr.Dataset.load_store(store).set_coords(store.axes) for store in elem_stores
]
except ValueError as err:
raise NotImplementedError(
"Opening dataset failed because you probably tried to open a field file "
"with an unsupported mesh. "
"The `pymech.open_dataset` function currently works only with cartesian "
"box meshes. For more details on this, see "
"https://github.com/eX-Mech/pymech/issues/31"
) from err
# See: https://github.com/MITgcm/xmitgcm/pull/200
ds = xr.combine_by_coords(elem_dsets, combine_attrs="drop")
ds.coords.update({"time": field.time})
if drop_variables:
ds = ds.drop_vars(drop_variables)
return ds
[docs]class PymechXarrayBackend(xr.backends.BackendEntrypoint):
[docs] def guess_can_open(self, filename_or_obj):
return can_open_nek_dataset(filename_or_obj)
[docs] def open_dataset(
self,
filename_or_obj,
*,
drop_variables=None,
# other backend specific keyword arguments
# `chunks` and `cache` DO NOT go here, they are handled by xarray
):
return _open_nek_dataset(filename_or_obj, drop_variables)
open_dataset_parameters = ("filename_or_obj", "drop_variables")
[docs]class _NekDataStore(xr.backends.common.AbstractDataStore):
"""Xarray store for a Nek field element.
Parameters
----------
elem: :class:`pymech.core.Elem`
A Nek5000 element.
"""
axes = ("z", "y", "x")
def __init__(self, elem):
self.elem = elem
[docs] def meshgrid_to_dim(self, mesh):
"""Reverse of np.meshgrid. This method extracts one-dimensional
coordinates from a cubical array format for every direction
"""
dim = np.unique(np.round(mesh, 8))
return dim
def get_dimensions(self):
return self.axes
def get_attrs(self):
elem = self.elem
attrs = {
"boundary_conditions": elem.bcs,
"curvature": elem.curv,
"curvature_type": elem.ccurv,
}
return Frozen(attrs)
[docs] def get_variables(self):
"""Generate an xarray dataset from a single element."""
ax = self.axes
elem = self.elem
data_vars = {
ax[2]: self.meshgrid_to_dim(elem.pos[0]), # x
ax[1]: self.meshgrid_to_dim(elem.pos[1]), # y
ax[0]: self.meshgrid_to_dim(elem.pos[2]), # z
"xmesh": xr.Variable(ax, elem.pos[0]),
"ymesh": xr.Variable(ax, elem.pos[1]),
"zmesh": xr.Variable(ax, elem.pos[2]),
"ux": xr.Variable(ax, elem.vel[0]),
"uy": xr.Variable(ax, elem.vel[1]),
"uz": xr.Variable(ax, elem.vel[2]),
}
if elem.pres.size:
data_vars["pressure"] = xr.Variable(ax, elem.pres[0])
if elem.temp.size:
data_vars["temperature"] = xr.Variable(ax, elem.temp[0])
if elem.scal.size:
data_vars.update(
{
"s{:02d}".format(iscalar + 1): xr.Variable(ax, elem.scal[iscalar])
for iscalar in range(elem.scal.shape[0])
}
)
return Frozen(data_vars)