Source code for pyrokinetics.dataset_wrapper

from __future__ import annotations

from ast import literal_eval
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional

import netCDF4 as nc

from .metadata import metadata
from .typing import PathLike
from .units import ureg

if TYPE_CHECKING:
    import xarray as xr


[docs] class DatasetWrapper: """ Base class for classes that store their data in an Xarray Dataset. Defines a number of useful functions, such as ``__getitem__`` that redirects to the underlying Dataset, and methods to read or write to disk. Ensures that the underlying Dataset contains metadata about the current session. The user may access the underlying Dataset via ``self.data``. Parameters ---------- data_vars: Optional[Dict[str, Any]] Variables to be passed to the underlying Dataset. coords: Optional[Dict[str, Any]] Coordinates to be passed to the underlying Dataset. attrs: Optional[Dict[str,Any]] Attributes to be passed to the underlying Dataset. An associated read-only property is created for each attr. title: Optional[str] Sets the 'title' attribute in the underlying Dataset. Uses the derived class name by default. """
[docs] def __init__( self, data_vars: Optional[Dict[str, Any]] = None, coords: Optional[Dict[str, Any]] = None, attrs: Optional[Dict[str, Any]] = None, title: Optional[str] = None, ) -> None: import pint_xarray # noqa import xarray as xr # Set default title if the user hasn't provided one if title is None: title = self.__class__.__name__ # Initialise attrs to an empty dict if the user hasn't provided one if attrs is None: attrs = {} # Save attribute units and strip them from the dict # Write attrs to a new dict to avoid modifying the original self._attr_units = {} new_attrs = {} for key, value in attrs.items(): if hasattr(value, "units") and hasattr(value, "magnitude"): self._attr_units[key] = value.units new_attrs[key] = value.magnitude else: new_attrs[key] = value # Save _attr_units in the dataset new_attrs["attribute_units"] = repr( {k: str(v) for k, v in self._attr_units.items()} ) # Add metadata to attrs dict obj_name = self.__class__.__name__ # name of derived class, not DatasetWrapper meta_dict = metadata(title, obj_name, netcdf4_version=nc.__version__) for key, val in meta_dict.items(): new_attrs[key] = val # Set underlying dataset self.data = xr.Dataset(data_vars=data_vars, coords=coords, attrs=new_attrs)
@property def data(self) -> xr.Dataset: """ Property for managing the underlying Xarray Dataset. The 'getter' returns the Dataset without changes, while the 'setter' uses the pint-array 'quantify' function to ensure units attributes are integrated properly. """ return self._data @data.setter def data(self, ds: xr.Dataset) -> None: self._data = ds.pint.quantify(unit_registry=ureg) @property def coords(self) -> Mapping[str, xr.DataArray]: """Redirects to underlying Xarray Dataset coords.""" return self.data.coords @property def data_vars(self) -> Mapping[str, xr.DataArray]: """Redirects to underlying Xarray Dataset data_vars.""" return self.data.data_vars @property def attrs(self) -> Dict[str, Any]: """Redirects to underlying Xarray Dataset attrs.""" return self.data.attrs @property def dims(self) -> Mapping[str, int]: """Redirects to underlying Xarray Dataset dims.""" return self.data.dims @property def sizes(self) -> Mapping[str, int]: """Redirects to underlying Xarray Dataset sizes.""" return self.data.sizes def __getitem__(self, key: str) -> Any: """Redirect indexing to self.data""" try: return self.data[key] except KeyError: raise KeyError( f"'{self.__class__.__name__}' object does not contain '{key}'" ) def __getattr__(self, name: str) -> Any: """ Redirect attribute lookup to self.data.attrs. Re-assigns units if they were stripped on initialisation. """ try: value = self.data.attrs[name] except KeyError: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) if name in self._attr_units: return value * self._attr_units[name] else: return value def __repr__(self) -> str: """Returns stringified xarray Dataset from self.data""" dataset_repr = repr(self.data) my_repr = dataset_repr.replace( "<xarray.Dataset>", f"<pyrokinetics.{self.__class__.__name__}>\n(Wraps <xarray.Dataset>)", ) return my_repr def __contains__(self, name: str) -> bool: """Redirect ``x in y`` calls to the inner dataset""" return name in self.data
[docs] def to_netcdf(self, *args, **kwargs) -> None: """ Writes self.data to disk. Forwards all args to xarray.Dataset.to_netcdf. Complex data is expanded out into float arrays of shape ``[dims..., 2]``. """ data = self.data.pint.dequantify() data.pint.dequantify().to_netcdf(auto_complex=True, *args, **kwargs)
[docs] @classmethod def from_netcdf( cls, path: PathLike, *args, overwrite_metadata: bool = False, overwrite_title: Optional[str] = None, **kwargs, ): """ Initialise self.data from a netCDF file. Parameters ---------- path: PathLike Path to the netCDF file on disk. *args: Positional arguments forwarded to xarray.open_dataset. overwrite_metadata: bool, default False Take ownership of the netCDF data, overwriting attributes such as 'title', 'software_name', 'date_created', etc. overwrite_title: Optional[str] If ``overwrite_metadata`` is ``True``, this is used to set the ``title`` attribute in ``self.data``. If unset, the derived class name is used. **kwargs: Keyword arguments forwarded to xarray.open_dataset. Returns ------- Derived Instance of a derived class with self.data initialised. Derived classes which need to do more than this should override this method with their own implementation. Raises ------ RuntimeError If the netcdf is for the wrong type of object. """ import pint_xarray # noqa import xarray as xr instance = cls.__new__(cls) with xr.open_dataset(Path(path), auto_complex=True, *args, **kwargs) as dataset: if dataset.attrs.get("object_type", cls.__name__) != cls.__name__: raise RuntimeError(dedent(f"""\ netcdf of type {dataset.attrs["object_type"]} cannot be used to create objects of type {cls.__name__}. """)) if overwrite_metadata: if overwrite_title is None: title = cls.__name__ else: title = str(overwrite_title) new_metadata = metadata( title, cls.__name__, netcdf4_version=nc.__version__ ) for key, val in new_metadata.items(): dataset.attrs[key] = val instance.data = dataset # Set up attr_units attr_units_as_str = literal_eval(dataset.attribute_units) instance._attr_units = {k: ureg(v).units for k, v in attr_units_as_str.items()} return instance