Source code for spectral_indices.processor

from datetime import date
from typing import Any, Dict, List, Tuple, Union

from dask.distributed import Client, LocalCluster, Lock
from simplestac.utils import ItemCollection
from xarray import DataArray

from spectral_indices.filter import Filter
from spectral_indices.idb.models import Formula, Index, Sensor
from spectral_indices.indices_computer import IndicesComputer
from spectral_indices.pipeline.pipeline import Pipeline
from spectral_indices.roi.bbox import BoundingBox
from spectral_indices.sources.sources import DataSource
from spectral_indices.transformations.mask import ApplyMasking, Masking
from spectral_indices.transformations.save import SaveRaster


[docs] class Processor: """Processor wrap all informations to process some data. This is the main class of the library. Args: source (``DataSource``): Source of data to collect data from. indice (``Indice``): Indice to compute. Pass Indices short names to retrieve indices from database. timestamps (``Union[Union[str, date], List[Union[str, date]]]``): Date range of data. rois (``Union[BoundingBox, List[BoundingBox]]``): Region of interest. pipeline (``Pipeline``): Pipeline to apply on data. filter (``Filter``, **optional**): Filter to apply on itemCollection. Defaults to Filter(). chunk_size (``int``, **optional**): Size of chunks to process for each thread. A value of x represent a chuk of shape(x,x) for spatial dimension (height, width).Defaults to 500. n_workers (``int``, **optional**): Number of workers to use. Defaults to 4. processes (``int``, **optional**): Number of processes to launch. Defaults to 1. worker_memory_limit (``str``, **optional**): Maximum memory allowed for a worker. Defaults to "24GB". threads (``int``, **optional**): Number of threads/worker. Defaults to 1. """ # TODO rewrite docstrings for idb update def __init__( self, source: DataSource, indices: Union[Union[Index, str], List[Union[Index, str]]], timestamps: Union[Union[str, date], List[Union[str, date]]], rois: Union[BoundingBox, List[BoundingBox]], pipeline: Pipeline, filter: Filter = Filter(), maskings: Dict[str, List[int]] = {}, chunk_size=500, n_workers=1, processes=1, worker_memory_limit="24GB", threads=1, ): self.source = source indices = indices if isinstance(indices, list) else [indices] self.indices = [Index.get_or_raise(short_name=i) for i in indices] self.sensor = Sensor.get_or_raise(name=source.sensor_name) self.formulas = [ Formula.get_or_raise(index=i, sensor=self.sensor) for i in self.indices ] self.bands = self.required_bands() self.filter = filter self.rois = rois if isinstance(rois, list) else [rois] self.pipeline = pipeline self.timestamps = timestamps self.chunk_size = chunk_size self.n_workers = n_workers self.processes = processes if processes > 1 else False self.threads = threads self.worker_memory_limit = worker_memory_limit self.total_bounds = BoundingBox.union(self.rois) self.maskings: List[Masking] = [] for mask_name, logical in maskings.items(): if mask_name not in source.masks: raise ValueError( f"{mask_name} not in source available masks: {source.masks}" ) self.maskings.append( Masking(mask=source.masks[mask_name], logical_string=logical) ) def required_bands(self): bands = [] for formula in self.formulas: formula_bands = formula.get_bands() if not set(formula_bands) <= set(self.source.bands): raise ValueError( f"{formula.index} requires {set(formula_bands) - set(self.source.bands)} which are not available from source '{self.source.name}'" ) bands += formula_bands return bands
[docs] def get_collection(self) -> ItemCollection: """Get item collection and apply filters.""" catalog_bands = self.bands if self.maskings: catalog_bands += [masking.mask.asset for masking in self.maskings] total_bounds = BoundingBox.union(self.rois) item_collection = self.source.fetch_collection(total_bounds, self.timestamps) item_collection = self.filter.filter_collection( item_collection, assets=catalog_bands ) return item_collection
[docs] def get_data(self) -> DataArray: """Get query data as xarray DataArray. Returns: ``DataArray``: - Lazy DataArray for query data. """ crs = self.total_bounds.crs # fetch stac collection item_collection = self.get_collection() data_array: DataArray = item_collection.to_xarray( geometry=[roi.transform(crs) for roi in self.rois], epsg=crs.to_epsg(), ) data_array = data_array.groupby("time").mean() data_array = data_array.sortby("time") data_array = data_array.chunk( { "time": -1, "band": -1, "y": self.chunk_size, "x": self.chunk_size, } ) return data_array
[docs] def apply_masks(self, array: DataArray) -> DataArray: """Apply mask to data array. Args: array (``DataArray``): Data array with masks bands. Returns: ``DataArray``: - Masked data array. """ mask_transform = ApplyMasking(maskings=self.maskings) array = mask_transform.planify(array) return array
[docs] def graph(self, array: DataArray) -> DataArray: """Apply transformations graph to array and return lazy processed array. Args: array (``DataArray``): DataArray to apply pipeline on. It contain all objects that could be usefull for transformations to be applied. Defaults to {}. Returns: ``DataArray``: - Lazy processed array. """ rio = array.rio indices_computer = IndicesComputer(formulas=self.formulas) indice = indices_computer.compute(array) indice.rio.write_nodata(rio.nodata, inplace=True) indice.rio.write_crs(rio.crs, inplace=True) indice.rio.write_transform(rio.transform(), inplace=True) final = self.pipeline.run(indice) return final
[docs] def default_cluster_client(self) -> Tuple[LocalCluster, Client]: """Build default cluster & client based on dask distributed.""" cluster = LocalCluster( n_workers=self.n_workers, processes=self.processes, threads_per_worker=self.threads, memory_limit=self.worker_memory_limit, ) client = Client(cluster) return cluster, client
[docs] def launch( self, save: str = "", nodata=-1000, compute: bool = False, cluster: Any = None, client: Any = None, ) -> DataArray: """Run the pipeline. Args: save (``str``, **optional**): Path to save pipeline output. If provided, transformation SaveRaster will be applied and so the array will be computed. Defaults to "". nodata (``int``, **optional**): Value to fill nan for raster writing. Defaults to -1000. compute (``bool``, **optionnal**): To compute the result array or not. Default to False. **Warning** if bot save and compute the array will be computed 2 times, onece during SaveRaster and once before returning result. It may be sub efficient to do so. cluster(`Any``, **optionnal**): Custom cluster to apply the pipeline. Default to None (will use dask LocalCluster). client(`Any``, **optionnal**): Custom client to apply the pipeline. Default to None (will use dask Client). Returns: ``DataArray``: - Result of pipeline as lazy DataArray. """ if not cluster or not client: default_cluster, default_client = self.default_cluster_client() cluster = cluster if cluster else default_cluster client = client if client else default_client data = self.get_data() if self.maskings: data = self.apply_masks(data) final = self.graph(data) with cluster, client: if isinstance(cluster, LocalCluster) and (save or compute): print(f"Dask dashboard: {cluster.dashboard_link}") if save: saver = SaveRaster(save, nodata=nodata) final = saver.planify(final, lock=Lock()) if compute: final = final.compute() return final