Source code for pymech.vtksuite

"""Module for converting :class:`pymech.core.HexaData` objects to vtk"""

import os
from itertools import product
from pathlib import Path

import numpy as np

from .log import logger

logger.warning(
    "The module pymech.vtksuite is experimental in nature and "
    "may have some rough edges. The functions can also change in the future."
)

try:
    from tvtk.api import tvtk, write_data
except ImportError:
    logger.warning("To use VTK functions,\n    pip install mayavi")


__all__ = ("hexa2vtk", "writevtk")


[docs] def hexa2vtk(field, downsample=False): """A function for converting :class:`pymech.core.HexaData` to `Traited VTK`_ dataset. The returned dataset can be manipulated with libraries which accept a VTK object, for example Mayavi_. .. _Traited VTK: https://docs.enthought.com/mayavi/tvtk/README.html .. todo:: Try https://github.com/pyvista/pyvista-xarray Example ------- This also requires you to have a GUI toolkit installed: either PyQt4, PySide, PySide2, PyQt5 or wxPython. .. code-block:: python import pymech as pm from pymech.vtksuite import hexa2vtk from mayavi import mlab field = pm.readnek("tests/nek/channel3D_0.f00001") dataset = hexa2vtk(field) mlab.pipeline.add_dataset(dataset) Instead of MayaVi_ you could use also use something high-level like PyVista_ to wrap the underlying VTK object and later visualize them. .. code-block:: python import pyvista as pv dataset = pv.wrap(dataset._vtk_obj) dataset.plot() .. _MayaVi: https://docs.enthought.com/mayavi/mayavi/mlab.html .. _PyVista: https://docs.pyvista.org/getting-started/index.html Parameters ---------- field : :class:`pymech.core.HexaData` a dataset in nekdata format downsample : bool flag T/F Returns ------- dataset : tvtk.tvtk_classes.unstructured_grid.UnstructuredGrid a VTK dataset """ # if downsample: ixs = field.lr1[0] - 1 iys = field.lr1[1] - 1 izs = max(field.lr1[2] - 1, 1) else: ixs = 1 iys = 1 izs = 1 # iix = range(0, field.lr1[0], ixs) nix = len(iix) iiy = range(0, field.lr1[1], iys) niy = len(iiy) iiz = range(0, field.lr1[2], izs) niz = len(iiz) # nppel = nix * niy * niz nepel = (nix - 1) * (niy - 1) * max((niz - 1), 1) nel = field.nel * nepel # if field.ndim == 3: nvert = 8 cellType = tvtk.Hexahedron().cell_type else: nvert = 4 cellType = tvtk.Quad().cell_type # ct = np.array(nel * [cellType]) of = np.arange(0, nvert * nel, nvert) ce = np.zeros(nel * (nvert + 1)) ce[np.arange(nel) * (nvert + 1)] = nvert if field.var[0] != 0: r = np.zeros((nvert * nel, 3)) if field.var[1] != 0: v = np.zeros((nvert * nel, 3)) if field.var[2] == 1: p = np.zeros(nvert * nel) if field.var[3] == 1: T = np.zeros(nvert * nel) if field.var[4] != 0: S = np.zeros((nvert * nel, field.var[4])) # ice = -(nvert + 1) for iel in range(field.nel): for (iz, ez), (iy, ey), (ix, ex) in product( enumerate(iiz), enumerate(iiy), enumerate(iix) ): iarray = iel * nppel + ix + iy * nix + iz * (nix * niy) # Downsample copy into a column vector if field.var[0] == 3: r[iarray, :] = field.elem[iel].pos[:, ez, ey, ex] if field.var[1] == 3: v[iarray, :] = field.elem[iel].vel[:, ez, ey, ex] if field.var[2] == 1: p[iarray] = field.elem[iel].pres[:, ez, ey, ex] if field.var[3] == 1: T[iarray] = field.elem[iel].temp[:, ez, ey, ex] if field.var[4] != 0: S[iarray, :] = field.elem[iel].scal[:, ez, ey, ex] if field.var[0] == 3: for iz, iy, ix in product( range(max(niz - 1, 1)), range(niy - 1), range(nix - 1) ): ice = ice + nvert + 1 for face in range(field.ndim - 1): cell_id = iel * nppel + ix + iy * nix + (iz + face) * nix * niy ce[ice + face * 4 + 1] = cell_id ce[ice + face * 4 + 2] = cell_id + 1 ce[ice + face * 4 + 3] = cell_id + nix + 1 ce[ice + face * 4 + 4] = cell_id + nix # create the array of cells ca = tvtk.CellArray() ca.set_cells(nel, ce) # create the unstructured dataset dataset = tvtk.UnstructuredGrid(points=r) # set the cell types dataset.set_cells(ct, of, ca) # set the data idata = 0 if field.var[1] != 0: dataset.point_data.vectors = v dataset.point_data.vectors.name = "vel" idata += 1 if field.var[2] == 1: dataset.point_data.scalars = p dataset.point_data.scalars.name = "pres" idata += 1 if field.var[3] == 1: dataset.point_data.add_array(T) dataset.point_data.get_array(idata).name = "temp" idata += 1 if field.var[4] != 0: for ii in range(field.var[4]): dataset.point_data.add_array(S[:, ii]) dataset.point_data.get_array(ii + idata).name = "scal_%d" % (ii + 1) # dataset.point_data.update() # return dataset
[docs] def writevtk(fname, data): """A function for writing binary data in the XML VTK format Parameters ---------- fname : str file name data : :class:`pymech.core.HexaData` data organised after reading a file """ ext = ".vtp" fname = Path(fname) if fname.suffix != ext: logger.info(f"Renaming {fname} with extension .vtp") fname = fname.with_suffix(ext) vtk_dataset = hexa2vtk(data) write_data(vtk_dataset, os.fspath(fname))