# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences
#
# SPDX-License-Identifier: Apache-2.0
import logging
from datetime import datetime, timezone
from typing import Optional, overload
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter1d
logger = logging.getLogger(__name__)
[docs]
def any_nans(data: list[pd.DataFrame] | pd.DataFrame) -> bool:
"""Calculate if a list of data frames contains any nans.
Parameters
----------
data : list[pd.DataFrame] | pd.DataFrame
Data frame or list of data frames to process
Returns
-------
bool
Bool if any data frame of the list contains any nan values
"""
if isinstance(data, list):
for df in data:
_ = nan_percentage(df)
if isinstance(data, pd.DataFrame):
_ = nan_percentage(data)
return any((df.isna().any(axis=None) > 0) for df in data)
[docs]
def nan_percentage(data: pd.DataFrame) -> float:
"""Calculate the percentage of NaN values in the data column of data frame and log it.
Parameters
----------
data : pd.DataFrame
The data frame to process
Returns
-------
float
Nan percentage in the data frame
"""
float_columns = data.select_dtypes(include=["float64", "float32"]).columns
nan_percentage = (data[float_columns].isna().sum().sum() / (data.shape[0])) * 100
logger.info(f"Percentage of NaNs in data frame: {nan_percentage:.2f}%")
return nan_percentage
[docs]
def construct_updated_data_frame(
data: list[pd.DataFrame] | pd.DataFrame,
data_one_model: list[pd.DataFrame] | pd.DataFrame,
model_label: str,
) -> list[pd.DataFrame]:
"""
Construct an updated data frame providing the previous data frame and the data frame of the current model call.
Also adds the model label to the data frame.
Parameters
----------
data : list[pd.DataFrame] | pd.DataFrame
The data frame or list of data frames to update.
data_one_model : list[pd.DataFrame] | pd.DataFrame
The data frame or list of data frames from the current model call.
model_label : str
The label of the model to add to the data frame.
Returns
-------
list[pd.DataFrame]
The updated data frame or list of data frames with the model label added.
"""
if isinstance(data_one_model, list) and data_one_model == []: # nothing to update
return data
if isinstance(data_one_model, pd.DataFrame):
data_one_model = [data_one_model]
if isinstance(data, pd.DataFrame):
data = [data]
# extend the data we have read so far to match the new ensemble numbers
if len(data) == 1 and len(data_one_model) > 1:
# Use copies to avoid all list elements sharing the same DataFrame reference.
# `[data[0]] * n` would create n aliases to the same object, causing in-place
# mutations (e.g. loc assignments) to bleed across all ensemble members.
data = [data[0].copy() for _ in range(len(data_one_model))]
elif len(data) != len(data_one_model):
msg = f"Tried to combine models with different ensemble numbers: {len(data)} and {len(data_one_model)}!"
raise ValueError(msg)
for i, _ in enumerate(data_one_model):
if "model" in data_one_model[i].columns:
mask_not_interpolated = data_one_model[i]["model"] != "interpolated"
data_one_model[i].loc[mask_not_interpolated, "model"] = model_label
mask_nan = data_one_model[i].isna().any(axis=1)
data_one_model[i].loc[mask_nan, "model"] = None
else:
data_one_model[i]["model"] = model_label
data_one_model[i].loc[data_one_model[i].isna().any(axis=1), "model"] = None
if "file_name" in data_one_model[i].columns:
data_one_model[i].loc[data_one_model[i]["file_name"].notna(), "model"] = model_label
data_one_model[i].loc[data_one_model[i]["file_name"].isna(), "model"] = None
if data[i].empty:
data[i] = data_one_model[i]
empty_idx = data[i].index[data[i].isna().all(axis=1)]
data[i].loc[empty_idx] = (
data[i].loc[empty_idx].combine_first(data_one_model[i].reindex(data[i].index).loc[empty_idx])
)
return data
[docs]
def datenum(
date_input: datetime | int,
month: Optional[int] = None,
year: Optional[int] = None,
hour: int = 0,
minute: int = 0,
seconds: int = 0,
) -> float:
"""Convert a date to a MATLAB serial date number.
Parameters
----------
date_input : datetime | int
A datetime object or an integer representing the day of the month.
month : int, optional
The month of the date. Required if date_input is an integer.
year : int, optional
The year of the date. Required if date_input is an integer.
hour : int
The hour of the date, by default 0
minute : int
The minute of the date, by default 0
seconds : int
The seconds of the date, by default 0
Returns
-------
float
The MATLAB serial date number.
Raises
------
ValueError
If the input is invalid, i.e., if date_input is an integer and month or year is not provided.
"""
MATLAB_EPOCH = datetime.toordinal(datetime(1970, 1, 1, tzinfo=timezone.utc)) + 366
if isinstance(date_input, datetime):
dt = enforce_utc_timezone(date_input)
elif month is not None and year is not None:
dt = enforce_utc_timezone(
datetime(
year=year,
month=month,
day=date_input,
hour=hour,
minute=minute,
second=seconds,
)
)
else:
raise ValueError("Invalid input. Provide either a datetime object or year, month, and day.")
return dt.timestamp() / 86400 + MATLAB_EPOCH
[docs]
def datestr(datenum: float) -> str:
"""
Convert MATLAB datenum to a formatted date string.
Parameters
----------
datenum : float
The MATLAB datenum to convert.
Returns
-------
str
The formatted date string in the format "YYYYMMDDHHMM00".
"""
MATLAB_EPOCH = datetime.toordinal(datetime(1970, 1, 1, tzinfo=timezone.utc)) + 366
unix_days = datenum - MATLAB_EPOCH
unix_timestamp = unix_days * 86400
dt = datetime.fromtimestamp(unix_timestamp, tz=timezone.utc)
formatted_date = dt.strftime("%Y%m%d%H%M")
return formatted_date
[docs]
def sw_mag_propagation(sw_data: pd.DataFrame) -> pd.DataFrame:
"""
Propagate the solar wind magnetic field to the bow shock and magnetopause.
Parameters
----------
sw_data : pd.DataFrame
Data frame containing solar wind data with a 'speed' column.
Returns
-------
pd.DataFrame
Data frame with propagated solar wind data, indexed by time.
"""
sw_data["t"] = [t.timestamp() for t in sw_data.index.to_pydatetime()] # ty: ignore[unresolved-attribute]
sw_data = sw_data.dropna(how="any")
distance = 1.5e6
shifted_time = distance / sw_data["speed"]
shifted_time_smooth = gaussian_filter1d(np.array(shifted_time.values, dtype=np.float64), sigma=5)
new_time_smooth = sw_data["t"] + shifted_time_smooth
stdate = sw_data["t"].min()
endate = new_time_smooth.max()
full_time_range = pd.date_range(
pd.to_datetime(sw_data["t"].min(), unit="s", utc=True).floor("min"),
pd.to_datetime(new_time_smooth.max(), unit="s", utc=True).floor("min"),
freq="1min",
tz="UTC",
)
valid = (new_time_smooth >= stdate) & (new_time_smooth <= endate)
sw_data = sw_data[valid]
new_time_smooth = new_time_smooth[valid]
valid = np.diff(new_time_smooth, prepend=new_time_smooth.iloc[0]) > 0
sw_data = sw_data[valid]
new_time_smooth = new_time_smooth[valid]
sw_data["t"] = new_time_smooth
sw_data = sw_data.dropna()
sw_data["t"] = pd.to_datetime(sw_data["t"], unit="s", utc=True)
sw_data.index = sw_data["t"]
sw_data.index = sw_data.index.round("min")
sw_data = sw_data[~sw_data.index.duplicated(keep="first")]
sw_data = sw_data.reindex(full_time_range)
sw_data = sw_data.drop(columns=["t"])
return sw_data
@overload
def enforce_utc_timezone(time: datetime) -> datetime: ...
@overload
def enforce_utc_timezone(time: list[datetime]) -> list[datetime]: ...
@overload
def enforce_utc_timezone(time: pd.Timestamp) -> pd.Timestamp: ...
@overload
def enforce_utc_timezone(time: pd.Series) -> pd.Series: ...
@overload
def enforce_utc_timezone(time: pd.DatetimeIndex) -> pd.DatetimeIndex: ...
@overload
def enforce_utc_timezone(time: pd.Index) -> pd.Index: ...
[docs]
def enforce_utc_timezone(time: datetime | list[datetime] | pd.Timestamp | pd.Series | pd.DatetimeIndex | pd.Index):
"""
Ensure datetime object(s) have UTC timezone information.
If the provided datetime object(s) are naive (lack timezone info),
UTC timezone is assigned. If they already have a timezone, they are
converted to UTC.
Parameters
----------
time : datetime, Iterable[datetime], pd.Timestamp, pd.Series, or pd.DatetimeIndex
The datetime object(s) to process. Can be:
- Single datetime.datetime object
- List of datetime.datetime objects
- Single pandas.Timestamp object
- pandas.Series with datetime64 dtype
- pandas.DatetimeIndex
Returns
-------
datetime, list of datetime, pd.Timestamp, pd.Series, or pd.DatetimeIndex
The datetime object(s) in UTC timezone.
Returns the same type as input.
Notes
-----
- For naive datetimes, this function assumes the times are already in UTC and simply adds the timezone information
- For timezone-aware datetimes, conversion to UTC is performed
- When processing pandas objects, the operation is vectorized for efficiency
"""
if isinstance(time, pd.Series):
if pd.api.types.is_datetime64_any_dtype(time):
if time.dt.tz is None:
return time.dt.tz_localize("UTC")
else:
return time.dt.tz_convert("UTC")
else:
raise TypeError(f"Series must have datetime64 dtype, got {time.dtype}")
elif isinstance(time, pd.DatetimeIndex):
if time.tz is None:
return time.tz_localize("UTC")
else:
return time.tz_convert("UTC")
elif isinstance(time, pd.Timestamp):
if time.tz is None:
return time.tz_localize("UTC")
else:
return time.tz_convert("UTC")
elif isinstance(time, list):
return [
(
dt.replace(tzinfo=timezone.utc)
if isinstance(dt, datetime) and dt.tzinfo is None
else dt.astimezone(timezone.utc)
if isinstance(dt, datetime)
else dt
)
for dt in time
]
elif isinstance(time, datetime):
if time.tzinfo is None:
return time.replace(tzinfo=timezone.utc)
else:
return time.astimezone(timezone.utc)
else:
raise TypeError(
f"Unsupported type: {type(time)}. "
f"Expected datetime, list of datetime, pd.Timestamp, pd.Series, or pd.DatetimeIndex"
)