Source code for swvo.io.RBMDataSet.RBMDataSet

# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences
# SPDX-FileContributor: Bernhard Haas
# SPDX-FileContributor: Sahil Jhawar
#
# SPDX-License-Identifier: Apache-2.0

"""Combined RBM Dataset class supporting .mat, .pickle, and .nc file formats."""

from __future__ import annotations

import datetime as dt
from datetime import timedelta, timezone
from pathlib import Path
from typing import Any, Literal, cast

import distance
import numpy as np
from dateutil.relativedelta import relativedelta
from numpy.typing import NDArray

from swvo.io.exceptions import VariableNotFoundError
from swvo.io.RBMDataSet import (
    FileCadenceEnum,
    FolderTypeEnum,
    Instrument,
    InstrumentEnum,
    InstrumentLike,
    Mfm,
    MfmEnum,
    MfmLike,
    Satellite,
    SatelliteEnum,
    SatelliteLike,
    Variable,
    VariableEnum,
    VariableLiteral,
)
from swvo.io.RBMDataSet.custom_enums import MfmEnumLiteral
from swvo.io.RBMDataSet.utils import (
    get_file_path_any_format,
    join_var,
    load_file_any_format,
    matlab2python,
    read_all_datasets_netcdf,
)
from swvo.io.utils import enforce_utc_timezone


