# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import fnmatch
import pickle
import re
import typing
import warnings
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any
import netCDF4
import numpy as np
import pandas as pd
from mat73 import loadmat
from numpy.typing import NDArray
from scipy.io import loadmat as sci_loadmat
from swvo.io.utils import enforce_utc_timezone
[docs]
def join_var(var1: NDArray[np.generic], var2: NDArray[np.generic]) -> NDArray[np.generic]:
"""Join two variables along the first axis."""
return np.concatenate((var1, var2), axis=0)
[docs]
def read_all_datasets_netcdf(file_path: str | Path) -> dict[str, Any]:
"""Reads all datasets (variables) from a NetCDF file, including those in groups.
This function recursively traverses all groups and variables in a NetCDF-4
file and stores their data in a dictionary. The key for each dataset is its
full hierarchical path.
Args:
file_path (str | Path): The path to the NetCDF file.
Returns:
Dict[str, Any]: A dictionary where keys are the full variable paths
and values are the corresponding NumPy arrays.
"""
datasets: dict[str, Any] = {}
file_path = Path(file_path)
def _read_all_recursively(group: netCDF4.Group | netCDF4.Dataset, path: str = ""):
for var_name, var_obj in group.variables.items():
full_path = f"{path}/{var_name}" if path else var_name
datasets[full_path] = var_obj[:]
for group_name, group_obj in group.groups.items():
new_path = f"{path}/{group_name}" if path else group_name
_read_all_recursively(group_obj, new_path)
if not file_path.exists():
print(f"File not found: {file_path}")
return {}
with netCDF4.Dataset(file_path, "r") as nc_file:
_read_all_recursively(nc_file)
return datasets
[docs]
def round_seconds(obj: datetime) -> datetime:
"""Round datetime object to the nearest second."""
if obj.microsecond >= 500_000:
obj += timedelta(seconds=1)
return obj.replace(microsecond=0)
[docs]
def python2matlab(datenum: datetime) -> float:
"""Convert Python datetime to MATLAB datenum."""
mdn = datenum + timedelta(days=366)
frac = (datenum - datetime(datenum.year, datenum.month, datenum.day, 0, 0, 0, tzinfo=timezone.utc)).seconds / (
24.0 * 60.0 * 60.0
)
return mdn.toordinal() + round(frac, 6)
[docs]
def matlab2python(datenum: float | Iterable[float]) -> Iterable[datetime] | datetime:
"""Convert MATLAB datenum to Python datetime."""
warnings.filterwarnings("ignore", message="Discarding nonzero nanoseconds in conversion")
datenum = np.asarray(datenum, dtype=float)
datenum = pd.to_datetime(datenum - 719529, unit="D", origin=pd.Timestamp("1970-01-01")).to_pydatetime() # ty:ignore[unresolved-attribute]
if isinstance(datenum, Iterable):
datenum = enforce_utc_timezone(list(datenum)) # ty:ignore[no-matching-overload]
datenum = [ # ty:ignore[invalid-assignment]
round_seconds(x) for x in datenum
]
else:
datenum = round_seconds(enforce_utc_timezone(datenum)) # ty:ignore[invalid-assignment]
return datenum
[docs]
def pol2cart(
theta: NDArray[np.float64], radius: NDArray[np.float64]
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
"""transforms polar coordinates theta (in rad) and radius to cartesian coordinates x, y"""
x = radius * np.cos(theta)
y = radius * np.sin(theta)
return (x, y)
[docs]
def cart2pol(x: NDArray[np.float64], y: NDArray[np.float64]) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
"""transforms cartesian coordinates x, y to polar coordinates theta (in rad) and radius"""
z = x + 1j * y
return np.angle(z), np.abs(z)