from datetime import date
from typing import Any, Dict, List, Tuple, Union
from dask.distributed import Client, LocalCluster, Lock
from simplestac.utils import ItemCollection
from xarray import DataArray
from spectral_indices.filter import Filter
from spectral_indices.idb.models import Formula, Index, Sensor
from spectral_indices.indices_computer import IndicesComputer
from spectral_indices.pipeline.pipeline import Pipeline
from spectral_indices.roi.bbox import BoundingBox
from spectral_indices.sources.sources import DataSource
from spectral_indices.transformations.mask import ApplyMasking, Masking
from spectral_indices.transformations.save import SaveRaster
[docs]
class Processor:
"""Processor wrap all informations to process some data. This is the main class of the library.
Args:
source (``DataSource``): Source of data to collect data from.
indice (``Indice``): Indice to compute. Pass Indices short names to retrieve indices from database.
timestamps (``Union[Union[str, date], List[Union[str, date]]]``): Date range of data.
rois (``Union[BoundingBox, List[BoundingBox]]``): Region of interest.
pipeline (``Pipeline``): Pipeline to apply on data.
filter (``Filter``, **optional**): Filter to apply on itemCollection. Defaults to Filter().
chunk_size (``int``, **optional**): Size of chunks to process for each thread. A value of x represent a chuk of shape(x,x) for spatial dimension (height, width).Defaults to 500.
n_workers (``int``, **optional**): Number of workers to use. Defaults to 4.
processes (``int``, **optional**): Number of processes to launch. Defaults to 1.
worker_memory_limit (``str``, **optional**): Maximum memory allowed for a worker. Defaults to "24GB".
threads (``int``, **optional**): Number of threads/worker. Defaults to 1.
"""
# TODO rewrite docstrings for idb update
def __init__(
self,
source: DataSource,
indices: Union[Union[Index, str], List[Union[Index, str]]],
timestamps: Union[Union[str, date], List[Union[str, date]]],
rois: Union[BoundingBox, List[BoundingBox]],
pipeline: Pipeline,
filter: Filter = Filter(),
maskings: Dict[str, List[int]] = {},
chunk_size=500,
n_workers=1,
processes=1,
worker_memory_limit="24GB",
threads=1,
):
self.source = source
indices = indices if isinstance(indices, list) else [indices]
self.indices = [Index.get_or_raise(short_name=i) for i in indices]
self.sensor = Sensor.get_or_raise(name=source.sensor_name)
self.formulas = [
Formula.get_or_raise(index=i, sensor=self.sensor) for i in self.indices
]
self.bands = self.required_bands()
self.filter = filter
self.rois = rois if isinstance(rois, list) else [rois]
self.pipeline = pipeline
self.timestamps = timestamps
self.chunk_size = chunk_size
self.n_workers = n_workers
self.processes = processes if processes > 1 else False
self.threads = threads
self.worker_memory_limit = worker_memory_limit
self.total_bounds = BoundingBox.union(self.rois)
self.maskings: List[Masking] = []
for mask_name, logical in maskings.items():
if mask_name not in source.masks:
raise ValueError(
f"{mask_name} not in source available masks: {source.masks}"
)
self.maskings.append(
Masking(mask=source.masks[mask_name], logical_string=logical)
)
def required_bands(self):
bands = []
for formula in self.formulas:
formula_bands = formula.get_bands()
if not set(formula_bands) <= set(self.source.bands):
raise ValueError(
f"{formula.index} requires {set(formula_bands) - set(self.source.bands)} which are not available from source '{self.source.name}'"
)
bands += formula_bands
return bands
[docs]
def get_collection(self) -> ItemCollection:
"""Get item collection and apply filters."""
catalog_bands = self.bands
if self.maskings:
catalog_bands += [masking.mask.asset for masking in self.maskings]
total_bounds = BoundingBox.union(self.rois)
item_collection = self.source.fetch_collection(total_bounds, self.timestamps)
item_collection = self.filter.filter_collection(
item_collection, assets=catalog_bands
)
return item_collection
[docs]
def get_data(self) -> DataArray:
"""Get query data as xarray DataArray.
Returns:
``DataArray``:
- Lazy DataArray for query data.
"""
crs = self.total_bounds.crs
# fetch stac collection
item_collection = self.get_collection()
data_array: DataArray = item_collection.to_xarray(
geometry=[roi.transform(crs) for roi in self.rois],
epsg=crs.to_epsg(),
)
data_array = data_array.groupby("time").mean()
data_array = data_array.sortby("time")
data_array = data_array.chunk(
{
"time": -1,
"band": -1,
"y": self.chunk_size,
"x": self.chunk_size,
}
)
return data_array
[docs]
def apply_masks(self, array: DataArray) -> DataArray:
"""Apply mask to data array.
Args:
array (``DataArray``): Data array with masks bands.
Returns:
``DataArray``:
- Masked data array.
"""
mask_transform = ApplyMasking(maskings=self.maskings)
array = mask_transform.planify(array)
return array
[docs]
def graph(self, array: DataArray) -> DataArray:
"""Apply transformations graph to array and return lazy processed array.
Args:
array (``DataArray``): DataArray to apply pipeline on.
It contain all objects that could be usefull for transformations to be applied. Defaults to {}.
Returns:
``DataArray``:
- Lazy processed array.
"""
rio = array.rio
indices_computer = IndicesComputer(formulas=self.formulas)
indice = indices_computer.compute(array)
indice.rio.write_nodata(rio.nodata, inplace=True)
indice.rio.write_crs(rio.crs, inplace=True)
indice.rio.write_transform(rio.transform(), inplace=True)
final = self.pipeline.run(indice)
return final
[docs]
def default_cluster_client(self) -> Tuple[LocalCluster, Client]:
"""Build default cluster & client based on dask distributed."""
cluster = LocalCluster(
n_workers=self.n_workers,
processes=self.processes,
threads_per_worker=self.threads,
memory_limit=self.worker_memory_limit,
)
client = Client(cluster)
return cluster, client
[docs]
def launch(
self,
save: str = "",
nodata=-1000,
compute: bool = False,
cluster: Any = None,
client: Any = None,
) -> DataArray:
"""Run the pipeline.
Args:
save (``str``, **optional**): Path to save pipeline output. If provided, transformation SaveRaster will be applied and so the array will be computed. Defaults to "".
nodata (``int``, **optional**): Value to fill nan for raster writing. Defaults to -1000.
compute (``bool``, **optionnal**): To compute the result array or not. Default to False. **Warning** if bot save and compute the array will be computed 2 times, onece during SaveRaster and once before returning result. It may be sub efficient to do so.
cluster(`Any``, **optionnal**): Custom cluster to apply the pipeline. Default to None (will use dask LocalCluster).
client(`Any``, **optionnal**): Custom client to apply the pipeline. Default to None (will use dask Client).
Returns:
``DataArray``:
- Result of pipeline as lazy DataArray.
"""
if not cluster or not client:
default_cluster, default_client = self.default_cluster_client()
cluster = cluster if cluster else default_cluster
client = client if client else default_client
data = self.get_data()
if self.maskings:
data = self.apply_masks(data)
final = self.graph(data)
with cluster, client:
if isinstance(cluster, LocalCluster) and (save or compute):
print(f"Dask dashboard: {cluster.dashboard_link}")
if save:
saver = SaveRaster(save, nodata=nodata)
final = saver.planify(final, lock=Lock())
if compute:
final = final.compute()
return final