from dataclasses import dataclass, field
from typing import Any, Dict, List, Union
from xarray import DataArray
from spectral_indices.sources.masks import Mask
from spectral_indices.transformations.base import (Transformation,
TransformationFactory)
[docs]
@dataclass
class Masking:
"""Wrap masking metadata.
Args:
logical_string (``str``): A string that represent the logical to apply within a xarray.where() method. Use 'mask' to represent the mask array. i.e: "mask > 5".
mask: (``Mask``): Corresponding Mask class.
"""
logical_string: str = field(default_factory=str)
mask: Mask = field(default_factory=Mask)
def __post_init__(self):
if "mask" not in self.logical_string:
raise ValueError(
f"logical string should contain 'mask' to represent the mask during the call of eval(), got {self.logical_string}"
)
[docs]
@TransformationFactory.register
@dataclass
class ApplyMasking(Transformation):
"""Wraps multiple masking operations."""
maskings: List[Union[Masking, Dict[str, Any]]] = field(default_factory=list)
def __post_init__(self):
maskings = []
for masking in self.maskings:
if not isinstance(masking, Masking):
masking = Masking(**masking)
maskings.append(masking)
self.maskings = maskings
[docs]
def planify(self, array: DataArray) -> DataArray:
"""Apply masking to data.
Args:
array (``DataArray``): Array to mask that contain band masks.
Returns:
``DataArray``:
- Lazy masked array without band masks.
"""
mask_names = [m.mask.asset for m in self.maskings]
mask_assets = array.sel(band=mask_names)
array = array.drop_sel(band=mask_names)
for masking in self.maskings:
mask = mask_assets.sel(band=masking.mask.asset)
array = array.where(eval(masking.logical_string))
return array