Source code for pyrokinetics.units

from contextlib import contextmanager
from typing import Optional

import numpy as np
import pint
from numpy.typing import ArrayLike
from scipy.constants import physical_constants
from scipy.interpolate import (
    CloughTocher2DInterpolator,
    InterpolatedUnivariateSpline,
    RectBivariateSpline,
)
from typing_extensions import TypeAlias


[docs] class PyroNormalisationError(Exception): """Exception raised when trying to convert simulation units requires physical reference values""" def __init__(self, system, units): super().__init__() self.system = system if isinstance(system, str) else system._system.name self.units = units def __str__(self): return ( f"Cannot convert '{self.units}' to '{self.system}' normalisation. " f"Possibly '{self.system}' is missing physical reference values. " "You may need to load a kinetics or equilibrium file" )
[docs] class PyroContextError(Exception): """Exception raised when trying to convert simulation units requires physical reference values""" def __init__(self, error_msg): super().__init__() self.error_msg = error_msg def __str__(self): return self.error_msg
[docs] class Normalisation: """ Base class for SimulationNormalisation and ConventionNormalisation. Places no constraints on subclasses. Allows us to detect Pyrokinetics normalisation objects so that PyroQuantity can convert to normalisations as well as perform standard unit conversions. """ pass
[docs] class PyroQuantity(pint.UnitRegistry.Quantity): def _replace_nan(self, value, system: Optional[str]): """Check bad conversions: if reference value not available, ``value`` will be ``NaN``""" if not np.isnan(value).any(): return value # Special case zero, because that's always fine (except for # offset units, but we don't use those) if (self == 0.0).all(): return 0.0 * value.units # If everything is a NaN then conversion failed or data was # all NaN to begin with. Checks if all data is now a NaN # but was not before, otherwise some # NaNs exist in the data and we can proceed if np.isnan(value).all() and not np.isnan(self).all(): raise PyroNormalisationError(system, self.units) return value
[docs] def to_base_units(self, system: Optional[str] = None): with self._REGISTRY.as_system(system): value = super().to_base_units() return self._replace_nan(value, system)
def _convert_simulation_units(self, norm): """Replace simulation units by their corresponding physical unit""" units = dict() if hasattr(norm, "run_name"): name = norm.run_name else: name = norm.name for unit, power in self._units.items(): if (new_unit := f"{unit}_{name}") in self._REGISTRY: unit = new_unit if unit not in units.keys(): units[unit] = power else: units[unit] += power units = {k: v for k, v in units.items() if v != 0} return self._REGISTRY.Quantity(self._magnitude, pint.util.UnitsContainer(units))
[docs] def convert_physical_units(self, norm): """Replace Phyiscal Units by their corresponding Simulation unit""" units = dict() as_physical = self.to(norm) convention = getattr(norm, "convention", None) if convention is None: convention = norm.default_convention.convention for unit, power in as_physical._units.items(): if base := self._is_base_unit(unit): unit = str(getattr(convention, base)) units[unit] = units.get(unit, 0) + power units = {k: v for k, v in units.items() if v != 0} return self._REGISTRY.Quantity( as_physical._magnitude, pint.util.UnitsContainer(units) )
@staticmethod def _is_base_unit(unit): """If ``unit`` is a reference unit, return the type of base unit, else return None""" base_units = [ "beta_ref", "bref", "lref", "mref", "nref", "qref", "tref", "vref", "rhoref", ] for base in base_units: if unit.startswith(base): return base return None def _is_physical_or_simulation_unit(self): """If ``unit`` is a physical unit, return the type of base unit, else return None""" base_dimensionality = [ "[beta_ref]", "[bref]", "[lref]", "[mref]", "[nref]", "[qref]", "[tref]", "[vref]", "[rhoref]", ] unit_dimensionality = list(self.dimensionality) unit_match = list(set(unit_dimensionality) & set(base_dimensionality)) n_dimensionality = len(unit_dimensionality) n_match = len(unit_match) # Return dimensionless if n_dimensionality == 0: return "dimensionless" # Check if any simulation units if n_match > 0: simulation = True else: simulation = False # Check if any physical units if n_match != n_dimensionality: physical = True else: physical = False # Edge case where elementary_charge is being used if n_match == n_dimensionality - 2 and "elementary_charge" in str(self.units): physical = False simulation = True if physical and simulation: return "mixture" elif physical: return "physical" elif simulation: return "simulation" else: raise ValueError(f"Somehow {self} is not physical or simulation unit") def _convert_base_units(self, norm): """Replace base units with those for other normalisation""" units = dict() for unit, power in self._units.items(): if new_unit := self._is_base_unit(unit): unit = str(getattr(norm, new_unit)) if unit not in units.keys(): units[unit] = power else: units[unit] += power units = {k: v for k, v in units.items() if v != 0} return pint.util.UnitsContainer(units)
[docs] def to(self, other=None, *contexts, **ctx_kwargs): """Return Quantity rescaled to other units or normalisation Raises ------ PyroNormalisationError If ``other`` is a :class:`Normalisation` and the value cannot be converted. This indicates required physical reference values are missing """ if isinstance(other, Normalisation): with self._REGISTRY.context(other.context, *contexts, **ctx_kwargs): as_physical = self._convert_simulation_units(other) value = as_physical.to(self._convert_base_units(other)) return self._replace_nan(value, other) else: unit_type = self._is_physical_or_simulation_unit() output_type = self._REGISTRY.Quantity( 1, other )._is_physical_or_simulation_unit() if output_type != unit_type: unit_type = "mixture" if unit_type == "mixture": if contexts: with self._REGISTRY.context(*contexts, **ctx_kwargs): as_physical = self._convert_simulation_units(*contexts) return as_physical.to(other, **ctx_kwargs) else: try: return super().to(other, **ctx_kwargs) except pint.errors.DimensionalityError: raise PyroContextError( f"Trying to convert between physical and simulation units " f"'{self}' -> '{other}' without context. Please use a context here" ) return super().to(other, *contexts, **ctx_kwargs)
[docs] class PyroUnitRegistry(pint.UnitRegistry): """Specialisation of `pint.UnitRegistry` that expands some methods to be aware of pyrokinetics normalisation objects. """ Quantity: TypeAlias = PyroQuantity
[docs] def __init__(self): super().__init__(force_ndarray=True) self._on_redefinition = "ignore" self.define("qref = elementary_charge") # IMAS normalises to the actual deuterium mass, so let's add that # as a constant self.define( f"hydrogen_mass = {physical_constants['proton mass'][0]} {physical_constants['proton mass'][1]}" ) self.define( f"deuterium_mass = {physical_constants['deuteron mass'][0]} {physical_constants['deuteron mass'][1]}" ) self.define( f"tritium_mass = {physical_constants['triton mass'][0]} {physical_constants['triton mass'][1]}" ) self.define( f"electron_mass = {physical_constants['electron mass'][0]} {physical_constants['electron mass'][1]}" ) # We can immediately define reference masses in physical units. # WARNING: This might need refactoring to use a [mref] dimension # if we start having other possible reference masses self.define("mref_deuterium = [mref]") self.define( f"mref_electron = {self.electron_mass} / {self.deuterium_mass} mref_deuterium" ) self.define( f"mref_hydrogen = {self.hydrogen_mass} / {self.deuterium_mass} mref_deuterium" ) self.define( f"mref_tritium = {self.tritium_mass} / {self.deuterium_mass} mref_deuterium" ) # For each normalisation unit, we create a unique dimension for # that unit and convention self.define("bref_B0 = [bref]") self.define("lref_minor_radius = [lref]") self.define("nref_electron = [nref]") self.define("tref_electron = [tref]") self.define("[vref] = [tref] ** 0.5 / [mref] ** 0.5") self.define("vref_nrl = [vref]") self.define("[rhoref] = [tref] ** 0.5 * [mref] ** 0.5 / [bref]") self.define("rhoref_pyro = [rhoref]") self.define("beta_ref_ee_B0 = [beta_ref]") # vrefs are related by constant, so we can always define this one self.define("vref_most_probable = (2**0.5) * vref_nrl") self.define("rhoref_gs2 = (2**0.5) * rhoref_pyro") # Now we define the "other" normalisation units that require more # information, such as bunit_over_B0 or the aspect_ratio self.define("bref_Bunit = NaN bref_B0") self.define("lref_major_radius = NaN lref_minor_radius") self.define("nref_deuterium = NaN nref_electron") self.define("tref_deuterium = NaN tref_electron") self.define("rhoref_unit = NaN rhoref_pyro") # Too many combinations of beta units, this almost certainly won't # scale, so just do the only one we know is used for now self.define("beta_ref_ee_Bunit = NaN beta_ref_ee_B0")
def _after_init(self): super()._after_init() # Enable the Boltzmann context by default so we can always convert # eV to Kelvin self.enable_contexts("boltzmann")
[docs] @contextmanager def as_system(self, system): """Temporarily change the current system of units""" old_system = self.default_system if system is None: pass elif isinstance(system, str): self.default_system = system else: self.default_system = system._system.name yield self.default_system = old_system
def _try_transform(self, src_value, src_unit, src_dim, dst_dim): path = pint.util.find_shortest_path(self._active_ctx.graph, src_dim, dst_dim) if not path: return None src = self.Quantity(src_value, src_unit) for a, b in zip(path[:-1], path[1:]): src = self._active_ctx.transform(a, b, self, src) return src._magnitude, src._units def _convert(self, value, src, dst, inplace=False): """Convert value from some source to destination units. In addition to what is done by the PlainRegistry, converts between units with different dimensions by following transformation rules defined in the context. Parameters ---------- value : value src : UnitsContainer source units. dst : UnitsContainer destination units. inplace : (Default value = False) Returns ------- callable converted value """ if not self._active_ctx: return super()._convert(value, src, dst, inplace) src_dim = self._get_dimensionality(src) dst_dim = self._get_dimensionality(dst) # Try converting the quantity with units as given if converted := self._try_transform(value, src, src_dim, dst_dim): value, src = converted return super()._convert(value, src, dst, inplace) # That wasn't possible, so now we break up the units and see # if we can convert them individually. # These are the new units resulting from any transformations new_units = src for unit, power in src.items(): # Here, we're assuming that the transformation is based on [dim]**1, # while the unit in our quantity might be e.g. its inverse unit_uc = pint.util.UnitsContainer({unit: 1}) unit_dim = self._get_dimensionality(unit_uc) # Now we try to convert between this unit and one of the # destination units for dst_part, dst_power in dst.items(): dst_part_uc = pint.util.UnitsContainer({dst_part: 1}) dst_part_dim = self._get_dimensionality(dst_part_uc) # If we're dealing with an inverse unit, we need to # invert the value to get the transformation right. # This is a bit hacky. Assuming we don't have any # non-multiplicative units, we should always be able # to convert zero though if converted := self._try_transform( 1.0, unit_uc, unit_dim, dst_part_dim ): value_multiplier, new_unit = converted if dst_power == power: value = value * value_multiplier**dst_power else: raise pint.DimensionalityError(src, dst, src_dim, dst_dim) # It worked, so we can replace the original unit # with the transformed one new_units = ( new_units / pint.util.UnitsContainer({unit: power}) * (new_unit**dst_power) ) return super()._convert(value, new_units, dst, inplace)
[docs] class UnitSpline: """ Unit-aware wrapper classes for 1D splines. Parameters ---------- x: Arraylike x-coordinates to pass to SciPy splines, with units. y: ArrayLike y-coordinates to pass to SciPy splines, with units. """
[docs] def __init__(self, x: ArrayLike, y: ArrayLike): from xarray import DataArray if isinstance(x, DataArray): x = x.data if isinstance(y, DataArray): y = y.data self._x_units = x.units self._y_units = y.units x_mag = x.magnitude y_mag = y.magnitude # Assume x is monotonically increasing/decreasing if x_mag[1] > x_mag[0]: self._spline = InterpolatedUnivariateSpline(x_mag, y_mag) else: self._spline = InterpolatedUnivariateSpline(x_mag[::-1], y_mag[::-1])
def __call__(self, x: ArrayLike, derivative: int = 0) -> np.ndarray: u = self._y_units / self._x_units**derivative return self._spline(x.magnitude, nu=derivative) * u
[docs] class UnitSpline2D(RectBivariateSpline): """ Unit-aware wrapper classes for 2D splines. Parameters ---------- x: pint.Quantity x-coordinates to pass to SciPy splines, with units. y: pint.Quantity y-coordinates to pass to SciPy splines, with units. z: pint.Quantity z-coordinates to pass to SciPy splines, with units. """
[docs] def __init__(self, x, y, z): self._x_units = x.units self._y_units = y.units self._z_units = z.units self._spline = RectBivariateSpline(x.magnitude, y.magnitude, z.magnitude)
def __call__(self, x, y, dx=0, dy=0): u = self._z_units / (self._x_units**dx * self._y_units**dy) return self._spline(x.magnitude, y.magnitude, dx=dx, dy=dy, grid=False) * u
[docs] class UnitCloughTocher2DInterpolator: """ Unit-aware 2D CloughTocher2DInterpolator Parameters ---------- points : pint.Quantity ndarray of floats, shape (npoints, ndims); values : pint.Quantity ndarray of float or complex, shape (npoints, …) fill_value: pin.Quantity value to be used outside of interpolation points """
[docs] def __init__(self, points, values): self._value_units = values.units self._spline = CloughTocher2DInterpolator(points, values.magnitude)
def __call__(self, x, y): return self._spline(x, y) * self._value_units
ureg = PyroUnitRegistry() """Default unit registry"""