"""
Definitions for the VRaster class.
Most actual functionality is in the `vrt` module.
"""
from __future__ import annotations
import copy
import tempfile
from pathlib import Path
from typing import Any, Callable, Iterable, Literal, overload
import numpy as np
import numpy.typing as npt
import rasterio as rio
from osgeo import gdal
from rasterio.coords import BoundingBox
from rasterio.crs import CRS
from rasterio.transform import Affine
from rasterio.warp import Resampling
from rasterio.windows import Window
from variete import misc
from variete.vrt import pixel_functions
from variete.vrt.pixel_functions import ScalePixelFunction
from variete.vrt.raster_bands import VRTDerivedRasterBand
from variete.vrt.sources import SimpleSource
from variete.vrt.vrt import AnyVRTDataset, VRTDataset, build_vrt, load_vrt, vrt_warp
# tqdm is an optional dependency and will simply raise a custom exception if it's explicitly asked for.
try:
from tqdm import tqdm
_has_tqdm = True
except ImportError:
_has_tqdm = False
# For code simplicity, a dummy tqdm class is ironically needed. This makes it so that a tqdm context can
# always be entered (even though it doesn't do anything if tqdm is not installed)
class tqdm: # type: ignore
def __init__(
self, total: float, disable: bool = False, smoothing: float | None = None, desc: str | None = None
) -> None:
...
def update(self, value: Any) -> None:
...
def __enter__(self) -> None:
...
def __exit__(self, *_: Any) -> None:
...
[docs]class VRasterStep:
"""A VRTDataset and an associated name, for logging purposes."""
dataset: AnyVRTDataset
name: str
[docs] def __init__(self, dataset: AnyVRTDataset, name: str):
for attr in ["dataset", "name"]:
setattr(self, attr, locals()[attr])
[docs]class VRaster:
"""
A "Virtual Raster" containing information on how to process a raster on disk.
A VRaster has no data loaded in memory, other than the processing steps to take when evaluating.
Evaluation is mainly done through the `VRaster.read()` or `VRaster.write()` functions.
"""
# The steps list is a (largely) unordered list of steps and dependencies to the last (current) dataset
# The last dataset (VRaster.steps[-1]) is always the current one, but no other step is required for evaluation.
# All other steps are only to show the steps that were taken to get to the latest, and are all self-contained.
steps: list[VRasterStep]
[docs] def __init__(self, steps: list[VRasterStep] | None = None):
self.steps = steps or []
[docs] @classmethod
def load_file(cls, filepath: str | Path) -> VRaster:
"""
Load a VRaster from a file.
Parameters
----------
filepath
The filepath to a GDAL-supported dataset.
Returns
-------
A newly created VRaster
"""
step = VRasterStep(VRTDataset.from_file(filepath), name="load_file")
return cls(steps=[step])
[docs] def save_vrt(self, filepath: str | Path) -> list[Path]:
"""
Save the VRaster as a VRT or a stack of VRTs.
If the VRaster is nested (depends on more than one VRTDataset), all dependents will be saved too.
Parameters
----------
filepath
The filepath to save the VRT. Multiple VRTs may be saved with suffixes.
Returns
-------
A list of filepaths that were created (multiple in case of a nested VRaster).
"""
if self.last.is_nested():
return self.last.save_vrt_nested(filepath)
else:
self.last.save_vrt(filepath)
return [Path(filepath)]
def _check_compatibility(self, other: VRaster) -> str | None:
"""Check if this VRaster is compatible with another VRaster."""
if self.crs != other.crs:
return f"CRS is different: {self.crs} != {other.crs}"
if self.n_bands != other.n_bands:
return f"Number of bands must be the same: {self.n_bands} != {other.n_bands}"
if self.transform != other.transform:
return f"Transforms must be the same: {self.transform} != {other.transform}"
if self.shape != other.shape:
return f"Shapes must be the same: {self.shape} != {other.shape}"
return None
@overload
def read(
self,
band: int | list[int] | None,
out: npt.ArrayLike | None,
window: Window | None,
masked: Literal[True],
**kwargs: dict[str, Any],
) -> np.ma.MaskedArray[Any, Any]:
...
@overload
def read(
self,
band: int | list[int] | None,
out: npt.ArrayLike | None,
window: Window | None,
masked: Literal[False],
**kwargs: dict[str, Any],
) -> npt.NDarray[Any]:
...
@overload
def read(
self,
band: int | list[int] | None = None,
out: npt.ArrayLike | None = None,
window: Window | None = None,
masked: bool = False,
**kwargs: dict[str, Any],
) -> npt.NDArray[Any] | np.ma.MaskedArray[Any, Any]:
...
[docs] def read(
self,
band: int | list[int] | None = None,
out: npt.ArrayLike | None = None,
window: Window | None = None,
masked: bool = False,
**kwargs: dict[str, Any],
) -> npt.NDArray[Any] | np.ma.MaskedArray[Any, Any]:
"""
Read the contents of a VRaster into memory.
Parameters
----------
band
A band index or a list of indices to load. Defaults to all bands.
out
Optional: The destination array to read to.
window
Optional: Read only a part of the VRaster (see the rasterio Window documentation)
masked
Return a masked array where all nodata values are masked.
**kwargs
Optional keyword arguments to supply the rio.DatasetReader.read method.
Returns
-------
A numpy array of shape (bands, height, width) or a numpy masked_array
"""
with tempfile.TemporaryDirectory() as temp_dir:
filepath = self.last.to_tempfiles(temp_dir)[1]
with rio.open(filepath) as raster:
return raster.read(band, out=out, masked=masked, window=window, **kwargs)
[docs] def write(
self,
filepath: Path | str,
format: str | None = None,
tiled: bool | None = None,
compress: str | None = "deflate",
predictor: Literal[1] | Literal[2] | Literal[3] | None = None,
zlevel: int | str | None = None,
creation_options: dict[str, str | int | bool] | None = None,
progress: bool = False,
callback: Callable[[float, Any, Any], Any] | None = None,
) -> None:
"""
Write the VRaster to a file.
Parameters
----------
filepath
The output filepath to write the file.
format
The output format (e.g. "GTiff"). If not given, the format is inferred from the filename.
tiled
Whether to write the blocks in tiles (True) or in strips (False)
compress
What compression algorithm to use.
predictor
Which compression predictor to use (only valid in some compression schemes).
zlevel
The level of compression to use. For deflate, valid numbers range between 1 and 12
creation_options
Other creation options to provide to GDAL as a {key: value} dictionary
progress
Whether to show a tqdm progress bar. tqdm needs to be installed for this to work.
callback
A callback function for the writer that takes three positional arguments.
The first argument is the progress, ranging from 0-1.
Raises
------
AssertionError
If any requirement for file creation is not filled.
ValueError
If the provided arguments are incompatible.
"""
filepath = Path(filepath)
if not filepath.parent.is_dir():
raise AssertionError("Filepath parent directory does not exist")
if progress and callback is not None:
raise ValueError("'progress' needs to be False if 'callback' is used")
if progress and not _has_tqdm:
raise ValueError("tqdm is required for 'progress=True'. For pip, use 'pip install tqdm'.")
if creation_options is None:
creation_options = {}
lowercase_keys = [key.lower() for key in creation_options]
for key, value in [("COMPRESS", compress), ("TILED", tiled), ("PREDICTOR", predictor), ("ZLEVEL", zlevel)]:
if key.lower() in lowercase_keys or value is None:
continue
creation_options[key] = value
# Always initialize a tqdm context, because it's easier with the context manager..
with tempfile.TemporaryDirectory() as temp_dir, tqdm(
total=100, disable=(not progress), smoothing=0.1, desc=f"Writing {filepath.name}"
) as progress_bar:
_, vrt_path = self.last.to_tempfiles(temp_dir=temp_dir)
if progress:
# This callback function will scale 0-1 to 0-100 and only show integer increments.
prev = 0.0
def callback(value: float, *_: Any) -> None:
nonlocal prev
new_value = value * 100.0
if int(new_value) > int(prev):
progress_bar.update(int(new_value - prev))
prev += new_value
gdal.Translate(
str(filepath),
str(vrt_path),
format=format,
creationOptions=[f"{k}={v}" for k, v in creation_options.items()],
callback=callback,
)
[docs] def add(self, other: int | float | VRaster) -> VRaster:
"""
Perform addition on the VRaster
Parameters
----------
other
A constant value or another VRaster to add.
Returns
-------
A new VRaster
"""
new_vraster = self.copy()
new = new_vraster.last.copy()
if isinstance(other, VRaster):
if (message := self._check_compatibility(other)) is not None:
raise AssertionError(message)
for i, band in enumerate(new.raster_bands):
if misc.nested_getattr(band, ["pixel_function", "name"]) == "scale":
band.sources.append(
SimpleSource(
source_filename=other.last,
source_band=i + 1,
)
)
else:
new_band = VRTDerivedRasterBand.from_raster_band(
band=band, pixel_function=pixel_functions.SumPixelFunction()
)
new_band.sources = [
SimpleSource(
source_filename=new_vraster.last,
source_band=i + 1,
),
SimpleSource(
source_filename=other.last,
source_band=i + 1,
),
]
new.raster_bands[i] = new_band
name = "add_vraster"
else:
for i, band in enumerate(new.raster_bands):
if misc.nested_getattr(band, ["pixel_function", "name"]) == "scale":
if band.offset is not None:
band.offset += other
else:
band.offset = other
else:
new_band = VRTDerivedRasterBand.from_raster_band(band=band, pixel_function=ScalePixelFunction())
new_band.sources = [
SimpleSource(
source_filename=new_vraster.last,
)
]
new_band.offset = other
new.raster_bands[i] = new_band
name = "add_constant"
new_vraster.steps.append(VRasterStep(new, name))
return new_vraster
[docs] def multiply(self, other: int | float | VRaster) -> VRaster:
"""
Perform multiplication on the VRaster
Parameters
----------
other
A constant value or another VRaster to multiply.
Returns
-------
A new VRaster
"""
new_vraster = self.copy()
new = new_vraster.last.copy()
if isinstance(other, VRaster):
if (message := self._check_compatibility(other)) is not None:
raise AssertionError(message)
for i, band in enumerate(new.raster_bands):
if misc.nested_getattr(band, ["pixel_function", "name"]) == "mul":
band.sources.append(
SimpleSource(
source_filename=other.last,
source_band=i + 1,
)
)
else:
new_band = VRTDerivedRasterBand.from_raster_band(
band=band, pixel_function=pixel_functions.MulPixelFunction()
)
new_band.sources = [
SimpleSource(
source_filename=new_vraster.last,
source_band=i + 1,
),
SimpleSource(
source_filename=other.last,
source_band=i + 1,
),
]
new.raster_bands[i] = new_band
# raise NotImplementedError("Not yet implemented for VRaster")
name = "multiply_vraster"
else:
for i, band in enumerate(new.raster_bands):
if misc.nested_getattr(band, ["pixel_function", "name"]) == "scale":
if band.scale is not None:
band.scale *= other
else:
band.scale = other
if band.offset is not None:
band.offset *= other
else:
new_band = VRTDerivedRasterBand.from_raster_band(band=band, pixel_function=ScalePixelFunction())
new_band.sources = [
SimpleSource(
source_filename=new_vraster.last,
)
]
new_band.scale = other
new.raster_bands[i] = new_band
name = "multiply_constant"
new_vraster.steps.append(VRasterStep(new, name))
return new_vraster
[docs] def warp(
self,
reference: VRaster | None = None,
crs: CRS | int | str | None = None,
res: tuple[float, float] | float | None = None,
shape: tuple[int, int] | None = None,
bounds: BoundingBox | list[float] | None = None,
transform: Affine | None = None,
resampling: Resampling | str = "bilinear",
dst_nodata: int | float | None = None,
multithread: bool = False,
) -> VRaster:
"""
Warp the VRaster to new bounds, resolutions and/or coordinate systems.
This wraps the functionality of gdal.Warp
Parameters
----------
reference
Optional: A reference VRaster to get the CRS, transform and shape from.
Note: It silently overrides the `shape`, `crs` and `transform` arguments.
If only parts of the reference parameters should be used, supply them directly instead (e.g. VRaster.crs).
crs
The target coordinate reference system (CRS).
If an integer is given, it's parsed as an EPSG code (e.g. 4326 -> WGS84).
res
The target resolution in georeferenced units. If only one value is given, it is used for both axes.
shape
The target shape of the VRaster in pixels as (height, width).
bounds
The target corner bounds of the VRaster. If a list is given, it's parsed as [xmin, ymin, xmax, ymax]
transform
The target affine transform of the VRaster.
resampling
The target resampling algorithm, e.g. "bilinear" or "cubic_spline".
See rio.warp.Resampling for all available algorithms.
dst_nodata
Destination nodata value to use after warping. Defaults to the source nodata value.
multithread
Use multithreading for the warp operation.
Returns
-------
A new VRaster
"""
new_vraster = self.copy()
warp_kwargs = {
"dst_crs": crs,
"dst_res": res,
"dst_shape": shape,
"dst_bounds": bounds,
"dst_transform": transform,
"dst_nodata": dst_nodata,
"resampling": resampling,
}
for band in self.last.raster_bands:
if band.nodata is not None:
warp_kwargs["src_nodata"] = band.nodata
break
if reference is not None:
for key, value in [
("dst_crs", reference.crs),
("dst_shape", reference.shape),
("dst_transform", reference.transform),
]:
if warp_kwargs[key] is None:
warp_kwargs[key] = value
with tempfile.TemporaryDirectory() as temp_dir:
_, vrt_filepath = new_vraster.last.to_tempfiles(temp_dir)
warped_path = vrt_filepath.with_stem("warped")
vrt_warp(output_filepath=warped_path, input_filepath=vrt_filepath, **warp_kwargs) # type: ignore
warped = load_vrt(warped_path)
new_path = vrt_filepath.with_stem("new")
build_vrt(new_path, warped_path)
new = load_vrt(new_path)
warped.source_dataset = new_vraster.last
for band in new.raster_bands:
band.sources = [SimpleSource(source_filename=warped, source_band=band.band)]
new_vraster.steps.append(VRasterStep(warped, "warp"))
new_vraster.steps.append(VRasterStep(new, "warp_wrapped"))
return new_vraster
[docs] def replace_nodata(self, value: int | float) -> VRaster:
"""
Replace all nodata pixels with the given value.
Parameters
----------
value
The value to replace nodata with
Returns
-------
A new VRaster
"""
# TODO: When no nodata value exists, rio throws an unhelpful error when trying to read. It's not as simple as
# just checking for self.nodata (yet; 2023-04-26), because nodata values may be inherited in many ways.
# Either the self.nodata property should be better, or a custom error handler be made.
new_vraster = self.copy()
new = new_vraster.last.copy()
for i, band in enumerate(new.raster_bands):
new_band = VRTDerivedRasterBand.from_raster_band(
band=band, pixel_function=pixel_functions.ReplaceNodataPixelFunction(value=value)
)
new_band.sources = [SimpleSource(source_filename=new_vraster.last, source_band=i + 1)]
new.raster_bands[i] = new_band
new_vraster.steps.append(VRasterStep(new, "replace_nodata"))
return new_vraster
[docs] def inverse(self) -> VRaster:
"""
Invert the VRaster (1 / x)
Returns
-------
A new VRaster.
"""
new_vraster = self.copy()
new = new_vraster.last.copy()
for i, band in enumerate(new.raster_bands):
new_band = VRTDerivedRasterBand.from_raster_band(
band=band, pixel_function=pixel_functions.InvPixelFunction()
)
new_band.sources = [SimpleSource(source_filename=new_vraster.last, source_band=i + 1)]
new.raster_bands[i] = new_band
new_vraster.steps.append(VRasterStep(new, "inverse"))
return new_vraster
[docs] def divide(self, other: int | float | VRaster) -> VRaster:
"""
Perform division on the VRaster
Parameters
----------
other
A constant value or another VRaster to divide.
Returns
-------
A new VRaster.
"""
if isinstance(other, VRaster):
new_vraster = self.copy()
new = new_vraster.last.copy()
if (message := self._check_compatibility(other)) is not None:
raise AssertionError(message)
for i, band in enumerate(new.raster_bands):
new_band = VRTDerivedRasterBand.from_raster_band(
band=band, pixel_function=pixel_functions.DivPixelFunction()
)
new_band.sources = [
SimpleSource(
source_filename=new_vraster.last,
source_band=i + 1,
),
SimpleSource(
source_filename=other.last.copy(),
source_band=i + 1,
),
]
new.raster_bands[i] = new_band
new_vraster.steps.append(VRasterStep(new, "divide_vraster"))
else:
new_vraster = self.multiply(1 / other)
new_vraster.steps[-1].name = "divide_constant"
return new_vraster
[docs] def subtract(self, other: int | float | VRaster) -> VRaster:
"""
Perform subtraction on the VRaster
Parameters
----------
other
A constant value or another VRaster to subtract.
Returns
-------
A new VRaster
"""
if isinstance(other, VRaster):
negative = other.multiply(-1)
new = self.add(negative)
new.steps[-1].name = "subtract_vraster"
else:
new = self.add(-other)
new.steps[-1].name = "subtract_constant"
return new
@property
def n_bands(self) -> int:
return self.last.n_bands
@property
def crs(self) -> CRS:
return self.last.crs
@property
def transform(self) -> Affine:
return self.last.transform
@property
def bounds(self) -> BoundingBox:
return self.last.bounds
@property
def res(self) -> tuple[float, float]:
return self.last.res
[docs] def copy(self) -> VRaster:
return copy.deepcopy(self)
@property
def shape(self) -> tuple[int, int]:
return self.last.shape
@property
def last(self) -> VRTDataset:
return self.steps[-1].dataset
@property
def nodata(self) -> int | float | None:
"""Get the first nodata value in the raster."""
for band in self.last.raster_bands:
if band.nodata is not None:
return band.nodata
return None
@nodata.setter
def nodata(self, new_nodata: float | int | None) -> None:
"""Set the first nodata value in the raster."""
for band in self.last.raster_bands:
band.nodata = new_nodata
@overload
def sample(
self, x_coord: Iterable[float], y_coord: Iterable[float], band: int | list[int], masked: Literal[False]
) -> npt.NDArray[Any]:
...
@overload
def sample(
self, x_coord: Iterable[float], y_coord: Iterable[float], band: int | list[int], masked: Literal[True]
) -> np.ma.MaskedArray[Any, Any]:
...
@overload
def sample(
self, x_coord: float, y_coord: float, band: list[int], masked: Literal[True]
) -> np.ma.MaskedArray[Any, Any]:
...
@overload
def sample(self, x_coord: float, y_coord: float, band: list[int], masked: Literal[False]) -> npt.NDArray[Any]:
...
@overload
def sample(self, x_coord: float, y_coord: float, band: int, masked: bool) -> int | float:
...
@overload
def sample(
self,
x_coord: float | Iterable[float],
y_coord: float | Iterable[float],
band: int | list[int] = 1,
masked: bool = False,
) -> int | float | npt.NDArray[Any] | np.ma.MaskedArray[Any, Any]:
...
[docs] def sample(
self,
x_coord: float | Iterable[float],
y_coord: float | Iterable[float],
band: int | list[int] = 1,
masked: bool = False,
) -> int | float | npt.NDArray[Any] | np.ma.MaskedArray[Any, Any]:
"""
Sample values at the given georeferenced coordinates of a VRaster.
Parameters
----------
x_coord
The x (easting/longitude) coordinate(s) to sample
y_coord
The x (northing/latitude) coordinate(s) to sample
band
The band(s) to sample from. Defaults to the first.
masked
Return a masked array with nodata values masked out.
Returns
-------
If one coordinate and one band:
One sampled value
If multiple coordinates and/or multiple bands:
An array of coordinates
"""
if self.last.is_nested():
with tempfile.TemporaryDirectory(prefix="variete") as temp_dir:
return load_vrt(self.last.to_tempfiles(temp_dir=temp_dir)[1]).sample(
x_coord=x_coord, y_coord=y_coord, band=band, masked=masked
)
return self.steps[-1].dataset.sample(x_coord=x_coord, y_coord=y_coord, band=band, masked=masked)
@overload
def sample_rowcol(self, row: float, col: float, band: int, masked: bool) -> int | float:
...
@overload
def sample_rowcol(
self, row: float, col: float, band: list[int], masked: Literal[True]
) -> np.ma.MaskedArray[Any, Any]:
...
@overload
def sample_rowcol(self, row: float, col: float, band: list[int], masked: Literal[False]) -> npt.NDArray[Any]:
...
@overload
def sample_rowcol(
self, row: Iterable[float], col: Iterable[float], band: int | list[int], masked: Literal[True]
) -> np.ma.MaskedArray[Any, Any]:
...
@overload
def sample_rowcol(
self, row: Iterable[float], col: Iterable[float], band: int | list[int], masked: Literal[False]
) -> npt.NDArray[Any]:
...
@overload
def sample_rowcol(
self,
row: float | Iterable[float],
col: float | Iterable[float],
band: int | list[int] = 1,
masked: bool = False,
) -> int | float | npt.NDArray[Any] | np.ma.MaskedArray[Any, Any]:
...
[docs] def sample_rowcol(
self,
row: float | Iterable[float],
col: float | Iterable[float],
band: int | list[int] = 1,
masked: bool = False,
) -> int | float | npt.NDArray[Any] | np.ma.MaskedArray[Any, Any]:
"""
Sample values at the given row(s) and column(s) of a VRaster.
Parameters
----------
row
The row(s) to sample.
y_coord
The column(s) to sample.
band
The band(s) to sample from. Defaults to the first.
masked
Return a masked array with nodata values masked out.
Returns
-------
If one coordinate and one band:
One sampled value
If multiple coordinates and/or multiple bands:
An array of coordinates
"""
x_coord, y_coord = rio.transform.xy(self.transform, row, col)
return self.sample(x_coord, y_coord, band=band, masked=masked) # type: ignore
def __div__(self, other: int | float | VRaster) -> VRaster:
return self.divide(other)
def __rdiv__(self, other: int | float | VRaster) -> VRaster:
return self.inverse().__rmul__(other)
def __add__(self, other: int | float | VRaster) -> VRaster:
return self.add(other)
def __radd__(self, other: int | float | VRaster) -> VRaster:
return self.__add__(other)
def __sub__(self, other: int | float | VRaster) -> VRaster:
return self.subtract(other)
def __neg__(self) -> VRaster:
return self.multiply(-1)
def __rsub__(self, other: int | float | VRaster) -> VRaster:
return self.__neg__().__add__(other)
def __mul__(self, other: int | float | VRaster) -> VRaster:
return self.multiply(other)
def __rmul__(self, other: int | float | VRaster) -> VRaster:
return self.__mul__(other)
def __truediv__(self, other: int | float | VRaster) -> VRaster:
return self.__div__(other)
def __rtruediv__(self, other: int | float | VRaster) -> VRaster:
return self.__rdiv__(other)
[docs]def load(filepath: str | Path, nodata_to_nan: bool = True) -> VRaster:
"""
Load a VRaster from a file.
Parameters
----------
filepath
The path to a GDAL-readable dataset.
nodata_to_nan
Whether to convert nodata values to np.nan on load
Returns
-------
A new VRaster
"""
vraster = VRaster.load_file(filepath)
if nodata_to_nan:
replace = False
for band in vraster.last.raster_bands:
if band.nodata is not None:
replace = True
if replace:
vraster = vraster.replace_nodata(np.nan)
return vraster