Source code for spectral_indices.transformations.mask

from dataclasses import dataclass, field
from typing import Any, Dict, List, Union

from xarray import DataArray

from spectral_indices.sources.masks import Mask
from spectral_indices.transformations.base import (Transformation,
                                                   TransformationFactory)


[docs] @dataclass class Masking: """Wrap masking metadata. Args: logical_string (``str``): A string that represent the logical to apply within a xarray.where() method. Use 'mask' to represent the mask array. i.e: "mask > 5". mask: (``Mask``): Corresponding Mask class. """ logical_string: str = field(default_factory=str) mask: Mask = field(default_factory=Mask) def __post_init__(self): if "mask" not in self.logical_string: raise ValueError( f"logical string should contain 'mask' to represent the mask during the call of eval(), got {self.logical_string}" )
[docs] @TransformationFactory.register @dataclass class ApplyMasking(Transformation): """Wraps multiple masking operations.""" maskings: List[Union[Masking, Dict[str, Any]]] = field(default_factory=list) def __post_init__(self): maskings = [] for masking in self.maskings: if not isinstance(masking, Masking): masking = Masking(**masking) maskings.append(masking) self.maskings = maskings
[docs] def planify(self, array: DataArray) -> DataArray: """Apply masking to data. Args: array (``DataArray``): Array to mask that contain band masks. Returns: ``DataArray``: - Lazy masked array without band masks. """ mask_names = [m.mask.asset for m in self.maskings] mask_assets = array.sel(band=mask_names) array = array.drop_sel(band=mask_names) for masking in self.maskings: mask = mask_assets.sel(band=masking.mask.asset) array = array.where(eval(masking.logical_string)) return array