from __future__ import annotations
import copy
import json
import os
import pathlib
import warnings
from contextlib import contextmanager
from functools import reduce
from itertools import product
import numpy as np
import pint
import xarray as xr
from pint import Quantity
from .dataset_wrapper import DatasetWrapper
from .gk_code import GKInput
from .normalisation import ConventionNormalisation
from .pyro import Pyro
from .units import ureg
def _serialize_path(path: pathlib.Path, base: pathlib.Path) -> str:
"""Convert absolute path to relative-for-JSON."""
try:
return os.path.relpath(path, base)
except ValueError:
return str(path)
def _resolve_path(path: str | pathlib.Path, base: pathlib.Path) -> pathlib.Path:
"""Resolve possibly-relative path against a base directory."""
path = pathlib.Path(path)
if path.is_absolute():
return path
return (base / path).resolve()
# ---- time handling ----
[docs]
def reduce_time(
da,
*,
time_mode,
tolerance_time_range=None,
):
"""
Reduce a DataArray over time according to the chosen policy.
time_mode:
"last" → take final time
"average" → average over tolerance_time_range
"""
if "time" not in da.dims:
return da
if time_mode == "last":
return da.isel(time=-1).drop_vars("time", errors="ignore")
if time_mode == "average":
if tolerance_time_range is None:
raise ValueError("tolerance_time_range required for time averaging")
t = da["time"]
t_max = float(t.max())
t_min = t_max * tolerance_time_range
return (
da.sel(time=slice(t_min, t_max))
.mean(dim="time")
.drop_vars("time", errors="ignore")
)
raise ValueError(f"Unknown time_mode={time_mode}")
# ---- xarray selection ----
[docs]
def select_kx_ky_time(
da,
*,
kx_min,
sum_ky=False,
time_mode,
tolerance_time_range=None,
):
if "kx" in da.dims:
da = da.sel(kx=kx_min)
da = reduce_time(
da,
time_mode=time_mode,
tolerance_time_range=tolerance_time_range,
)
if sum_ky and "ky" in da.dims:
da = da.sum(dim="ky")
return da
# ---- error handling ----
[docs]
def handle_failed_run(buffers, templates, gk_file, error=None):
import warnings
warnings.warn(
f"Failed to load GK output for {gk_file}: {type(error).__name__}: {error}",
RuntimeWarning,
stacklevel=2,
)
for name, buf in buffers.items():
buf.append(templates[name])
[docs]
def normalize_failed_runs(buffers: dict[str, list]) -> None:
"""
Replace None placeholders (from early failures) with NaN-filled DataArrays
once a reference shape is available.
"""
for name, values in buffers.items():
ref = next((v for v in values if v is not None), None)
if ref is None:
continue
buffers[name] = [xr.full_like(ref, np.nan) if v is None else v for v in values]
# ---- dataset assembly ----
[docs]
def add_quantity(ds, name, arrays, base_shape, scan_coords):
if not any(isinstance(a, xr.DataArray) for a in arrays):
return ds
last = arrays[-1]
for dim in scan_coords:
if dim in last.dims:
last = last.squeeze(dim, drop=True)
arrays = [
a.squeeze(dim, drop=True) if isinstance(a, xr.DataArray) else a
for a in arrays
]
shape = base_shape + last.shape
raw = []
units = None
for a in arrays:
data = a.data
if hasattr(data, "magnitude"):
raw.append(data.magnitude)
units = data.units
else:
raw.append(np.asarray(data))
stacked = np.stack(raw).reshape(shape)
if units is not None:
stacked = stacked * units
dims = tuple(scan_coords.keys()) + last.dims
arr = xr.DataArray(
stacked,
dims=dims,
coords={
**scan_coords,
**last.coords,
},
)
arr = arr.reset_coords(drop=True)
ds[name] = arr
return ds
[docs]
class PyroScan:
"""
Creates a dictionary of pyro objects in pyro_dict
Need a templates pyro object
Dict of parameters to scan through
{ param : [values], }
"""
JSON_ATTRS = [
"value_fmt",
"value_separator",
"parameter_separator",
"parameter_dict",
"file_name",
"base_directory",
"runfile_dict",
"p_prime_type",
"parameter_map",
]
[docs]
def __init__(
self,
pyro=None,
parameter_dict=None,
p_prime_type=0,
value_fmt=".2f",
value_separator="_",
parameter_separator="/",
file_name=None,
base_directory=".",
load_default_parameter_keys=True,
pyroscan_json=None,
runfile_dict=None,
load_base_pyro=False,
):
# Mapping from parameter to location in Pyro
self.parameter_map = {}
# Need to intialise and pyro_dict pyroscan_json before base_directory
self.pyro_dict = {}
self.pyroscan_json = {}
self.parameter_func = {}
self.base_directory = base_directory
# Format values/parameters
self.value_fmt = value_fmt
self.value_separator = value_separator
if parameter_separator in ["/", "\\"]:
self.parameter_separator = os.path.sep
else:
self.parameter_separator = parameter_separator
self.runfile_dict = runfile_dict or {}
self.base_directory = pathlib.Path(base_directory).resolve()
if file_name is not None:
self.file_name = file_name
elif pyro is not None:
self.file_name = GKInput._factory[pyro.gk_code].default_file_name
if parameter_dict is None:
self.parameter_dict = {}
else:
self.parameter_dict = parameter_dict
self.p_prime_type = p_prime_type
# Load in pyroscan json if there
if pyroscan_json is not None:
pyroscan_json = pathlib.Path(pyroscan_json).resolve()
with open(pyroscan_json) as f:
self.pyroscan_json = json.load(f)
json_dir = pyroscan_json.parent
for key, value in self.pyroscan_json.items():
if key == "parameter_dict":
self.parameter_dict = {
k: (
np.asarray(raw[0]) * ureg(raw[1])
if isinstance(raw, list)
and len(raw) == 2
and isinstance(raw[1], str)
else raw
)
for k, raw in value.items()
}
self.pyroscan_json["parameter_dict"] = self.parameter_dict
continue
elif key == "base_directory":
# Resolve relative path against JSON location
resolved = _resolve_path(value, json_dir)
# User override still wins
if base_directory != ".":
resolved = pathlib.Path(base_directory).resolve()
setattr(self, key, resolved)
continue
setattr(self, key, value)
else:
self.pyroscan_json = {attr: getattr(self, attr) for attr in self.JSON_ATTRS}
if pyro is not None:
if not isinstance(pyro, Pyro):
raise TypeError("pyro must be a Pyro instance")
self.base_pyro = pyro
elif self.file_name is None:
raise ValueError(
"file_name must be specified or in json if pyro is not given"
)
elif load_base_pyro:
pyro_base = pathlib.Path(pyroscan_json).resolve().parent
in_loc = pyro_base / "pyroscan_base.input"
self.base_pyro = Pyro(gk_file=in_loc)
# Restore normalisation reference values if they were saved.
# This is needed when the base input file type (e.g. TGLF)
# cannot store normalisations that the original pyro had
# (e.g. from a GENE run).
norms_file = pyro_base / "pyroscan_norms.json"
if norms_file.exists():
self.base_pyro.read_reference_values(norms_file)
else:
raise ValueError("Either provide a pyro object or enable load_base_pyro")
if (
load_default_parameter_keys and pyroscan_json is None
): # if parameter keys are loaded from json there is no need to set defaults
self.load_default_parameter_keys()
# Canonicalise freshly-supplied parameter_dict values into pyrokinetics
# simulation units so run-directory names are consistent regardless of
# gk_code convention. On reload the JSON is authoritative — its values
# already reflect whatever unit was serialised, so we skip conversion
# there (new-format JSONs are already in pyro sim units, and old-format
# fixtures pair their magnitudes with their on-disk directory names).
if pyroscan_json is None:
# Pass the pyrokinetics convention explicitly: base_pyro.norms.default_convention
# is set to whichever gk_code was read (gs2/gene/cgyro/...), which would
# otherwise make run-directory names code-dependent.
pyro_convention = self.base_pyro.norms.pyrokinetics
for name, values in list(self.parameter_dict.items()):
if hasattr(values, "convert_physical_units"):
self.parameter_dict[name] = values.convert_physical_units(
pyro_convention
)
# Get len of values for each parameter
self.value_size = [len(value) for value in self.parameter_dict.values()]
self.pyro_dict = dict(
self.create_single_run(run) for run in self.outer_product()
)
self.run_directories = [pyro.run_directory for pyro in self.pyro_dict.values()]
[docs]
def create_single_run(self, parameters: dict):
"""
Create a new Pyro instance from the PyroScan base with new run parameters
"""
name = self.format_single_run_name(parameters)
new_run = copy.deepcopy(self.base_pyro)
new_run.gk_file = self.base_directory / name / self.file_name
new_run.run_parameters = copy.deepcopy(parameters)
return name, new_run
[docs]
def write(
self,
file_name=None,
base_directory=None,
template_file=None,
relative_path=True,
):
"""
Creates and writes GK input files for parameters in scan
"""
if file_name is not None:
self.file_name = file_name
if base_directory is not None:
self.base_directory = pathlib.Path(base_directory)
# Set run directories
self.run_directories = [
self.base_directory / run_dir for run_dir in self.pyro_dict.keys()
]
self.base_directory.mkdir(parents=True, exist_ok=True)
# Dump json file with pyroscan data
json_file = self.base_directory / "pyroscan.json"
json_data = dict(self.pyroscan_json)
if relative_path:
json_data["base_directory"] = "."
else:
json_data["base_directory"] = str(self.base_directory)
# Convert parameter_dict values to generic simulation units so the
# JSON carries run-independent unit names; reloading pairs them back
# with physical values via pyroscan_norms.json.
if "parameter_dict" in json_data:
norms = self.base_pyro.norms.pyrokinetics
json_data["parameter_dict"] = {
name: (
values.convert_physical_units(norms)
if hasattr(values, "convert_physical_units")
else values
)
for name, values in json_data["parameter_dict"].items()
}
with open(json_file, "w+") as f:
json.dump(json_data, f, cls=NumpyEncoder)
self.update_self_parameters()
# Iterate through all runs and write output
for parameter, run_dir, pyro in zip(
self.outer_product(), self.run_directories, self.pyro_dict.values()
):
# Write input file
pyro.write_gk_file(
file_name=run_dir / self.file_name, template_file=template_file
)
self.base_pyro.write_gk_file(
file_name=self.base_directory / "pyroscan_base.input"
)
# Save normalisation reference values so they can be restored
# when reloading from a base input file that cannot store them
# (e.g. TGLF inputs generated from a GENE run)
try:
self.base_pyro.write_reference_values(
self.base_directory / "pyroscan_norms.json"
)
except Exception:
warnings.warn(
"Could not save normalisation reference values to "
"pyroscan_norms.json. Normalisation-dependent units "
"may not be available when reloading this PyroScan.",
stacklevel=2,
)
[docs]
def update_self_parameters(
self,
):
"""
Updates all pyro object parameters based on pyro_dict values
"""
for parameter, run_dir, pyro in zip(
self.outer_product(), self.run_directories, self.pyro_dict.values()
):
# Param value for each run written accordingly
for param, value in parameter.items():
# Get attribute name and keys where param is stored in Pyro
attr_name, keys_to_param = self.parameter_map[param]
# Get attribute in Pyro storing the parameter
pyro_attr = getattr(pyro, attr_name)
# Set the value given the Pyro attribute and location of parameter
set_in_dict(pyro_attr, keys_to_param, value)
if param in self.parameter_func.keys():
func, kwargs = self.parameter_func[param]
func(pyro, **kwargs)
[docs]
def add_parameter_key(
self, parameter_key=None, parameter_attr=None, parameter_location=None
):
"""
parameter_key: string to access variable
parameter_attr: string of attribute storing value in pyro
parameter_location: list of strings showing path to value in pyro
"""
if parameter_key is None:
raise ValueError("Need to specify parameter key")
if parameter_attr is None:
raise ValueError("Need to specify parameter attr")
if parameter_location is None:
raise ValueError("Need to specify parameter location")
dict_item = {parameter_key: [parameter_attr, parameter_location]}
self.parameter_map.update(dict_item)
# Get attribute name and keys where param is stored in Pyro
pyro_attr = getattr(self.base_pyro, parameter_attr)
if parameter_key in self.parameter_dict:
value = self.parameter_dict[parameter_key]
if not hasattr(value, "units"):
units = getattr(
get_from_dict(pyro_attr, parameter_location[:-1])[
parameter_location[-1]
],
"units",
1,
)
if units != 1:
warnings.warn(
f"Adding units [{units}] to {parameter_key} as it has not been "
"specified. To suppress this warning please add units"
)
self.parameter_dict[parameter_key] = value * units
self.pyroscan_json["parameter_map"] = self.parameter_map
[docs]
def add_parameter_func(
self, parameter_key=None, parameter_func=None, parameter_kwargs=None
):
"""
Applies function `parameter_func(pyro, **kwargs)` on pyro object each time after
parameter_key is set in a scan
parameter_key: string to access variable
parameter_func: function that take in a pyro object applies modification
parameter_kwargs: Dictionary of kwargs to apply to function
"""
self.parameter_func[parameter_key] = (parameter_func, parameter_kwargs)
[docs]
def load_default_parameter_keys(self):
"""
Loads default parameters name into parameter_map
{param : ["attribute", ["key_to_location_1", "key_to_location_2" ]] }
for example
{'electron_temp_gradient': [
"local_species", ['electron','inverse_lt']] }
"""
self.parameter_map = {}
# ky
parameter_key = "ky"
parameter_attr = "numerics"
parameter_location = ["ky"]
self.add_parameter_key(parameter_key, parameter_attr, parameter_location)
# Electron temperature gradient
parameter_key = "electron_temp_gradient"
parameter_attr = "local_species"
parameter_location = ["electron", "inverse_lt"]
self.add_parameter_key(parameter_key, parameter_attr, parameter_location)
# Electron density gradient
parameter_key = "electron_dens_gradient"
parameter_attr = "local_species"
parameter_location = ["electron", "inverse_ln"]
self.add_parameter_key(parameter_key, parameter_attr, parameter_location)
# Deuterium temperature gradient
parameter_key = "deuterium_temp_gradient"
parameter_attr = "local_species"
parameter_location = ["deuterium", "inverse_lt"]
self.add_parameter_key(parameter_key, parameter_attr, parameter_location)
# Deuterium density gradient
parameter_key = "deuterium_dens_gradient"
parameter_attr = "local_species"
parameter_location = ["deuterium", "inverse_ln"]
self.add_parameter_key(parameter_key, parameter_attr, parameter_location)
# ExB shear
parameter_key = "gamma_exb"
parameter_attr = "numerics"
parameter_location = ["gamma_exb"]
self.add_parameter_key(parameter_key, parameter_attr, parameter_location)
# Elongation
parameter_key = "kappa"
parameter_attr = "local_geometry"
parameter_location = ["kappa"]
self.add_parameter_key(parameter_key, parameter_attr, parameter_location)
[docs]
def load_gk_output(
self,
output_convention="pyrokinetics",
tolerance_time_range=0.8,
netcdf_file=None,
load_fields=True,
load_fluxes=True,
load_moments=False,
sum_ky=True,
drop_nan=False,
**kwargs,
):
"""
Loads PyroScanGKOutput into self.gk_output
Parameters
----------
output_convention: str default 'pyrokinetics'
ConventionNormalisation to convert output to
tolerance_time_range: float default 0.8
Time window over which to calculate growth rate tolerance
netcdf_file: PathLike default None
If supplied then load PyroScanGKOutput from existing netCDF
load_fields (bool, default True) – Flag to load fields or not
load_fluxes (bool, default True) – Flag to load fluxes or not
load_moments (bool, default False) – Flag to load moments or not
drop_nan (bool, default False) – If NaNs are found in the output then that data is dropped. Off by default
**kwargs – Arguments to pass to the GKOutputReader.
Returns
-------
None
"""
# Load from netCDF is supplied
if netcdf_file is not None:
convention = getattr(self.base_pyro.norms, output_convention)
gk_output = PyroScanGKOutput.from_netcdf(netcdf_file)
gk_output.to(convention, convention.context)
self.gk_output = gk_output
return
parameter_dict = {}
coords = {}
attrs = {}
coord_units = {}
for name, values in self.parameter_dict.items():
# Unitless scan parameters are stored as plain arrays/lists.
if isinstance(values, Quantity):
vals = np.asarray(values.magnitude)
unit = values.units
else:
vals = np.asarray(values)
unit = ureg.dimensionless
parameter_dict[name] = ((name,), vals)
coords[name] = vals
attrs[name + "_units"] = str(unit)
coord_units[name] = unit
output_shape = tuple(len(v) for v in self.parameter_dict.values())
load_specs = {
"linear": {
"scalars": ["growth_rate", "mode_frequency", "eigenfunctions"],
"extras": ["growth_rate_tolerance"],
"fluxes": [],
"fields": [],
},
"nonlinear": {
"scalars": [],
"extras": [],
"fluxes": [],
"fields": [],
},
}
time_policy = {
"linear": {
"scalars": "last",
"fluxes": "last",
"fields": "last",
},
"nonlinear": {
"scalars": "last",
"fluxes": "average",
"fields": "average",
},
}
if self.base_pyro.gk_code == "TGLF":
load_specs["nonlinear"]["scalars"].extend(["growth_rate", "mode_frequency"])
load_specs["nonlinear"]["extras"].extend(["growth_rate_tolerance"])
if load_fluxes:
load_specs["linear"]["fluxes"].extend(["particle", "heat", "momentum"])
load_specs["nonlinear"]["fluxes"].extend(["particle", "heat", "momentum"])
if load_fields:
load_specs["linear"]["fields"].extend(["phi", "bpar", "apar"])
load_specs["nonlinear"]["fields"].extend(["phi", "bpar", "apar"])
regime = "nonlinear" if self.base_pyro.numerics.nonlinear else "linear"
spec = load_specs[regime]
time_policy = time_policy[regime]
buffers = {
name: []
for name in (
spec["scalars"] + spec["extras"] + spec["fluxes"] + spec["fields"]
)
}
for i, pyro in enumerate(self.pyro_dict.values()):
run_buffers = {name: None for name in buffers}
try:
pyro.load_gk_output(
output_convention=output_convention,
load_fields=load_fields,
load_fluxes=load_fluxes,
load_moments=load_moments,
drop_nan=drop_nan,
**kwargs,
)
data = pyro.gk_output.data
kx_min = float(np.min(np.abs(data.kx)))
# removes growth_rate_tolerance from nonlinear codes with no time TGLF
if (
"mode" not in pyro.gk_output.dims
and "growth_rate_tolerance" in spec["extras"]
):
run_buffers["growth_rate_tolerance"] = None
if (
"mode" not in pyro.gk_output.dims
and "growth_rate_tolerance" in spec["extras"]
):
run_buffers["growth_rate_tolerance"] = (
pyro.gk_output.get_growth_rate_tolerance(
tolerance_time_range
).sel(kx=kx_min)
)
for name in spec["scalars"]:
if name in pyro.gk_output:
run_buffers[name] = select_kx_ky_time(
pyro.gk_output[name],
kx_min=kx_min,
time_mode=time_policy["scalars"],
)
for name in spec["fluxes"]:
if name in pyro.gk_output:
run_buffers[name] = select_kx_ky_time(
pyro.gk_output[name],
kx_min=kx_min,
sum_ky=sum_ky,
time_mode=time_policy["fluxes"],
tolerance_time_range=tolerance_time_range,
)
data = data.isel(ky=[0]).squeeze()
pyro.gk_output.data = data
for name in spec["fields"]:
if name in pyro.gk_output:
run_buffers[name] = select_kx_ky_time(
pyro.gk_output[name],
kx_min=kx_min,
time_mode=time_policy["fields"],
tolerance_time_range=tolerance_time_range,
)
for name, value in run_buffers.items():
buffers[name].append(value)
except (
FileNotFoundError,
OSError,
IndexError,
RuntimeError,
KeyError,
ValueError,
) as e:
warnings.warn(
f"Unable to load gk_output for {pyro.gk_file} "
f"[{type(e).__name__}: {e}]"
)
for name in buffers:
ref = next((x for x in buffers[name] if x is not None), None)
if ref is None:
buffers[name].append(None)
else:
buffers[name].append(ref * np.nan)
finally:
if hasattr(pyro, "gk_output"):
pyro.gk_output = None
for name, arrays in buffers.items():
ref = next((x for x in arrays if x is not None), None)
if ref is None:
continue
buffers[name] = [ref * np.nan if x is None else x for x in arrays]
if all(all(x is None for x in arrays) for arrays in buffers.values()):
raise FileNotFoundError("Unable to load any gk_output files in this scan")
ds = xr.Dataset(parameter_dict)
for name, arrays in buffers.items():
ds = add_quantity(ds, name, arrays, output_shape, coords)
for coord, units in coord_units.items():
ds[coord] = ds[coord].assign_attrs(units=units)
self.gk_output = PyroScanGKOutput(ds)
self.gk_output.to(getattr(self.base_pyro.norms, output_convention))
@property
def gk_code(self):
# NOTE: In previous versions, this would return a GKCode class. Now it only
# returns a string.
# The setter has been replaced by the function 'convert_gk_code'
return self.base_pyro.gk_code
[docs]
def convert_gk_code(self, gk_code: str) -> None:
"""
Converts all gyrokinetics codes to the code type 'gk_code'. This can be any
viable GKInput type (GS2, CGYRO, GENE,...)
"""
self.base_pyro.convert_gk_code(gk_code)
for pyro in self.pyro_dict.values():
pyro.convert_gk_code(gk_code)
@property
def base_directory(self):
return self._base_directory
@base_directory.setter
def base_directory(self, value):
"""
Sets the base_directory
"""
self._base_directory = pathlib.Path(value).absolute()
self.pyroscan_json["base_directory"] = self._base_directory
# Set base_directory in copies of pyro
for key, pyro in self.pyro_dict.items():
pyro.gk_file = self.base_directory / key / pyro.gk_file.name
@property
def file_name(self):
return self._file_name
@file_name.setter
def file_name(self, value):
"""
Sets the file_name
"""
self.pyroscan_json["file_name"] = value
self._file_name = value
[docs]
def outer_product(self):
"""
Creates generator of outer product for all parameter permutations
"""
return (
dict(zip(self.parameter_dict, x))
for x in product(*self.parameter_dict.values())
)
[docs]
def get_from_dict(data_dict, map_list):
"""
Gets item in dict given location as a list of string
"""
return reduce(get_attr_or_item, map_list, data_dict)
[docs]
def get_attr_or_item(obj, value):
if hasattr(obj, value):
return getattr(obj, value)
elif value in obj.keys():
return obj[value]
else:
raise ValueError(f"{obj} has not got {value} as a key or attribute")
[docs]
def set_in_dict(data_dict, map_list, value):
"""
Sets item in dict given location as a list of string
"""
get_from_dict(data_dict, map_list[:-1])[map_list[-1]] = copy.deepcopy(value)
[docs]
@contextmanager
def cd(newdir):
prevdir = os.getcwd()
os.chdir(os.path.expanduser(newdir))
try:
yield
finally:
os.chdir(prevdir)
[docs]
class NumpyEncoder(json.JSONEncoder):
"""Numpy/pint-aware JSON encoder. Quantities are stored as
``[magnitude, unit_str]`` with unit names left fully qualified
(including any instance-specific normalisation suffix)."""
[docs]
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, pathlib.Path):
return str(obj)
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, pint.Quantity):
return [obj.m, str(obj.units)]
return json.JSONEncoder.default(self, obj)
[docs]
class PyroScanGKOutput(DatasetWrapper):
[docs]
def __init__(self, dataset: xr.Dataset):
data_vars = dataset.data_vars
coords = dataset.coords
attrs = dataset.attrs
# Hand over to underlying dataset
super().__init__(data_vars=data_vars, coords=coords, attrs=attrs)
[docs]
def to(self, norms: ConventionNormalisation, *contexts):
"""
Parameters
----------
norms : ConventionNormalisation
Normalisation convention to convert to
Returns
-------
GKOutput with units from norms
"""
for data_var in self.data_vars:
self[data_var].data = self[data_var].data.to(norms, *contexts)
# Coordinates with units not supported in xarray need to manually change
new_coords = {}
for coord in self.coords:
if hasattr(self[coord], "units"):
if self[coord].units is None:
continue
new_coord = (self[coord].data * self[coord].units).to(norms, *contexts)
new_coords[coord] = (
coord,
new_coord.m,
{"units": new_coord.units},
)
self.data = self.data.assign_coords(coords=new_coords)
[docs]
def unwrap(self):
"""Return the underlying xarray.Dataset."""
return self._dataset