# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import itertools
import time
from collections.abc import Iterable
from enum import Enum
from functools import partial
from multiprocessing import Pool
from typing import TYPE_CHECKING, Literal, TypeAlias, cast
import numpy as np
from numpy.typing import NDArray
from tqdm import tqdm
from swvo.io.RBMDataSet import RBMDataSet
[docs]
class TargetType(Enum): # noqa: D101
TargetPairs = 0
TargetMeshGrid = 1
TARGETS: TypeAlias = list[tuple[float | int, float | int]]
def _linear_interp(
flux_left: float,
flux_right: float,
target_value: float,
left_value: float,
right_value: float,
) -> float:
a = (target_value - left_value) / (right_value - left_value)
return flux_left + a * (flux_right - flux_left)
def _interp_flux_parallel(
flux: NDArray[np.float64],
energy: NDArray[np.float64],
alpha_eq_model: NDArray[np.float64],
targets: list[tuple[float, float]],
it: int,
) -> list[float]:
result: list[float] = []
for _, (target_en_single, target_al_single) in enumerate(targets):
# find left and right alpha indices
# first find the two al levels, where en points must exist
al_right_idx = np.searchsorted(alpha_eq_model[it, :], target_al_single, side="right")
al_left_idx = al_right_idx - 1
if al_right_idx == 0 or al_right_idx >= len(alpha_eq_model[it, :]):
result.append(np.nan)
continue
finite_idx = np.argwhere(np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_left_idx]))
if finite_idx.size == 0:
result.append(np.nan)
continue
energy_interp = np.squeeze(energy[it, finite_idx])
flux_interp = np.squeeze(flux[it, finite_idx, al_left_idx])
assert np.all(np.diff(energy_interp) > 0)
flux_left = float(np.interp(target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan))
finite_idx = np.argwhere(np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_right_idx]))
if finite_idx.size == 0:
result.append(np.nan)
continue
energy_interp = np.squeeze(energy[it, finite_idx])
flux_interp = np.squeeze(flux[it, finite_idx, al_right_idx])
assert np.all(np.diff(energy_interp) > 0)
flux_right = float(np.interp(target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan))
result.append(
_linear_interp(
flux_left,
flux_right,
target_al_single,
alpha_eq_model[it, al_left_idx],
alpha_eq_model[it, al_right_idx],
)
)
return result
def interp_flux( # noqa: D103
self: RBMDataSet,
target_en: float | list[float] | NDArray[np.float64],
target_al: float | list[float],
target_type: TargetType | Literal["TargetPairs", "TargetMesh"],
n_threads: int = 10,
) -> NDArray[np.float64]:
if not isinstance(target_en, Iterable):
target_en = [target_en]
if not isinstance(target_al, Iterable):
target_al = [target_al]
if isinstance(target_type, str):
target_type = TargetType[target_type]
if target_type == TargetType.TargetPairs:
assert len(target_en) == len( # ty:ignore[invalid-argument-type]
target_al # ty:ignore[invalid-argument-type]
), "For TargetType.Pairs, the target vectors must have the same size!"
result_arr = np.empty((len(self.time), len(target_en))) # ty:ignore[invalid-argument-type]
targets = cast("TARGETS", list(zip(target_en, target_al, strict=False)))
else:
result_arr = np.empty((len(self.time), len(target_en), len(target_al))) # ty:ignore[invalid-argument-type]
targets = cast("TARGETS", list(itertools.product(target_en, target_al)))
func = partial(
_interp_flux_parallel,
self.Flux,
self.energy_channels,
self.alpha_eq_model,
targets,
)
with Pool(n_threads) as p:
rs = p.map_async(func, range(len(self.time)))
# display progress bar if verbose
if self._verbose:
total_elements = rs._number_left # ty:ignore[unresolved-attribute]
with tqdm(total=total_elements) as t:
while True:
if rs.ready():
break
t.n = total_elements - rs._number_left # ty:ignore[unresolved-attribute]
t.refresh()
time.sleep(1)
else:
rs.wait()
parallel_results = rs.get()
if isinstance(parallel_results, Exception):
raise parallel_results
for i in range(result_arr.shape[0]):
if target_type == TargetType.TargetPairs:
for t, _ in enumerate(targets):
result_arr[i, t] = parallel_results[i][t]
else:
for ie, ia in itertools.product(
range(len(target_en)), # ty:ignore[invalid-argument-type]
range(len(target_al)), # ty:ignore[invalid-argument-type]
):
result_arr[i, ie, ia] = parallel_results[i][ie * len(target_al) + ia] # ty:ignore[invalid-argument-type]
return result_arr
def _interp_psd_parallel(
psd: NDArray[np.float64],
invmu: NDArray[np.float64],
invk: NDArray[np.float64],
targets: list[tuple[float, float]],
it: int,
) -> list[float]:
"""Interpolate PSD at time index `it` to (mu_target, K_target) pairs in `targets`.
Shapes per time slice:
psd[it] -> (nE, nA)
invmu[it] -> (nE, nA)
invk[it] -> (nA,)
"""
out: list[float] = []
# ---- 0) Extract this time slice
psd_i = psd[it, :, :] # (nE, nA)
mu_i = invmu[it, :, :] # (nE, nA)
K_row = invk[it, :] # (nA,)
# ---- 1) Drop NaN K bins and the corresponding columns in PSD/mu
finite_k = np.isfinite(K_row)
if not np.any(finite_k):
# No valid K at this time -> all NaN
return [np.nan] * len(targets)
K_use = K_row[finite_k] # (nA_valid,)
psd_use = psd_i[:, finite_k] # (nE, nA_valid)
mu_use = mu_i[:, finite_k] # (nE, nA_valid)
# If after masking we have fewer than 2 K points, we cannot bracket
if K_use.size < 2:
return [np.nan] * len(targets)
# ---- 2) Ensure K ascending for searchsorted; if descending, flip columns
if K_use[1] < K_use[0]:
K_use = K_use[::-1]
psd_use = psd_use[:, ::-1]
mu_use = mu_use[:, ::-1]
# ---- 3) For each (mu*, K*) target: 1D along mu, then linear across K
for _, (mu_t, K_t) in enumerate(targets):
# 3a) Bracket in K
k_right = np.searchsorted(K_use, K_t, side="right")
k_left = k_right - 1
if k_right == 0 or k_right >= K_use.size:
out.append(np.nan)
continue
# 3b) Interp along mu at LEFT K
mu_L = mu_use[:, k_left]
psd_L = psd_use[:, k_left]
okL = np.isfinite(mu_L) & np.isfinite(psd_L)
if not np.any(okL):
out.append(np.nan)
continue
xL = np.asarray(mu_L[okL], dtype=float)
yL = np.asarray(psd_L[okL], dtype=float)
if xL.size < 2:
out.append(np.nan)
continue
if not np.all(np.diff(xL) > 0):
order = np.argsort(xL)
xL, yL = xL[order], yL[order]
xL, idx = np.unique(xL, return_index=True)
yL = yL[idx]
if xL.size < 2:
out.append(np.nan)
continue
psd_left = float(np.interp(mu_t, xL, yL, left=np.nan, right=np.nan))
# 3c) Interp along mu at RIGHT K
mu_R = mu_use[:, k_right]
psd_R = psd_use[:, k_right]
okR = np.isfinite(mu_R) & np.isfinite(psd_R)
if not np.any(okR):
out.append(np.nan)
continue
xR = np.asarray(mu_R[okR], dtype=float)
yR = np.asarray(psd_R[okR], dtype=float)
if xR.size < 2:
out.append(np.nan)
continue
if not np.all(np.diff(xR) > 0):
order = np.argsort(xR)
xR, yR = xR[order], yR[order]
xR, idx = np.unique(xR, return_index=True)
yR = yR[idx]
if xR.size < 2:
out.append(np.nan)
continue
psd_right = float(np.interp(mu_t, xR, yR, left=np.nan, right=np.nan))
if not np.isfinite(psd_left) or not np.isfinite(psd_right):
out.append(np.nan)
continue
# 3d) Linear across K to K_t
val = _linear_interp(psd_left, psd_right, K_t, K_use[k_left], K_use[k_right])
out.append(val)
return out
[docs]
def interp_psd(
self: RBMDataSet,
target_mu: float | list[float] | NDArray[np.float64],
target_K: float | list[float] | NDArray[np.float64],
target_type: TargetType | Literal["TargetPairs", "TargetMesh"],
n_threads: int = 10,
) -> NDArray[np.float64]:
"""Interpolate PSD to requested (mu, K) targets for every time.
Output shapes (matching interp_flux semantics):
- TargetPairs -> (time, N)
- TargetMeshGrid -> (time, n_mu, n_K)
"""
if not isinstance(target_mu, Iterable):
target_mu = [target_mu]
if not isinstance(target_K, Iterable):
target_K = [target_K]
if isinstance(target_type, str):
target_type = TargetType[target_type]
if target_type == TargetType.TargetPairs:
assert len(target_mu) == len(target_K), "For TargetType.Pairs, mu and K vectors must have the same size!" # ty:ignore[invalid-argument-type]
result_arr = np.empty((len(self.time), len(target_mu))) # ty:ignore[invalid-argument-type]
targets = cast("TARGETS", list(zip(target_mu, target_K, strict=False)))
else:
result_arr = np.empty((len(self.time), len(target_mu), len(target_K))) # ty:ignore[invalid-argument-type]
targets = cast("TARGETS", list(itertools.product(target_mu, target_K)))
# ensure needed fields are loaded (triggers lazy loader if any)
_ = self.PSD
_ = self.InvMu
_ = self.InvK
# parallel over time (same pattern as interp_flux)
func = partial(_interp_psd_parallel, self.PSD, self.InvMu, self.InvK, targets)
with Pool(n_threads) as p:
rs = p.map_async(func, range(len(self.time)))
if self._verbose:
total_elements = rs._number_left # ty:ignore[unresolved-attribute]
with tqdm(total=total_elements) as t:
while True:
if rs.ready():
break
t.n = total_elements - rs._number_left # ty:ignore[unresolved-attribute]
t.refresh()
time.sleep(1)
else:
rs.wait()
parallel_results = rs.get()
if isinstance(parallel_results, Exception):
raise parallel_results
# pack results back like interp_flux
if target_type == TargetType.TargetPairs:
for i in range(result_arr.shape[0]):
for t, _ in enumerate(targets):
result_arr[i, t] = parallel_results[i][t]
else:
n_mu, n_K = len(target_mu), len(target_K) # ty:ignore[invalid-argument-type]
for i in range(result_arr.shape[0]):
for im, iK in itertools.product(range(n_mu), range(n_K)):
result_arr[i, im, iK] = parallel_results[i][im * n_K + iK]
return result_arr