Source code for spectral_indices.transformations.time

from dataclasses import dataclass, field
from datetime import timedelta
from typing import Union

from numpy import timedelta64
from pandas import Timedelta
from xarray import DataArray

from spectral_indices.transformations.base import (Transformation,
                                                   TransformationFactory)


[docs] @TransformationFactory.register @dataclass class TemporalInterpolation(Transformation): """Perform temporal interpolation on xarray. Args: method (``str``): Method to perform interpolation. See https://docs.xarray.dev/en/stable/generated/xarray.DataArray.interpolate_na.html to see all available options. Default to 'linear'. gap (``Union[timedelta, timedelta64, Timedelta``): Maximum length of nan gap to fill. Default to None (no limit). """ # TODO allow args & kwargs optionnal parameters from interpolate_na. method: str = field(default="linear") gap: Union[timedelta, timedelta64, Timedelta] = field(default=None)
[docs] def planify(self, array: DataArray) -> DataArray: array = array.chunk({"time": -1}) array = ( array.ffill("time") .bfill("time") .interpolate_na( dim="time", method=self.method, max_gap=self.gap, ) ) return array
@dataclass class TemporalOperation(Transformation): dim: str = field(default="time") skipna: bool = field(default=False) @TransformationFactory.register @dataclass class Diff(TemporalOperation): def planify(self, array: DataArray): return array.diff(dim=self.dim) @TransformationFactory.register @dataclass class RollingMean(TemporalOperation): window: int = field(default=3) center: bool = field(default=False) def planify(self, array: DataArray): return array.rolling({self.dim: self.window}, center=self.center).mean() @TransformationFactory.register @dataclass class RollingMedian(TemporalOperation): window: int = field(default=3) center: bool = field(default=False) def planify(self, array: DataArray): return array.rolling({self.dim: self.window}, center=self.center).median() @TransformationFactory.register @dataclass class RollingStd(TemporalOperation): window: int = field(default=3) center: bool = field(default=False) def planify(self, array: DataArray): return array.rolling({self.dim: self.window}, center=self.center).std() @TransformationFactory.register @dataclass class ZScore(TemporalOperation): def planify(self, array: DataArray): mean = array.mean(dim=self.dim, skipna=self.skipna) std = array.std(dim=self.dim, skipna=self.skipna) return (array - mean) / std @TransformationFactory.register @dataclass class TemporalDerivative(TemporalOperation): def planify(self, array: DataArray): time = array[self.dim] dt = (time - time.shift({self.dim: 1})).astype("timedelta64[s]").fillna(1) dy = array.diff(dim=self.dim) return dy / dt @TransformationFactory.register @dataclass class BreakpointDetection(TemporalOperation): threshold: float = field(default=0.1) def planify(self, array: DataArray): diff = array.diff(dim=self.dim).abs() return diff > self.threshold @TransformationFactory.register @dataclass class ResampleTemporal(TemporalOperation): freq: str = field(default="1D") how: str = field(default="mean") def planify(self, array: DataArray): resampled = array.resample({self.dim: self.freq}) agg_func = getattr(resampled, self.how) return agg_func(skipna=self.skipna)