from dataclasses import dataclass, field
from datetime import timedelta
from typing import Union
from numpy import timedelta64
from pandas import Timedelta
from xarray import DataArray
from spectral_indices.transformations.base import (Transformation,
TransformationFactory)
[docs]
@TransformationFactory.register
@dataclass
class TemporalInterpolation(Transformation):
"""Perform temporal interpolation on xarray.
Args:
method (``str``): Method to perform interpolation. See https://docs.xarray.dev/en/stable/generated/xarray.DataArray.interpolate_na.html to see all available options.
Default to 'linear'.
gap (``Union[timedelta, timedelta64, Timedelta``): Maximum length of nan gap to fill. Default to None (no limit).
"""
# TODO allow args & kwargs optionnal parameters from interpolate_na.
method: str = field(default="linear")
gap: Union[timedelta, timedelta64, Timedelta] = field(default=None)
[docs]
def planify(self, array: DataArray) -> DataArray:
array = array.chunk({"time": -1})
array = (
array.ffill("time")
.bfill("time")
.interpolate_na(
dim="time",
method=self.method,
max_gap=self.gap,
)
)
return array
@dataclass
class TemporalOperation(Transformation):
dim: str = field(default="time")
skipna: bool = field(default=False)
@TransformationFactory.register
@dataclass
class Diff(TemporalOperation):
def planify(self, array: DataArray):
return array.diff(dim=self.dim)
@TransformationFactory.register
@dataclass
class RollingMean(TemporalOperation):
window: int = field(default=3)
center: bool = field(default=False)
def planify(self, array: DataArray):
return array.rolling({self.dim: self.window}, center=self.center).mean()
@TransformationFactory.register
@dataclass
class RollingMedian(TemporalOperation):
window: int = field(default=3)
center: bool = field(default=False)
def planify(self, array: DataArray):
return array.rolling({self.dim: self.window}, center=self.center).median()
@TransformationFactory.register
@dataclass
class RollingStd(TemporalOperation):
window: int = field(default=3)
center: bool = field(default=False)
def planify(self, array: DataArray):
return array.rolling({self.dim: self.window}, center=self.center).std()
@TransformationFactory.register
@dataclass
class ZScore(TemporalOperation):
def planify(self, array: DataArray):
mean = array.mean(dim=self.dim, skipna=self.skipna)
std = array.std(dim=self.dim, skipna=self.skipna)
return (array - mean) / std
@TransformationFactory.register
@dataclass
class TemporalDerivative(TemporalOperation):
def planify(self, array: DataArray):
time = array[self.dim]
dt = (time - time.shift({self.dim: 1})).astype("timedelta64[s]").fillna(1)
dy = array.diff(dim=self.dim)
return dy / dt
@TransformationFactory.register
@dataclass
class BreakpointDetection(TemporalOperation):
threshold: float = field(default=0.1)
def planify(self, array: DataArray):
diff = array.diff(dim=self.dim).abs()
return diff > self.threshold
@TransformationFactory.register
@dataclass
class ResampleTemporal(TemporalOperation):
freq: str = field(default="1D")
how: str = field(default="mean")
def planify(self, array: DataArray):
resampled = array.resample({self.dim: self.freq})
agg_func = getattr(resampled, self.how)
return agg_func(skipna=self.skipna)