from dataclasses import dataclass, field
from typing import Any, Dict
from xarray import DataArray
from spectral_indices.transformations.base import (Transformation,
TransformationFactory)
[docs]
@TransformationFactory.register
@dataclass
class SaveRaster(Transformation):
"""Save xarray to a tiff file supporting dask chunk parralle save.
Args:
out_path (``str``): Path to save raster.
nodata (``float``): Value to fill nodata. Default to -1000
"""
out_path: str
nodata: float = field(default=-1000)
[docs]
def planify(self, array: DataArray, lock: Any) -> DataArray:
"""Planify the array saving on a pipeline."""
array = array.fillna(self.nodata)
array = array.rio.write_nodata(self.nodata)
if len(array.dims) > 2:
if "index" in array.dims:
array = array.rename({"index": "band"})
band_names = list(array.coords["band"].values)
elif "time" in array.dims:
array = array.rename({"time": "band"})
band_names = list(
array.coords["band"].dt.strftime("%B %d, %Y, %r").values
)
array = array.assign_coords(
{
"band": band_names,
"x": list(array.coords["x"].values),
"y": list(array.coords["y"].values),
}
)
array.attrs["long_name"] = band_names
array.rio.to_raster(self.out_path, driver="GTiff", compute=True, lock=lock)
return array