from dataclasses import dataclass, field
import numpy as np
import xarray as xr
from scipy.integrate import simpson, trapezoid
from xarray import DataArray
from spectral_indices.transformations.base import (Transformation,
TransformationFactory)
[docs]
@dataclass
class Aggregation(Transformation):
"""Base class for aggregations along a dimension (sum, mean etc..).
Args:
dim (str): Dimensions to perform aggregation on. Default to 'time'.
skipna (bool): To perform aggregation with nan or not. Default to False.
"""
dim: str = field(default="time")
skipna: bool = field(default=False)
[docs]
@TransformationFactory.register
@dataclass
class Sum(Aggregation):
[docs]
def planify(self, array: DataArray):
return array.sum(dim=self.dim, skipna=self.skipna)
[docs]
@TransformationFactory.register
@dataclass
class Mean(Aggregation):
[docs]
def planify(self, array: DataArray):
return array.mean(dim=self.dim, skipna=self.skipna)
[docs]
@TransformationFactory.register
@dataclass
class Max(Aggregation):
[docs]
def planify(self, array: DataArray):
return array.max(dim=self.dim, skipna=self.skipna)
[docs]
@TransformationFactory.register
@dataclass
class Min(Aggregation):
[docs]
def planify(self, array: DataArray):
return array.min(dim=self.dim, skipna=self.skipna)
@TransformationFactory.register
@dataclass
class Std(Aggregation):
def planify(self, array: DataArray):
return array.std(dim=self.dim, skipna=self.skipna)
@TransformationFactory.register
@dataclass
class Quantile(Aggregation):
q: float = field(default=0.5)
def planify(self, array: DataArray):
return array.quantile(q=self.q, dim=self.dim, skipna=self.skipna)
@TransformationFactory.register
@dataclass
class SimpsonIntegration(Aggregation):
use_datetime: bool = field(default=True)
def planify(self, array: DataArray):
auc = xr.apply_ufunc(
simpson,
array,
input_core_dims=[["time"]],
dask="parallelized", # TODO investigate dask parrallelized here.
)
return auc
@TransformationFactory.register
@dataclass
class TrapezoidIntegration(Aggregation):
use_datetime: bool = field(default=True)
def planify(self, array: DataArray):
coord = array[self.dim]
x = (coord - coord[0]) / np.timedelta64(1, "D") if self.use_datetime else coord
return xr.apply_ufunc(
trapezoid,
array,
x,
input_core_dims=[[self.dim], [self.dim]],
vectorize=True,
dask="parallelized",
output_dtypes=[float],
)
@TransformationFactory.register
@dataclass
class Cumsum(Aggregation):
def planify(self, array: DataArray):
return array.cumsum(dim=self.dim, skipna=self.skipna)
@TransformationFactory.register
@dataclass
class MinMaxScaling(Aggregation):
def planify(self, array: DataArray):
min_val = array.min(dim=self.dim, skipna=self.skipna)
max_val = array.max(dim=self.dim, skipna=self.skipna)
return (array - min_val) / (max_val - min_val)
@TransformationFactory.register
@dataclass
class CumulativeMovingAverage(Aggregation):
def planify(self, array: DataArray):
cumsum = array.cumsum(dim=self.dim, skipna=self.skipna)
counts = xr.ones_like(array).cumsum(dim=self.dim)
return cumsum / counts