Source code for pyrokinetics.gk_code.stella
from __future__ import annotations
import warnings
from copy import copy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import f90nml
import numpy as np
from cleverdict import CleverDict
from ..constants import deuterium_mass, electron_mass, pi
from ..file_utils import FileReader
from ..local_geometry import (
LocalGeometry,
LocalGeometryMiller,
MetricTerms,
default_miller_inputs,
)
from ..local_species import LocalSpecies
from ..normalisation import SimulationNormalisation as Normalisation
from ..normalisation import convert_dict, ureg
from ..numerics import Numerics
from ..templates import gk_templates
from ..typing import PathLike
from ..units import PyroContextError
from .gk_input import GKInput
from .gk_output import Coords, Eigenvalues, Fields, Fluxes, GKOutput, Moments
if TYPE_CHECKING:
import xarray as xr
[docs]
class GKInputSTELLA(GKInput, FileReader, file_type="STELLA", reads=GKInput):
"""
Class that can read STELLA input files, and produce
Numerics, LocalSpecies, and LocalGeometry objects
"""
code_name = "STELLA"
default_file_name = "input.in"
norm_convention = "stella"
_parameters_physics = "parameters_physics"
_parameters_params = "parameters_physics"
_parameters_numerical = "parameters_numerical"
_legacy_stella = False
pyro_stella_miller = {
"rho": ["millergeo_parameters", "rhoc"],
"Rmaj": ["millergeo_parameters", "rmaj"],
"q": ["millergeo_parameters", "qinp"],
"kappa": ["millergeo_parameters", "kappa"],
"shat": ["millergeo_parameters", "shat"],
"shift": ["millergeo_parameters", "shift"],
"beta_prime": ["millergeo_parameters", "betaprim"],
}
pyro_stella_miller_defaults = {
"rho": 0.5,
"Rmaj": 3.0,
"q": 1.5,
"kappa": 1.0,
"shat": 0.0,
"shift": 0.0,
"beta_prime": 0.0,
}
pyro_stella_species = {
"mass": "mass",
"z": "z",
"dens": "dens",
"temp": "temp",
"inverse_lt": "tprim",
"inverse_ln": "fprim",
}
def _set_legacy_stella(self, flag: bool):
"""
Set dictionay flags used to access namelist to legacy values or not
"""
if flag:
self._legacy_stella = True
self._parameters_numerical = "knobs"
self._parameters_params = "parameters"
self._parameters_physics = "physics_flags"
else:
self._legacy_stella = False
self._parameters_numerical = "parameters_numerical"
self._parameters_params = "parameters_physics"
self._parameters_physics = "parameters_physics"
[docs]
def read_from_file(
self, filename: PathLike, detect_norm: bool = True
) -> Dict[str, Any]:
"""
Reads STELLA input file into a dictionary
"""
result = super().read_from_file(filename, detect_norm=detect_norm)
if {"knobs", "parameters", "physics_flags"}.intersection(result.keys()):
warnings.warn(
"The keys 'knobs'/'parameters'/'physics_flags' were found in the input file suggesting this is a "
"legacy input file, please update this to the latest version to suppress this warning"
)
self._set_legacy_stella(True)
return result
[docs]
def read_str(self, input_string: str, detect_norm: bool = True) -> Dict[str, Any]:
"""
Reads STELLA input file given as string
Uses default read_str, which assumes input_string is a Fortran90 namelist
"""
result = super().read_str(input_string, detect_norm=detect_norm)
if {"knobs", "parameters", "physics_flags"}.intersection(result.keys()):
warnings.warn(
"The keys 'knobs'/'parameters'/'physics_flags' were found in the input file suggesting this is a "
"legacy input file, please update this to the latest version to suppress this warning"
)
self._set_legacy_stella(True)
return result
[docs]
def read_dict(self, input_dict: dict, detect_norm: bool = True) -> Dict[str, Any]:
"""
Reads STELLA input file given as dict
Uses default read_dict, which assumes input is a dict
"""
return super().read_dict(input_dict, detect_norm=detect_norm)
[docs]
def verify_file_type(self, filename: PathLike):
"""
Ensure this file is a valid stella input file, and that it contains sufficient
info for Pyrokinetics to work with
"""
# The following keys are not strictly needed for a stella input file,
# but they are needed by Pyrokinetics
expected_keys = [
# "parameters_numerical",
"zgrid_parameters",
"geo_knobs",
"millergeo_parameters",
# "parameters_physics",
"species_knobs",
"kt_grids_knobs",
]
self.verify_expected_keys(filename, expected_keys)
[docs]
def write(
self,
filename: PathLike,
float_format: str = "",
local_norm=None,
code_normalisation: str = None,
):
if local_norm is None:
local_norm = Normalisation("write")
if code_normalisation is None:
code_normalisation = self.code_name.lower()
convention = getattr(local_norm, code_normalisation)
for name, namelist in self.data.items():
self.data[name] = convert_dict(namelist, convention)
super().write(filename, float_format=float_format)
[docs]
def is_nonlinear(self) -> bool:
try:
is_box = self.data["kt_grids_knobs"]["grid_option"] == "box"
is_nonlinear = self.data[self._parameters_physics]["nonlinear"]
return is_box and is_nonlinear
except KeyError:
return False
[docs]
def add_flags(self, flags) -> None:
"""
Add extra flags to STELLA input file
"""
super().add_flags(flags)
[docs]
def get_local_geometry(self) -> LocalGeometry:
"""
Returns local geometry. Delegates to more specific functions
"""
if hasattr(self, "convention"):
convention = self.convention
else:
norms = Normalisation("get_local_geometry")
convention = getattr(norms, self.norm_convention)
stella_eq = self.data["geo_knobs"]["geo_option"]
if stella_eq not in ["miller"]:
raise NotImplementedError(
f"stella equilibrium option {stella_eq} not implemented"
)
local_geometry = self.get_local_geometry_miller()
local_geometry.B0 = (
self.data["millergeo_parameters"]["rgeo"]
/ self.data["millergeo_parameters"]["rmaj"]
)
local_geometry.dpsidr *= local_geometry.B0
local_geometry.normalise(norms=convention)
local_geometry.Fpsi = local_geometry.get_f_psi()
local_geometry.FF_prime = local_geometry.get_f_prime() * local_geometry.Fpsi
return local_geometry
[docs]
def get_local_geometry_miller(self) -> LocalGeometryMiller:
"""
Load Basic Miller object from stella file
"""
miller_data = default_miller_inputs()
for (pyro_key, (stella_param, stella_key)), stella_default in zip(
self.pyro_stella_miller.items(), self.pyro_stella_miller_defaults.values()
):
miller_data[pyro_key] = self.data[stella_param].get(
stella_key, stella_default
)
rho = miller_data["rho"]
kappa = miller_data["kappa"]
miller_data["delta"] = np.sin(self.data["millergeo_parameters"].get("tri", 0.0))
miller_data["s_kappa"] = (
self.data["millergeo_parameters"].get("kapprim", 0.0) * rho / kappa
)
miller_data["s_delta"] = (
self.data["millergeo_parameters"].get("triprim", 0.0) * rho
)
# convert from stella normalisation to pyrokinetics normalisation of beta_prime
miller_data["beta_prime"] *= -2.0
miller_data["ip_ccw"] = 1
miller_data["bt_ccw"] = 1
# must construct using from_gk_data as we cannot determine bunit_over_b0 here
return LocalGeometryMiller.from_gk_data(miller_data)
[docs]
def get_local_species(self):
"""
Load LocalSpecies object from stella file
"""
# Dictionary of local species parameters
local_species = LocalSpecies()
ion_count = 0
ne_norm, Te_norm = self.get_ne_te_normalisation()
# get the reference collision frequency from the stella data
# ready for conversion to species-specific collision frequencies
# in the pyrokinetics internal format
vnew_ref = self.data[self._parameters_params]["vnew_ref"]
# Load each species into a dictionary
for i_sp in range(self.data["species_knobs"]["nspec"]):
species_data = CleverDict()
stella_key = f"species_parameters_{i_sp + 1}"
stella_data = self.data[stella_key]
for pyro_key, stella_key in self.pyro_stella_species.items():
species_data[pyro_key] = stella_data[stella_key]
# normalisation factor to get into GS2 convention
normfac = (
species_data.dens
* (species_data.z**4)
/ (np.sqrt(species_data.mass) * (species_data.temp**1.5))
)
species_data.nu = vnew_ref * normfac
# assume rotation not implemented in stella
species_data.omega0 = 0.0 * ureg.vref_most_probable / ureg.lref_minor_radius
# assume no isolated PVG term in stella
species_data.domega_drho = (
0.0 * ureg.vref_most_probable / ureg.lref_minor_radius**2
)
if species_data.z == -1:
name = "electron"
else:
ion_count += 1
name = f"ion{ion_count}"
species_data.name = name
# normalisations
species_data.dens *= ureg.nref_electron / ne_norm
species_data.mass *= ureg.mref_deuterium
species_data.nu *= ureg.vref_most_probable / ureg.lref_minor_radius
species_data.temp *= ureg.tref_electron / Te_norm
species_data.z *= ureg.elementary_charge
species_data.inverse_lt *= ureg.lref_minor_radius**-1
species_data.inverse_ln *= ureg.lref_minor_radius**-1
# Add individual species data to dictionary of species
local_species.add_species(name=name, species_data=species_data)
local_species.normalise()
if "zeff" in self.data[self._parameters_params]:
local_species.zeff = (
self.data[self._parameters_params]["zeff"] * ureg.elementary_charge
)
else:
local_species.zeff = 1.0 * ureg.elementary_charge
return local_species
def _read_range_grid(self):
range_options = self.data["kt_grids_range_parameters"]
nky = range_options.get("naky", 1)
ky_min = range_options.get("aky_min", 0.0)
ky_max = range_options.get("aky_max", 0.0)
spacing_option = range_options.get("kyspacing_option", "linear")
if spacing_option == "default":
spacing_option = "linear"
ky_space = np.linspace if spacing_option == "linear" else np.logspace
ky = ky_space(ky_min, ky_max, nky)
return {
"nky": nky,
"nkx": 1,
"ky": ky,
"kx": np.array([0.0]),
"theta0": 0.0,
}
def _read_box_grid(self):
box = self.data["kt_grids_box_parameters"]
keys = box.keys()
grid_data = {}
# Set up ky grid
if "ny" in keys:
grid_data["nky"] = int((box["ny"] - 1) / 3 + 1)
else:
raise RuntimeError(f"ky grid details not found in {keys}")
if "y0" in keys:
if box["y0"] < 0.0:
grid_data["ky"] = -box["y0"]
else:
grid_data["ky"] = 1 / box["y0"]
else:
raise RuntimeError(f"Min ky details not found in {keys}")
if "nx" in keys:
grid_data["nkx"] = int(2 * (box["nx"] - 1) / 3 + 1)
else:
raise RuntimeError("kx grid details not found in {keys}")
shat_params = self.pyro_stella_miller["shat"]
shat = self.data[shat_params[0]][shat_params[1]]
if abs(shat) > 1e-6:
jtwist_default = max(int(2 * pi * shat + 0.5), 1)
jtwist = box.get("jtwist", jtwist_default)
grid_data["kx"] = grid_data["ky"] * shat * 2 * pi / jtwist
else:
grid_data["kx"] = 2 * pi / box["x0"]
return grid_data
def _read_grid(self):
"""Read the perpendicular wavenumber grid"""
grid_option = self.data["kt_grids_knobs"].get("grid_option", "range")
GRID_READERS = {
"default": self._read_range_grid,
"range": self._read_range_grid,
"box": self._read_box_grid,
}
try:
reader = GRID_READERS[grid_option]
except KeyError:
valid_options = ", ".join(f"'{option}'" for option in GRID_READERS)
raise ValueError(
f"Unknown stella 'kt_range_knobs::grid_option', '{grid_option}'. Expected one of {valid_options}"
)
return reader()
[docs]
def get_numerics(self) -> Numerics:
"""Gather numerical info (grid spacing, time steps, etc)"""
if hasattr(self, "convention"):
convention = self.convention
else:
norms = Normalisation("get_numerics")
convention = getattr(norms, self.norm_convention)
numerics_data = {}
# Set no. of fields
numerics_data["phi"] = (
self.data[self._parameters_numerical].get("fphi", 0.0) > 0.0
)
numerics_data["apar"] = self.data[self._parameters_physics].get(
"include_apar", False
)
numerics_data["bpar"] = self.data[self._parameters_physics].get(
"include_bpar", False
)
# Set time stepping
delta_time = self.data[self._parameters_numerical].get("delt", 0.005)
numerics_data["delta_time"] = delta_time
if "tend" in self.data[self._parameters_numerical]:
numerics_data["max_time"] = self.data[self._parameters_numerical]["tend"]
else:
numerics_data["max_time"] = (
self.data[self._parameters_numerical].get("nstep", 50000) * delta_time
)
numerics_data["nonlinear"] = self.is_nonlinear()
numerics_data.update(self._read_grid())
# z grid
numerics_data["ntheta"] = self.data["zgrid_parameters"]["nzed"]
numerics_data["nperiod"] = self.data["zgrid_parameters"]["nperiod"]
# Velocity grid
numerics_data["nenergy"] = self.data["vpamu_grids_parameters"]["nvgrid"]
numerics_data["npitch"] = self.data["vpamu_grids_parameters"]["nmu"]
numerics_data["beta"] = self._get_beta()
numerics_data["gamma_exb"] = self.data[self._parameters_params].get(
"g_exb", 0.0
)
return Numerics(**numerics_data).with_units(convention)
[docs]
def get_reference_values(self, local_norm: Normalisation) -> Dict[str, Any]:
"""
Reads in normalisation values from input file
"""
if "normalisations_knobs" not in self.data.keys():
return {}
norms = {}
norms["tref_electron"] = (
self.data["normalisations_knobs"]["tref"] * local_norm.units.eV
)
norms["nref_electron"] = (
self.data["normalisations_knobs"]["nref"] * local_norm.units.meter**-3
)
norms["bref_B0"] = (
self.data["normalisations_knobs"]["bref"] * local_norm.units.tesla
)
norms["lref_minor_radius"] = (
self.data["normalisations_knobs"]["aref"] * local_norm.units.meter
)
return norms
def _detect_normalisation(self):
"""
Determines the necessary inputs and passes information to the base method _set_up_normalisation.
The following values are needed
default_references: dict
Dictionary containing default reference values for the
gk_code: str
GK code
electron_density: float
Electron density from GK input
electron_temperature: float
Electron density from GK input
e_mass: float
Electron mass from GK input
electron_index: int
Index of electron in list of data
found_electron: bool
Flag on whether electron was found
densities: ArrayLike
List of species densities
temperatures: ArrayLike
List of species temperature
reference_density_index: ArrayLike
List of indices where the species has a density of 1.0
reference_temperature_index: ArrayLike
List of indices where the species has a temperature of 1.0
major_radius: float
Normalised major radius from GK input
rgeo_rmaj: float
Ratio of Geometric and flux surface major radius
minor_radius: float
Normalised minor radius from GK input
"""
default_references = {
"nref_species": "electron",
"tref_species": "electron",
"mref_species": "deuterium",
"bref": "B0",
"lref": "minor_radius",
"ne": 1.0,
"te": 1.0,
"rgeo_rmaj": 1.0,
"vref": "most_probable",
"rhoref": "gs2",
"raxis_rmaj": None,
}
reference_density_index = []
reference_temperature_index = []
densities = []
temperatures = []
masses = []
found_electron = False
e_mass = None
electron_temperature = None
electron_density = None
electron_index = None
# Load each species into a dictionary
for i_sp in range(self.data["species_knobs"]["nspec"]):
species_key = f"species_parameters_{i_sp + 1}"
dens = self.data[species_key]["dens"]
temp = self.data[species_key]["temp"]
mass = self.data[species_key]["mass"]
# Find all reference values
if self.data[species_key]["z"] == -1:
electron_density = dens
electron_temperature = temp
e_mass = mass
electron_index = i_sp
found_electron = True
if np.isclose(dens, 1.0):
reference_density_index.append(i_sp)
if np.isclose(temp, 1.0):
reference_temperature_index.append(i_sp)
densities.append(dens)
temperatures.append(temp)
masses.append(mass)
adiabatic_electron_flags = ["iphi00=2", "field-line-average-term"]
if (
not found_electron
and self.data[self._parameters_physics]["adiabatic_option"]
in adiabatic_electron_flags
):
found_electron = True
electron_density = 1.0 / self.data[self._parameters_physics].get(
"nine", 1.0
)
electron_temperature = 1.0 / self.data[self._parameters_physics].get(
"tite", 1.0
)
e_mass = (electron_mass / deuterium_mass).m
n_species = self.data["species_knobs"]["nspec"]
electron_index = n_species + 1
if np.isclose(electron_density, 1.0):
reference_density_index.append(n_species + 1)
if np.isclose(electron_temperature, 1.0):
reference_temperature_index.append(n_species + 1)
rgeo_rmaj = (
self.data["millergeo_parameters"]["rgeo"]
/ self.data["millergeo_parameters"]["rmaj"]
)
major_radius = self.data["millergeo_parameters"]["rmaj"]
minor_radius = 1.0
super()._set_up_normalisation(
default_references=default_references,
gk_code=self.code_name.lower(),
electron_density=electron_density,
electron_temperature=electron_temperature,
e_mass=e_mass,
electron_index=electron_index,
found_electron=found_electron,
densities=densities,
temperatures=temperatures,
reference_density_index=reference_density_index,
reference_temperature_index=reference_temperature_index,
major_radius=major_radius,
rgeo_rmaj=rgeo_rmaj,
minor_radius=minor_radius,
)
[docs]
def set(
self,
local_geometry: LocalGeometry,
local_species: LocalSpecies,
numerics: Numerics,
local_norm: Normalisation = None,
template_file: Optional[PathLike] = None,
code_normalisation: Optional[str] = None,
**kwargs,
):
"""
Set self.data using LocalGeometry, LocalSpecies, and Numerics.
These may be obtained via another GKInput file, or from Equilibrium/Kinetics
objects.
"""
# If self.data is not already populated, fill in defaults from a given
# template file. If this is not provided by the user, fall back to the
# default.
if self.data is None:
if template_file is None:
template_file = gk_templates["STELLA"]
self.read_from_file(template_file)
if local_norm is None:
local_norm = Normalisation("set")
if code_normalisation is None:
code_normalisation = self.norm_convention
convention = getattr(local_norm, code_normalisation)
# Set Miller Geometry bits
if not isinstance(local_geometry, LocalGeometryMiller):
raise NotImplementedError(
f"LocalGeometry type {local_geometry.__class__.__name__} for stella not supported yet"
)
# Ensure Miller settings
self.data["geo_knobs"]["geo_option"] = "miller"
# Assign Miller values to input file
for key, val in self.pyro_stella_miller.items():
self.data[val[0]][val[1]] = local_geometry[key]
self.data["millergeo_parameters"]["rgeo"] = local_geometry.Rmaj
# get stella normalised beta_prime
self.data["millergeo_parameters"]["betaprim"] = -0.5 * local_geometry.beta_prime
self.data["millergeo_parameters"]["kapprim"] = (
local_geometry.s_kappa * local_geometry.kappa / local_geometry.rho
)
self.data["millergeo_parameters"]["tri"] = np.arcsin(local_geometry.delta)
self.data["millergeo_parameters"]["triprim"] = (
local_geometry["s_delta"] / local_geometry.rho
)
# Set local species bits
n_species = local_species.nspec
self.data["species_knobs"]["nspec"] = local_species.nspec
self.data["species_knobs"]["species_option"] = "stella"
stored_species = len(
[key for key in self.data.keys() if "species_parameters_" in key]
)
extra_species = stored_species - n_species
if extra_species > 0:
for i_sp in range(extra_species):
stella_key = f"species_parameters_{i_sp + 1 + n_species}"
if stella_key in self.data:
self.data.pop(stella_key)
for iSp, name in enumerate(local_species.names):
# add new outer params for each species
species_key = f"species_parameters_{iSp + 1}"
if species_key not in self.data:
self.data[species_key] = copy(self.data["species_parameters_1"])
if name == "electron":
self.data[species_key]["type"] = "electron"
else:
self.data[species_key]["type"] = "ion"
for key, val in self.pyro_stella_species.items():
self.data[species_key][val] = local_species[name][key]
if "electron" in local_species.names:
if local_species.electron.domega_drho.m != 0:
warnings.warn(
"stella does not support PVG term so this is not included"
)
self.data[self._parameters_params]["zeff"] = local_species.zeff
beta_ref = convention.beta if local_norm else 0.0
self.data[self._parameters_params]["beta"] = (
numerics.beta if numerics.beta is not None else beta_ref
)
# set the reference collision frequency
specref = self.data["species_parameters_1"]
normfac = (
(specref["z"] ** 4)
* specref["dens"]
/ (np.sqrt(specref["mass"]) * (specref["temp"] ** 1.5))
)
nameref = local_species.names[0]
vnew_ref = local_species[nameref]["nu"].to(convention)
# convert to the reference parameter from the species parameter of species 1
self.data[self._parameters_params]["vnew_ref"] = vnew_ref / normfac
# Set numerics bits
self.data["dissipation"]["include_collisions"] = (
True if vnew_ref > 0.0 else False
)
# other parameters from the dissipation namelist related to collisions are
# collisions_implicit = True/False
# collision_model = "dougherty"/"fokker-planck"
# Set no. of fields
self.data[self._parameters_numerical]["fphi"] = 1.0 if numerics.phi else 0.0
self.data[self._parameters_physics]["include_apar"] = numerics.apar
self.data[self._parameters_physics]["include_bpar"] = numerics.bpar
# Set time stepping
self.data[self._parameters_numerical]["delt"] = numerics.delta_time
# self.data[self._parameters_numerical]["nstep"] = int(numerics.max_time / numerics.delta_time)
self.data[self._parameters_numerical]["tend"] = int(numerics.max_time.m)
if numerics.nky == 1:
self.data["kt_grids_knobs"]["grid_option"] = "range"
if "kt_grids_range_parameters" not in self.data.keys():
self.data["kt_grids_range_parameters"] = {}
try:
ky = (
numerics.ky[0]
* (1 * convention.bref / local_norm.stella.bref).to_base_units()
)
except IndexError:
ky = (
numerics.ky
* (1 * convention.bref / local_norm.stella.bref).to_base_units()
)
self.data["kt_grids_range_parameters"]["aky_min"] = ky
self.data["kt_grids_range_parameters"]["aky_max"] = ky
self.data["kt_grids_range_parameters"]["theta0_min"] = numerics.theta0
self.data["kt_grids_range_parameters"]["theta0_max"] = numerics.theta0
self.data["kt_grids_range_parameters"]["naky"] = 1
self.data["kt_grids_range_parameters"]["nakx"] = 1
self.data["zgrid_parameters"]["nperiod"] = numerics.nperiod
else:
self.data["kt_grids_knobs"]["grid_option"] = "box"
if "kt_grids_box_parameters" not in self.data.keys():
self.data["kt_grids_box_parameters"] = {}
self.data["kt_grids_box_parameters"]["nx"] = int(
((numerics.nkx - 1) * 3 / 2) + 1
)
self.data["kt_grids_box_parameters"]["ny"] = int(
((numerics.nky - 1) * 3) + 1
)
self.data["kt_grids_box_parameters"]["y0"] = -numerics.ky
# Currently forces NL sims to have nperiod = 1
self.data["zgrid_parameters"]["nperiod"] = 1
shat = local_geometry.shat
if abs(shat) < 1e-6:
self.data["kt_grids_box_parameters"]["x0"] = 2 * pi / numerics.kx
else:
if numerics.kx == 0:
self.data["kt_grids_box_parameters"]["jtwist"] = 1
else:
self.data["kt_grids_box_parameters"]["jtwist"] = int(
(numerics.ky * shat * 2 * pi / numerics.kx) + 0.1
)
self.data["zgrid_parameters"]["nzed"] = numerics.ntheta
self.data["vpamu_grids_parameters"]["nvgrid"] = numerics.nenergy
self.data["vpamu_grids_parameters"]["nmu"] = numerics.npitch
self.data[self._parameters_params]["g_exb"] = numerics.gamma_exb
self.data[self._parameters_physics]["nonlinear"] = numerics.nonlinear
if not local_norm:
return
try:
(1 * convention.tref).to("keV")
si_units = True
except PyroContextError:
si_units = False
if si_units:
if "normalisations_knobs" not in self.data.keys():
self.data["normalisations_knobs"] = f90nml.Namelist()
self.data["normalisations_knobs"]["tref"] = (1 * convention.tref).to("eV")
self.data["normalisations_knobs"]["nref"] = (1 * convention.nref).to(
"meter**-3"
)
self.data["normalisations_knobs"]["mref"] = (1 * convention.mref).to(
"atomic_mass_constant"
)
self.data["normalisations_knobs"]["bref"] = (1 * convention.bref).to(
"tesla"
)
self.data["normalisations_knobs"]["aref"] = (1 * convention.lref).to(
"meter"
)
self.data["normalisations_knobs"]["vref"] = (1 * convention.vref).to(
"meter/second"
)
self.data["normalisations_knobs"]["qref"] = 1 * convention.qref
self.data["normalisations_knobs"]["rhoref"] = (1 * convention.rhoref).to(
"meter"
)
for name, namelist in self.data.items():
self.data[name] = convert_dict(namelist, convention)
[docs]
def get_ne_te_normalisation(self):
found_electron = False
# Load each species into a dictionary
for i_sp in range(self.data["species_knobs"]["nspec"]):
stella_key = f"species_parameters_{i_sp + 1}"
if (
self.data[stella_key]["z"] == -1
and self.data[stella_key]["type"] == "electron"
):
ne = self.data[stella_key]["dens"]
Te = self.data[stella_key]["temp"]
found_electron = True
break
adiabatic_electron_flags = ["iphi00=2", "field-line-average-term"]
if (
not found_electron
and self.data[self._parameters_physics]["adiabatic_option"]
in adiabatic_electron_flags
):
found_electron = True
ne = 1.0 / self.data[self._parameters_physics].get("nine", 1.0)
Te = 1.0 / self.data[self._parameters_physics].get("tite", 1.0)
if not found_electron:
raise TypeError(
"Pyro currently only supports electron species with charge = -1"
)
return ne, Te
def _get_beta(self):
beta_default = 0.0
return self.data[self._parameters_params].get("beta", beta_default)
[docs]
class GKOutputReaderSTELLA(FileReader, file_type="STELLA", reads=GKOutput):
[docs]
def read_from_file(
self,
filename: PathLike,
norm: Normalisation,
output_convention: str = "pyrokinetics",
downsize: int = 1,
load_fields=True,
load_fluxes=True,
load_moments=False,
) -> GKOutput:
raw_data, gk_input, input_str = self._get_raw_data(filename)
coords = self._get_coords(raw_data, gk_input, downsize)
fields = self._get_fields(raw_data) if load_fields else None
fluxes = self._get_fluxes(raw_data, gk_input, coords) if load_fluxes else None
moments = (
self._get_moments(raw_data, gk_input, coords) if load_moments else None
)
eigenvalues = None
if not fields and coords["linear"]:
eigenvalues = self._get_eigenvalues(raw_data, coords["time_divisor"])
# Assign units and return GKOutput
convention = getattr(norm, gk_input.norm_convention)
norm.default_convention = output_convention.lower()
gk_input.convention = convention
field_dims = ("theta", "kx", "ky", "time")
flux_dims = ("species", "kx", "ky", "time")
moment_dims = ("species", "kx", "ky", "time")
return GKOutput(
coords=Coords(
time=coords["time"],
kx=coords["kx"],
ky=coords["ky"],
theta=coords["zed"],
energy=coords["vpa"],
pitch=coords["mu"],
species=coords["species"],
field=coords["field"],
).with_units(convention),
norm=norm,
fields=(
Fields(**fields, dims=field_dims).with_units(convention)
if fields
else None
),
fluxes=(
Fluxes(**fluxes, dims=flux_dims).with_units(convention)
if fluxes
else None
),
moments=(
Moments(**moments, dims=moment_dims).with_units(convention)
if moments
else None
),
eigenvalues=(
Eigenvalues(**eigenvalues).with_units(convention)
if eigenvalues
else None
),
linear=coords["linear"],
gk_code="STELLA",
input_file=input_str,
normalise_flux_moment=True,
output_convention=output_convention,
input_convention=convention.name,
jacobian=coords["jacobian"],
)
[docs]
def verify_file_type(self, filename: PathLike):
import xarray as xr
try:
# warnings.filterwarnings("error")
data = xr.open_dataset(filename)
except RuntimeWarning:
warnings.resetwarnings()
raise RuntimeError("Error occurred reading stella output file")
warnings.resetwarnings()
if "software_name" in data.attrs:
if data.attrs["software_name"] != "stella":
raise RuntimeError(
f"file '{filename}' has wrong 'software_name' for a stella file"
)
elif "code_info" in data.data_vars:
if data["code_info"].long_name != "stella":
raise RuntimeError(
f"file '{filename}' has wrong 'code_info' for a stella file"
)
elif "stella_help" in data.attrs.keys():
pass
else:
raise RuntimeError(f"file '{filename}' missing expected stella attributes")
[docs]
@staticmethod
def infer_path_from_input_file(filename: PathLike) -> Path:
"""
Gets path by removing ".in" and replacing it with ".out.nc"
"""
filename = Path(filename)
return filename.parent / (filename.stem + ".out.nc")
@staticmethod
def _get_raw_data(filename: PathLike) -> Tuple[xr.Dataset, GKInputSTELLA, str]:
import xarray as xr
raw_data = xr.open_dataset(filename)
# Read input file from netcdf, store as GKInputSTELLA
input_file = raw_data["input_file"]
if input_file.shape == ():
# New diagnostics, input file stored as bytes
# - Stored within numpy 0D array, use [()] syntax to extract
# - Convert bytes to str by decoding
# - \n is represented as character literals '\' 'n'. Replace with '\n'.
input_str = input_file.data[()].decode("utf-8").replace(r"\n", "\n")
else:
# Old diagnostics (and eventually the single merged diagnostics)
# input file stored as array of bytes
if isinstance(input_file.data[0], np.ndarray):
input_str = "\n".join(
("".join(np.char.decode(line)).strip() for line in input_file.data)
)
else:
input_str = "\n".join(
(line.decode("utf-8") for line in input_file.data)
)
gk_input = GKInputSTELLA()
gk_input.read_str(input_str)
return raw_data, gk_input, input_str
@staticmethod
def _get_coords(
raw_data: xr.Dataset, gk_input: GKInputSTELLA, downsize: int
) -> Dict[str, Any]:
# ky coords
ky = raw_data["ky"].data
# time coords
time_divisor = 1
time = raw_data["t"].data / time_divisor
# kx coords
# Shift kx=0 to middle of array
kx = np.fft.fftshift(raw_data["kx"].data)
# zed coords
zed = raw_data["zed"].data
# vpa coords
vpa = raw_data["vpa"].data
# mu coords
mu = raw_data["mu"].data
# moment coords
fluxes = ["particle", "heat", "momentum"]
moments = ["density", "temperature", "upar", "spitzer2"]
# field coords
# stella is hardcoded to require phi, only apar and bpar are optional
field_vals = {"phi": True}
for field, default in zip(["apar", "bpar"], [False, False]):
try:
field_vals[field] = gk_input.data[gk_input._parameters_physics][
f"include_{field}"
]
except KeyError:
field_vals[field] = default
fields = [field for field, val in field_vals.items() if val > 0]
# species coords
# TODO is there some way to get this info without looking at the input data?
species = []
ion_num = 0
for idx in range(gk_input.data["species_knobs"]["nspec"]):
if gk_input.data[f"species_parameters_{idx + 1}"]["z"] == -1:
species.append("electron")
else:
ion_num += 1
species.append(f"ion{ion_num}")
local_geometry = gk_input.get_local_geometry()
metric_terms = MetricTerms(local_geometry, ntheta=len(zed) * 4)
theta_mod = np.mod(zed, 2 * np.pi)
Jacobian = np.interp(
theta_mod,
metric_terms.regulartheta,
metric_terms.Jacobian,
period=2 * np.pi,
)
return {
"time": time,
"kx": kx,
"ky": ky,
"zed": zed,
"vpa": vpa,
"mu": mu,
"linear": gk_input.is_linear(),
"time_divisor": time_divisor,
"field": fields,
"moment": moments,
"flux": fluxes,
"species": species,
"downsize": downsize,
"jacobian": Jacobian,
}
@staticmethod
def _get_fields(raw_data: xr.Dataset) -> Dict[str, np.ndarray]:
"""
to have fields written out versus time, we must set
&stella_diagnostics_knobs
write_phi_vs_time = .true.
write_apar_vs_time = .true.
write_bpar_vs_time = .true.
/
at the same time, we must also set
&physics_flags
include_apar = .true.
include_bpar = .true.
/
to include apar and bpar in the simulation
"""
field_names = ("phi", "apar", "bpar")
results = {}
# Loop through all fields and add field if it exists
for field_name in field_names:
key = f"{field_name}_vs_t"
if key not in raw_data:
continue
# raw_field has coords (t, tube, zed, kx, ky, real/imag).
# We wish to transpose that to (real/imag,zed,kx,ky,t)
# Selecting first index in tube
field = raw_data[key].transpose("tube", "ri", "zed", "kx", "ky", "t").data
field = field[0, 0, ...] + 1j * field[0, 1, ...]
# Adjust fields to account for differences in defintions/normalisations
# A||_stella = 0.5 * A||_gs2
# B||_stella = B||_gs2 * B
# infer from GS2 script that no adjustments required here
# Shift kx=0 to middle of axis
field = np.fft.fftshift(field, axes=1)
results[field_name] = field
return results
@staticmethod
def _get_moments(
raw_data: Dict[str, Any],
gk_input: GKInputSTELLA,
coords: Dict[str, Any],
) -> Dict[str, np.ndarray]:
"""
Sets 3D moments over time.
The moment coordinates should be (moment, theta, kx, species, ky, time)
"""
raise NotImplementedError
@staticmethod
def _get_fluxes(
raw_data: xr.Dataset,
gk_input: GKInputSTELLA,
coords: Dict,
) -> Dict[str, np.ndarray]:
"""
For stella to print fluxes(t,species) to the netcdf file, at the present time we must
be using the branch https://github.com/stellaGK/stella/tree/development/apar2plusbpar.
Otherwise the fluxes are automatically written to the ascii text files.
To make the fluxes as a function of ky and kx be written to the netcdf file, set
&stella_diagnostics_knobs
write_kspectra = .true.
/
Flux contributions as a function of kx ky and z are available in stella with
&stella_diagnostics_knobs
write_fluxes_kxkyz = .true.
/
These are not supported to be read here as they are a function of tubes and zed in addition
to ky, kx.
"""
fluxes_dict = {
"particle": "pflux_vs_s",
"heat": "qflux_vs_s",
"momentum": "vflux_vs_s",
}
fluxes_dict_old = {
"particle": "pflx",
"heat": "qflx",
"momentum": "vflx",
}
# Get species names from input file
species = []
ion_num = 0
for idx in range(gk_input.data["species_knobs"]["nspec"]):
if gk_input.data[f"species_parameters_{idx+1}"]["z"] == -1:
species.append("electron")
else:
ion_num += 1
species.append(f"ion{ion_num}")
results = {}
coord_names = ["flux", "species", "kx", "ky", "time"]
fluxes = np.zeros([len(coords[name]) for name in coord_names])
for iflux, (pyro_flux, stella_flux) in enumerate(fluxes_dict.items()):
# total fluxes
flux_key = f"{stella_flux}"
# flux contributions by kx ky (averaged over z)
vskxky_key = f"{stella_flux}_vs_kxky"
# flux constributions by kx ky z
# vskxkyz_key = f"{stella_flux}_kxky"
# commented out as not (yet) supported
# if vskxkyz_key in raw_data.data_vars:
# key = vskxkyz_key
# flux = raw_data[key].transpose("species", "tube", "zed", "kx", "ky", "t")
if vskxky_key in raw_data.data_vars:
key = vskxky_key
flux = raw_data[key].transpose("species", "kx", "ky", "t")
elif fluxes_dict_old[pyro_flux] in raw_data.data_vars:
# coordinates from raw are (t,species)
# convert to (species, ky, t)
flux = raw_data[fluxes_dict_old[pyro_flux]]
flux = flux.expand_dims("ky").transpose("species", "ky", "t")
flux = flux.expand_dims("kx").transpose("species", "kx", "ky", "t")
elif flux_key in raw_data.data_vars:
# coordinates from raw are (t,species)
# convert to (species, ky, t)
flux = raw_data[flux_key]
flux = flux.expand_dims("ky").transpose("species", "ky", "t")
flux = flux.expand_dims("kx").transpose("species", "kx", "ky", "t")
else:
continue
fluxes[iflux, ...] = flux.data
if gk_input.is_linear() and gk_input.data["stella_diagnostics_knobs"].get(
"flux_norm", True
):
jacob = raw_data["jacob"].data
grho = raw_data["grho"].data
theta = raw_data["zed"].data
theta_append = 2 * theta[-1] - theta[-2]
dtheta = np.diff(theta, append=theta_append)
flux_norm = np.sum(jacob * dtheta) / np.sum(jacob * dtheta * grho)
else:
flux_norm = 1.0
for iflux, flux in enumerate(coords["flux"]):
if not np.all(fluxes[iflux, ...] == 0):
results[flux] = fluxes[iflux, ...] / flux_norm
return results
@staticmethod
def _get_eigenvalues(
raw_data: xr.Dataset, time_divisor: float
) -> Dict[str, np.ndarray]:
# should only be called if no field data were found
mode_frequency = raw_data.omega.isel(ri=0).transpose("kx", "ky", "t")
growth_rate = raw_data.omega.isel(ri=1).transpose("kx", "ky", "t")
return {
"mode_frequency": mode_frequency.data / time_divisor,
"growth_rate": growth_rate.data / time_divisor,
}