[docs] class RBMDataSet: """RBMDataSet class supporting .mat, .pickle, and .nc file formats. This unified class handles loading RBM (Radiation Belt Model) data from multiple file formats. It can load data either from files or from a dictionary. For file-based loading, provide `start_time`, `end_time`, and `folder_path`. For dictionary-based loading, initialize without these parameters and use `update_from_dict()`. Parameters ---------- satellite : Union[:class:`SatelliteLike`, :class:`DummyLike`] Satellite identifier as enum or string. instrument : Union[:class:`InstrumentLike`, :class:`DummyLike`] Instrument enumeration or string. mfm : Union[:class:`MfmLike`, :class:`DummyLike`] Magnetic field model enum or string. start_time : dt.datetime, optional Start time for file-based loading. end_time : dt.datetime, optional End time for file-based loading. folder_path : Path, optional Base folder path for file-based loading. preferred_extension : Literal["mat", "pickle", "nc"], optional Preferred file extension for file-based loading. Default is "pickle". verbose : bool, optional Whether to print verbose output. Default is True. enable_dict_loading : bool, optional Enable dictionary-based loading even in file mode. Default is False. Attributes ---------- datetime : list[dt.datetime] time : NDArray[np.float64] energy_channels : NDArray[np.float64] alpha_local : NDArray[np.float64] alpha_eq_model : NDArray[np.float64] alpha_eq_real : NDArray[np.float64] InvMu : NDArray[np.float64] InvMu_real : NDArray[np.float64] InvK : NDArray[np.float64] InvV : NDArray[np.float64] Lstar : NDArray[np.float64] Flux : NDArray[np.float64] PSD : NDArray[np.float64] MLT : NDArray[np.float64] B_SM : NDArray[np.float64] B_total : NDArray[np.float64] B_sat : NDArray[np.float64] xGEO : NDArray[np.float64] P : NDArray[np.float64] R0 : NDArray[np.float64] density : NDArray[np.float64] """ _preferred_ext: Literal["mat", "pickle", "nc"] datetime: list[dt.datetime] time: NDArray[np.float64] energy_channels: NDArray[np.float64] alpha_local: NDArray[np.float64] alpha_eq_model: NDArray[np.float64] alpha_eq_real: NDArray[np.float64] InvMu: NDArray[np.float64] InvMu_real: NDArray[np.float64] InvK: NDArray[np.float64] InvV: NDArray[np.float64] Lstar: NDArray[np.float64] Flux: NDArray[np.float64] PSD: NDArray[np.float64] MLT: NDArray[np.float64] B_SM: NDArray[np.float64] B_total: NDArray[np.float64] B_sat: NDArray[np.float64] xGEO: NDArray[np.float64] P: NDArray[np.float64] R0: NDArray[np.float64] density: NDArray[np.float64]
[docs] def __init__( self, satellite: SatelliteLike, instrument: InstrumentLike, mfm: MfmLike, start_time: dt.datetime | None = None, end_time: dt.datetime | None = None, folder_path: Path | None = None, preferred_extension: Literal["mat", "pickle", "nc"] = "nc", *, verbose: bool = True, enable_dict_loading: bool = False, ) -> None: self.possible_variables: list[str] = list(VariableLiteral.__args__) # Handle satellite conversion with special cases for GOES if isinstance(satellite, str): if satellite.lower() == "goesprimary": satellite = SatelliteEnum["GOESPrimary"] elif satellite.lower() == "goessecondary": satellite = SatelliteEnum["GOESSecondary"] else: satellite = SatelliteEnum[satellite.upper()] if isinstance(instrument, str): instrument = InstrumentEnum[instrument.upper()] if isinstance(mfm, str): mfm = MfmEnum[mfm.upper()] # Validate preferred_extension if preferred_extension not in ("mat", "pickle", "nc"): msg = f"preferred_extension must be 'mat', 'pickle', or 'nc', got '{preferred_extension}'" raise ValueError(msg) # Store the original satellite enum for properties and other attributes self._satellite = satellite self._instrument = instrument self._mfm = mfm self._verbose = verbose self._preferred_ext = preferred_extension # For dict-based loading, modify satellite properties if start_time is None and end_time is None and folder_path is None: self._file_loading_mode = False else: # File loading mode: need all parameters if start_time is None or end_time is None or folder_path is None: msg = "For file-based loading, start_time, end_time, and folder_path must all be provided" raise ValueError(msg) start_time = enforce_utc_timezone(start_time) end_time = enforce_utc_timezone(end_time) self._start_time = start_time self._end_time = end_time self._satellite = satellite self._folder_path = Path(folder_path) self._folder_type = self._satellite.folder_type self._file_path_stem = self._create_file_path_stem() self._is_nc_dataset = self._check_if_nc_dataset() self._file_name_stem = self._create_file_name_stem() self._file_cadence = self._satellite.file_cadence self._date_of_files = self._create_date_list() self._file_loading_mode = True self._enable_dict_loading = enable_dict_loading self._netcdf_dataset_cache: dict[Path, dict[str, Any]] = {}
def __repr__(self) -> str: return f"{self.__class__.__name__}({self._satellite}, {self._instrument}, {self._mfm})" def __str__(self) -> str: return self.__repr__() def __dir__(self) -> list[str]: return list(super().__dir__()) + [var.var_name for var in VariableEnum] def __getattr__(self, name: str) -> NDArray[np.float64]: # Avoid recursion for internal attributes if name.startswith("_"): raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") # Handle computed properties for both modes if name == "P": if len(self.MLT) == 0: # MLT not found self.P = np.asarray([]) else: self.P = ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) return self.P if name == "InvV": if len(self.InvK) == 0 or len(self.InvMu) == 0: # invariants not found self.InvV = np.asarray([]) else: inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) self.InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2 return self.InvV # check if a sat variable is requested # if we find a similar word, suggest that to the user sat_variable = None sat_variable, levenstein_info = self.find_similar_variable(name) if sat_variable is not None and self._file_loading_mode: self._load_variable(sat_variable) return getattr(self, name) if not self._file_loading_mode and name in self.possible_variables: raise AttributeError( f"Attribute '{name}' exists in `VariableLiteral` but has not been set. " "Call `update_from_dict()` before accessing it." ) if levenstein_info["min_distance"] <= 2: msg = f"{self.__class__.__name__} object has no attribute {name}. Maybe you meant {levenstein_info['var_name']}?" else: msg = f"{self.__class__.__name__} object has no attribute {name}" raise AttributeError(msg)
[docs] def load(self, name_or_var: str | VariableEnum) -> None: """Load data into memory""" if isinstance(name_or_var, VariableEnum): getattr(self, name_or_var.var_name) else: getattr(self, name_or_var)
def find_similar_variable(self, name: str) -> tuple[None | VariableEnum, dict[str, Any]]: levenstein_info: dict[str, Any] = {"min_distance": 10, "var_name": ""} sat_variable = None for var in VariableEnum: if name == var.var_name: sat_variable = var break else: dist = distance.levenshtein(name, var.var_name) if name.lower() in var.var_name.lower(): dist = 1 if dist < levenstein_info["min_distance"]: levenstein_info["min_distance"] = dist levenstein_info["var_name"] = var.var_name return sat_variable, levenstein_info @property def satellite(self) -> SatelliteEnum | Satellite: """Returns the satellite enum.""" return self._satellite @property def instrument(self) -> InstrumentEnum | Instrument: """Returns the instrument enum.""" return self._instrument @property def mfm(self) -> MfmEnum | Mfm: """Returns the MFM enum.""" return self._mfm
[docs] def update_from_dict( self, source_dict: dict[VariableLiteral, NDArray[np.floating] | list[dt.datetime]] ) -> RBMDataSet: """Get data from data dictionary and update the object. Parameters ---------- source_dict : dict[str, VariableLiteral] Dictionary containing the data to be loaded into the object. Returns ------- RBMDataSet The updated RBMDataSet object. Raises ------ VariableNotFoundError If a key in the `source_dict` is not a valid `VariableLiteral`. RuntimeError If the RBMDataSet is in file loading mode and dictionary loading is not enabled. """ if self._file_loading_mode and not self._enable_dict_loading: msg = "RBMDataSet is in file loading mode. Cannot update from dictionary. To use dictionary-based loading, set `enable_dict_loading=True` during initialization." raise RuntimeError(msg) for key, value in source_dict.items(): _, levenstein_info = self.find_similar_variable(key) if key in self.possible_variables: setattr(self, key, value) elif levenstein_info["min_distance"] <= 2: msg = f"Key '{key}' is not a valid `VariableLiteral`. Maybe you meant '{levenstein_info['var_name']}'?" raise VariableNotFoundError(msg) else: msg = f"Key '{key}' is not a valid `VariableLiteral`." raise VariableNotFoundError(msg) return self
def get_var(self, var: VariableEnum) -> NDArray[np.float64]: return getattr(self, var.var_name) def _check_if_nc_dataset(self) -> bool: does_processed_mat_files_folder_exist = (self._file_path_stem / "Processed_Mat_Files").exists() if does_processed_mat_files_folder_exist and self._preferred_ext in ["mat", "pickle"]: return False elif does_processed_mat_files_folder_exist and self._preferred_ext == "nc": # if any .nc files are stored in the file_path_stem, we switch to nc mode return next(self._file_path_stem.glob("*.nc"), None) is not None else: # if the Processed_Mat_Files folder does not exist, it is safe to assume nc mode return True def _create_date_list(self) -> list[dt.datetime]: match self._file_cadence: case FileCadenceEnum.Daily: time_delta = timedelta(days=1) case FileCadenceEnum.Monthly: time_delta = relativedelta(months=1) case _: msg = "Encounterd invalid file cadence!" raise ValueError(msg) start_date = self._start_time.date() date_of_files = np.asarray([dt.datetime(start_date.year, start_date.month, 1, tzinfo=timezone.utc)]) while (date_of_files[-1] + time_delta) < self._end_time: date_of_files = np.append(date_of_files, date_of_files[-1] + time_delta) return list(date_of_files) def _create_file_path_stem(self) -> Path: """Create the file path stem based on format and folder type.""" if self._folder_type == FolderTypeEnum.DataServer: return self._folder_path / self._satellite.mission / self._satellite.sat_name if self._folder_type == FolderTypeEnum.SingleFolder: return self._folder_path msg = "Encountered invalid FolderTypeEnum!" raise ValueError(msg) def _create_file_name_stem(self) -> str: """Create the file name stem.""" return self._satellite.sat_name + "_" + self._instrument.instrument_name + "_" def get_satellite_name(self) -> str: return self._satellite.sat_name def get_satellite_and_instrument_name(self) -> str: return self._satellite.sat_name + "_" + self._instrument.instrument_name def set_file_path_stem(self, file_path_stem: Path) -> RBMDataSet: self._file_path_stem = file_path_stem return self def set_file_name_stem(self, file_name_stem: Path) -> RBMDataSet: self._file_path_stem = file_name_stem return self def set_file_cadence(self, file_cadence: FileCadenceEnum) -> RBMDataSet: self._file_cadence = file_cadence self._date_of_files = self._create_date_list() return self def get_print_name(self) -> str: return self._satellite.sat_name + " " + self._instrument.instrument_name def _load_variable(self, var: Variable | VariableEnum) -> None: """Load variable from .mat, .pickle, or .nc files.""" loaded_var_arrs: dict[str, NDArray[np.number]] = {} var_names_stored: list[str] = [] # 1. Handle Computed Values if isinstance(var, VariableEnum): if var == VariableEnum.INV_V: inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) self.InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2 return if var == VariableEnum.P: self.P = ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) return # 2. Iterate through date ranges for date in self._date_of_files: if self._folder_type != FolderTypeEnum.DataServer: raise NotImplementedError("Only DataServer folder type is currently supported.") # Construct date string start_month = date.replace(day=1) next_month = start_month + relativedelta(months=1, days=-1) date_str = f"{start_month.strftime('%Y%m%d')}to{next_month.strftime('%Y%m%d')}" # 3. Handle File Pathing & Loading based on format if self._is_nc_dataset: file_name = f"{self._file_name_stem}{date_str}_{self._mfm.mfm_name}.nc" full_file_path = self._file_path_stem / file_name file_content = self._get_cached_datasets_netcdf(full_file_path) else: file_name_no_format = f"{self._file_name_stem}{date_str}_{var.mat_file_prefix}" if var.mat_has_B: file_name_no_format += f"_n4_4_{self._mfm.mfm_name}" file_name_no_format += "_ver4" full_file_path = get_file_path_any_format( self._file_path_stem, file_name_no_format, self._preferred_ext, self._is_nc_dataset ) if full_file_path is None: print(f"File not found: {file_name_no_format}") continue if self._verbose: print(f"\tLoading {full_file_path}") file_content = load_file_any_format(full_file_path) if not file_content: continue # 4. Process Datetimes raw_times = file_content["time"] if self._is_nc_dataset: # NetCDF timestamp logic datetimes = np.asarray( [dt.datetime.fromtimestamp(t.astype(np.int64), tz=dt.timezone.utc) for t in raw_times] ) else: # Matlab logic datetimes = np.asarray([matlab2python(t) for t in raw_times]) file_content["datetime"] = datetimes correct_time_idx = (datetimes >= self._start_time) & (datetimes <= self._end_time) # 5. Filter and Join Arrays for key, var_arr in file_content.items(): # Skip non-numeric metadata (excluding our new datetime) if key != "datetime" and ( not isinstance(var_arr, np.ndarray) or not np.issubdtype(var_arr.dtype, np.number) ): continue var_arr = cast("NDArray[np.number]", var_arr) # Time-dependent trimming if var_arr.shape[0] == correct_time_idx.shape[0]: var_arr = var_arr[correct_time_idx.reshape(-1), ...] joined_value = join_var(loaded_var_arrs[key], var_arr) if key in loaded_var_arrs else var_arr else: joined_value = var_arr loaded_var_arrs[key] = joined_value # type: ignore if key not in var_names_stored: var_names_stored.append(key) # 6. Final Assignment to Self if var.var_name not in var_names_stored: setattr(self, var.var_name, np.asarray([])) for var_name in var_names_stored: val = list(loaded_var_arrs[var_name]) if var_name == "datetime" else loaded_var_arrs[var_name] if self._is_nc_dataset: # NetCDF name mapping logic rbm_names = self._get_rbm_name_for_nc(var_name, self._mfm.mfm_name) # type: ignore if rbm_names: for name in rbm_names if isinstance(rbm_names, list) else [rbm_names]: setattr(self, name, val) else: setattr(self, var_name, val) def _get_cached_datasets_netcdf(self, file_path: Path) -> dict[str, Any]: """Return cached parsed NetCDF content for a monthly file.""" file_path = Path(file_path) if file_path not in self._netcdf_dataset_cache: if self._verbose: print(f"\tLoading {file_path}") self._netcdf_dataset_cache[file_path] = read_all_datasets_netcdf(file_path) return self._netcdf_dataset_cache[file_path] @classmethod def _get_rbm_name_for_nc( cls, var_name: str, mag_field: MfmEnumLiteral ) -> VariableLiteral | None | list[VariableLiteral]: """Map NetCDF variable names to RBM variable names.""" match var_name: case "time": return "time" case "datetime": return "datetime" case "flux/FEDU": return ["Flux", "FEDU"] case "flux/FEIU": return ["Flux", "FEIU"] case "flux/alpha_eq": return "alpha_eq_model" case "flux/energy": return "energy_channels" case "flux/alpha_local": return "alpha_local" case "position/xGEO": return "xGEO" case _ if var_name == f"position/{mag_field}/MLT": return "MLT" case _ if var_name == f"position/{mag_field}/R0": return "R0" case _ if var_name == f"position/{mag_field}/Lstar": return "Lstar" case _ if var_name == f"position/{mag_field}/Lm": return "Lm" case _ if var_name == f"mag_field/{mag_field}/B_local": return "B_total" case "psd/PSD": return "PSD" case _ if var_name == f"psd/{mag_field}/inv_mu": return "InvMu" case _ if var_name == f"psd/{mag_field}/inv_K": return "InvK" case "density/density_local": return "density" case _: return None
[docs] def get_loaded_variables(self) -> list[str]: """Get a list of currently loaded variable names.""" loaded_vars = [] for var in VariableEnum: if var.var_name in self.__dict__: loaded_vars.append(var.var_name) return loaded_vars
def __eq__(self, other: RBMDataSet) -> bool: # ty :ignore[invalid-method-override] if ( self._file_loading_mode != other._file_loading_mode or self._satellite != other._satellite or self._instrument != other._instrument or self._mfm != other._mfm ): return False different_vars = self.get_different_variables(other) return len(different_vars) == 0 def get_different_variables(self, rbm_other: RBMDataSet) -> list[str]: different_vars: list[str] = [] self_vars = self.get_loaded_variables() other_vars = rbm_other.get_loaded_variables() for var in set(self_vars + other_vars): if var not in other_vars or var not in self_vars: different_vars.append(var) continue self_var = getattr(self, var) other_var = getattr(rbm_other, var) if not isinstance(other_var, type(self_var)): different_vars.append(var) continue if isinstance(self_var, list): if len(self_var) != len(other_var) or any(a != b for a, b in zip(self_var, other_var)): different_vars.append(var) continue elif isinstance(self_var, np.ndarray): if self_var.shape != other_var.shape or not np.allclose(self_var, other_var, equal_nan=True): different_vars.append(var) continue elif self_var != other_var: different_vars.append(var) continue return different_vars from .bin_and_interpolate_to_model_grid import bin_and_interpolate_to_model_grid # noqa: I001 from .identify_orbits import identify_orbits from .interp_functions import interp_flux, interp_psd from .linearize_trajectories import linearize_trajectories