Source code for variete.vrt.vrt

from __future__ import annotations

import copy
import hashlib
import tempfile
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, Iterable, Literal, Sequence

import lxml.etree as ET
import numpy as np
import numpy.typing as npt
import rasterio as rio
from affine import Affine
from osgeo import gdal
from rasterio import CRS
from rasterio.coords import BoundingBox
from rasterio.warp import Resampling

from variete import misc
from variete.vrt.raster_bands import AnyRasterBand, raster_band_from_etree


[docs]def build_vrt( output_filepath: Path | str, filepaths: Path | str | list[Path | str], calculate_resolution: Literal["highest"] | Literal["lowest"] | Literal["average"] | Literal["user"] = "average", res: tuple[float, float] | None = None, separate: bool = False, output_bounds: BoundingBox | None = None, resample_algorithm: Resampling = Resampling.bilinear, target_aligned_pixels: bool = False, band_list: list[int] | None = None, add_alpha: bool = False, output_crs: CRS | int | str | None = None, allow_projection_difference: bool = False, src_nodata: int | float | None = None, vrt_nodata: int | float | None = None, strict: bool = True, ) -> None: if target_aligned_pixels and res is None: raise ValueError(f"{target_aligned_pixels=} requires that 'res' is specified") if isinstance(filepaths, (str, Path)): filepaths = [filepaths] if res is not None: x_res: float | None = res[0] y_res: float | None = res[1] else: x_res = y_res = None if output_crs is not None: if isinstance(output_crs, int): output_crs = CRS.from_epsg(output_crs).to_wkt() elif isinstance(output_crs, CRS): output_crs = output_crs.to_wkt() else: output_crs = str(output_crs) gdal.BuildVRT( str(output_filepath), list(map(str, filepaths)), resolution=calculate_resolution, xRes=x_res, yRes=y_res, separate=separate, outputBounds=list(output_bounds) if output_bounds is not None else None, resampleAlg=resample_algorithm, targetAlignedPixels=target_aligned_pixels, bandList=band_list, addAlpha=add_alpha, outputSRS=output_crs, allowProjectionDifference=allow_projection_difference, srcNodata=src_nodata, VRTNodata=vrt_nodata, strict=strict, )
[docs]def vrt_warp( output_filepath: Path | str, input_filepath: Path | str, # src_crs: CRS | int | str | None = None, dst_crs: CRS | int | str | None = None, dst_res: tuple[float, float] | float | None = None, # src_res: tuple[float, float] | None = None, dst_shape: tuple[int, int] | None = None, # src_bounds: BoundingBox | list[float] | None = None, dst_bounds: BoundingBox | list[float] | None = None, # src_transform: Affine | None = None, dst_transform: Affine | None = None, src_nodata: int | float | None = None, dst_nodata: int | float | None = None, resampling: Resampling | str = "bilinear", multithread: bool = False, ) -> None: if isinstance(resampling, str): resampling = getattr(Resampling, resampling) kwargs = { "resampleAlg": misc.resampling_rio_to_gdal(resampling), "multithread": multithread, "format": "VRT", "dstNodata": dst_nodata, "srcNodata": src_nodata, } # This is strange. Warped pixels that are outside the range of the original raster get assigned to 0 # Unclear if this can be overridden somehow! It should be dst_nodata or np.nan if kwargs["dstNodata"] is None: kwargs["dstNodata"] = 0 for key, crs in [("dstSRS", dst_crs)]: if crs is None: if key == "dst_wkt": raise TypeError("dst_crs has to be provided") continue if isinstance(crs, int): kwargs[key] = CRS.from_epsg(crs).to_wkt() elif isinstance(crs, CRS): kwargs[key] = crs.to_wkt() else: kwargs[key] = crs if dst_transform is not None and dst_shape is None: raise ValueError("dst_transform requires dst_shape, which was not supplied.") if dst_transform is not None and dst_res is not None: raise ValueError("dst_transform and dst_res cannot be used at the same time.") if dst_transform is not None and dst_bounds is not None: raise ValueError("dst_transform and dst_bounds cannot be used at the same time.") if dst_shape is not None and dst_res is not None: raise ValueError("dst_shape and dst_res cannot be used at the same time.") if dst_transform is not None: # kwargs["dstTransform"] = dst_transform.to_gdal() kwargs["outputBounds"] = list(rio.transform.array_bounds(*dst_shape, dst_transform)) if dst_shape is not None: kwargs["width"] = dst_shape[1] kwargs["height"] = dst_shape[0] if dst_res is not None: if isinstance(dst_res, Sequence): kwargs["xRes"] = dst_res[0] # type: ignore kwargs["yRes"] = dst_res[1] # type: ignore else: kwargs["xRes"] = dst_res kwargs["yRes"] = dst_res gdal.Warp(str(output_filepath), str(input_filepath), **kwargs)
[docs]def build_warped_vrt( vrt_filepath: Path | str, filepath: Path | str, dst_crs: CRS | int | str, resample_algorithm: Resampling = Resampling.bilinear, max_error: float = 0.125, src_crs: CRS | int | str | None = None, ) -> None: crss = {"dst_wkt": dst_crs, "src_wkt": src_crs} for key, crs in crss.items(): if crs is None: if key == "dst_wkt": raise TypeError("dst_crs has to be provided") continue if isinstance(crs, int): crss[key] = CRS.from_epsg(crs).to_wkt() elif isinstance(crs, CRS): crss[key] = crs.to_wkt() else: crss[key] = crs dataset = gdal.Open(str(filepath)) vrt_dataset = gdal.AutoCreateWarpedVRT(dataset, crss["src_wkt"], crss["dst_wkt"], resample_algorithm, max_error) vrt_dataset.GetDriver().CreateCopy(str(vrt_filepath), vrt_dataset) del dataset del vrt_dataset
[docs]class VRTDataset: shape: tuple[int, int] crs: CRS crs_mapping: str transform: Affine raster_bands: list[AnyRasterBand] subclass: str | None # block_size: tuple[int, int] | None
[docs] def __init__( self, shape: tuple[int, int], crs: CRS, transform: Affine, raster_bands: list[AnyRasterBand], crs_mapping: str = "2,1", ) -> None: for attr in ["shape", "crs", "crs_mapping", "transform", "raster_bands"]: setattr(self, attr, locals()[attr]) self.subclass = self.warp_options = None
@property def n_bands(self) -> int: return len(self.raster_bands) @property def bounds(self) -> rio.coords.BoundingBox: return rio.coords.BoundingBox(*rio.transform.array_bounds(*self.shape, self.transform)) @property def res(self) -> tuple[float, float]: """ Return the X/Y resolution of the dataset. """ return self.transform.a, -self.transform.e
[docs] def to_etree(self) -> ET.Element: vrt = ET.Element("VRTDataset", {"rasterXSize": str(self.shape[1]), "rasterYSize": str(self.shape[0])}) crs = ET.SubElement(vrt, "SRS", {"dataAxisToSRSAxisMapping": self.crs_mapping}) crs.text = misc.crs_to_string(self.crs) transform = ET.SubElement(vrt, "GeoTransform") transform.text = misc.transform_to_gdal(self.transform) for band in self.raster_bands: vrt.append(band.to_etree()) return vrt
[docs] def to_xml(self) -> str: vrt = self.to_etree() ET.indent(vrt) return ET.tostring(vrt).decode()
[docs] @classmethod def from_etree(cls, root: ET.Element) -> VRTDataset: x_size, y_size = (int(root.get(f"raster{k}Size", 0)) for k in ["X", "Y"]) srs_elem, srs_text = misc.find_element(root, "SRS", "both") crs = CRS.from_string(srs_text) crs_mapping = srs_elem.get("dataAxisToSRSAxisMapping", "2,1") transform = misc.parse_gdal_transform(misc.find_element(root, "GeoTransform", True)) raster_bands = [] for band in root.findall("VRTRasterBand"): raster_bands.append(raster_band_from_etree(band)) return cls( shape=(y_size, x_size), crs=crs, transform=transform, raster_bands=raster_bands, crs_mapping=crs_mapping )
[docs] def copy(self) -> VRTDataset: return copy.deepcopy(self)
[docs] @classmethod def from_xml(cls, xml: str) -> VRTDataset: vrt = ET.fromstring(xml) return cls.from_etree(vrt)
[docs] @classmethod def load_vrt(cls, filepath: Path) -> VRTDataset: with open(filepath) as infile: return cls.from_xml(infile.read())
[docs] def save_vrt(self, filepath: str | Path) -> None: with open(filepath, "w") as outfile: outfile.write(self.to_xml())
def _save_vrt_nested(self, filepath: Path, nested_level: list[int]) -> list[Path]: if len(nested_level) == 0: save_filepath = filepath else: save_filepath = filepath.with_stem(filepath.stem + "-nested-" + "-".join(map(str, nested_level))) nested_level += [0] filepaths = [save_filepath] j = 1 vrt = self.copy() for raster_band in vrt.raster_bands: for source in raster_band.sources: if hasattr(source.source_filename, "_save_vrt_nested"): # new_filepath = filepath.with_stem(filepath.stem + "-" + str(j).zfill(2)) new_nest = nested_level.copy() new_nest[-1] = j new_filepaths = source.source_filename._save_vrt_nested(filepath, new_nest) source.source_filename = new_filepaths[0] source.relative_filename = False filepaths += new_filepaths j += 1 vrt.save_vrt(save_filepath) # print(f"Saved {save_filepath}: {nested_level}") return filepaths
[docs] def save_vrt_nested(self, filepath: Path | str) -> list[Path]: return list(set(self._save_vrt_nested(filepath=Path(filepath).absolute(), nested_level=[])))
[docs] @classmethod def from_file(cls, filepaths: Path | str | list[Path | str], **kwargs: dict[str, Any]) -> VRTDataset: with tempfile.TemporaryDirectory() as temp_dir: temp_vrt = Path(temp_dir).joinpath("temp.vrt") build_vrt(output_filepath=temp_vrt, filepaths=filepaths, **kwargs) # type: ignore return cls.load_vrt(temp_vrt)
[docs] def sha1(self) -> str: return hashlib.sha1(str(self.__dict__).encode()).hexdigest()
[docs] def is_nested(self) -> bool: for raster_band in self.raster_bands: for source in raster_band.sources: if hasattr(source.source_filename, "to_tempfiles"): return True return False
[docs] def to_tempfiles( self, temp_dir: TemporaryDirectory[str] | str | Path | None = None ) -> tuple[TemporaryDirectory[str] | str | Path, Path]: if temp_dir is None: temp_dir = TemporaryDirectory(prefix="variete") if isinstance(temp_dir, TemporaryDirectory): temp_dir_path = Path(temp_dir.name) else: temp_dir_path = Path(temp_dir) filepath = temp_dir_path.joinpath("vrtdataset.vrt") self.save_vrt_nested(filepath) return temp_dir, filepath
[docs] def to_memfile(self) -> rio.MemoryFile: if self.is_nested(): raise ValueError("Nested VRTs require temporary saving to work (see to_memfile_nested") return rio.MemoryFile(self.to_xml().encode(), ext=".vrt")
[docs] def to_memfile_nested( self, temp_dir: TemporaryDirectory[str] | str | Path | None ) -> tuple[TemporaryDirectory[str] | Path | str | None, rio.MemoryFile]: if not self.is_nested(): return (temp_dir, self.to_memfile()) if temp_dir is None: temp_dir = TemporaryDirectory(prefix="variete") _, filepath = self.to_tempfiles(temp_dir=temp_dir) with open(filepath, "rb") as infile: return (temp_dir, rio.MemoryFile(infile.read()))
@property def open_rio(self) -> Callable[..., rio.DatasetReader]: if self.is_nested(): raise ValueError("Nested VRTs require temporary saving to work (see open_rio_nested") return self.to_memfile().open
[docs] def open_rio_nested( self, temp_dir: TemporaryDirectory[str] | str | Path | None = None ) -> tuple[TemporaryDirectory[str] | str | Path | None, Callable[..., rio.DatasetReader]]: if not self.is_nested(): return (temp_dir, self.open_rio) if temp_dir is None: temp_dir = TemporaryDirectory(prefix="variete") return (temp_dir, self.to_memfile_nested(temp_dir=temp_dir)[1].open)
[docs] def sample( self, x_coord: float | Iterable[float], y_coord: float | Iterable[float], band: int | list[int] = 1, masked: bool = False, ) -> int | float | npt.NDArray[Any]: x_coords: Iterable[float] = [] y_coords: Iterable[float] = [] if isinstance(x_coord, float): x_coords = [x_coord] else: x_coords = x_coord if isinstance(y_coord, float): y_coords = [y_coord] else: y_coords = y_coord with self.open_rio() as raster: with warnings.catch_warnings(): warnings.filterwarnings( "ignore", r".*Conversion of an array with ndim > 0 to a scalar is deprecated.*" ) # Should be removed when rasterio fixes this (2023-08-02) values = np.fromiter( raster.sample(zip(x_coords, y_coords), indexes=band, masked=masked), dtype=self.raster_bands[0].dtype, count=-1 if not hasattr(x_coords, "__len__") else len(x_coords), # type: ignore ).ravel() if values.size > 1: return values return values[0]
def __repr__(self) -> str: return "\n".join( [f"VRTDataset: shape={self.shape}, crs=EPSG:{self.crs.to_epsg()}, bounds: {self.bounds}"] + ["\t" + "\n\t".join(band.__repr__().splitlines()) for band in self.raster_bands] )
[docs]class WarpedVRTDataset(VRTDataset): """A VRTDataset that specifies a GDAL warp operation.""" shape: tuple[int, int] crs: CRS crs_mapping: str transform: Affine block_size: tuple[int, int] raster_bands: list[AnyRasterBand] warp_memory_limit: float resample_algorithm: Resampling dst_dtype: str options: dict[str, str] source_dataset: str | Path relative_filename: bool | None band_mapping: list[tuple[int, int]] max_error: float approximate: bool src_transform: Affine src_inv_transform: Affine dst_transform: Affine dst_inv_transform: Affine
[docs] def __init__( self, shape: tuple[int, int], crs: CRS, transform: Affine, raster_bands: list[AnyRasterBand], resample_algorithm: Resampling, block_size: tuple[int, int], dst_dtype: str, options: dict[str, str], source_dataset: str | Path, band_mapping: list[tuple[int, int]], src_transform: Affine, src_inv_transform: Affine, dst_transform: Affine, dst_inv_transform: Affine, crs_mapping: str = "2,1", relative_filename: bool | None = None, max_error: float = 0.125, approximate: bool = True, warp_memory_limit: float = 6.71089e07, ): if crs_mapping is None: crs_mapping = "2,1" if relative_filename is None: if isinstance(source_dataset, Path): self.relative_filename = not source_dataset.is_absolute() else: self.relative_filename = True else: self.relative_filename = relative_filename attrs = ( ["shape", "crs", "transform", "raster_bands", "resample_algorithm", "block_size", "dst_dtype"] + ["options", "source_dataset", "band_mapping", "src_transform", "src_inv_transform", "dst_transform"] + ["dst_inv_transform", "crs_mapping", "warp_memory_limit", "max_error", "approximate"] ) for attr in attrs: setattr(self, attr, locals()[attr])
[docs] @classmethod def from_etree(cls, root: ET.Element) -> WarpedVRTDataset: initial = VRTDataset.from_etree(root) block_size = int(misc.find_element(root, "BlockXSize", True, "1")), int( misc.find_element(root, "BlockYSize", True, "1") ) # block_size = tuple([int(getattr(root.find(f"Block{dim}Size"), "text", 0)) for dim in ["X", "Y"]]) warp_options = misc.find_element(root, "GDALWarpOptions", False, None) resample_algorithm = misc.resampling_gdal_to_rio( misc.find_element(warp_options, "ResampleAlg", True, "bilinear") ) dst_dtype = misc.dtype_gdal_to_numpy(misc.find_element(warp_options, "WorkingDataType", True, "float32")) warp_memory_limit = float(misc.find_element(warp_options, "WarpMemoryLimit", True, "0")) source_dataset_elem, source_dataset_text = misc.find_element(warp_options, "SourceDataset", text="both") source_dataset: str | Path = source_dataset_text if not source_dataset_text.startswith("/vsi"): source_dataset = Path(source_dataset) relative_filename = bool(int(source_dataset_elem.get("relativeToVRT", 0))) options = {} for option_elem in warp_options.findall("Option"): if (name := option_elem.get("name")) is not None: if option_elem.text is not None: options[name] = option_elem.text transformer = misc.find_element(warp_options, ["Transformer", "ApproxTransformer"]) max_error = float(getattr(transformer.find("MaxError"), "text", 0.125)) proj_transformer = misc.find_element(transformer, ["BaseTransformer", "GenImgProjTransformer"]) transforms = {} for key, gdal_key in [ ("src_transform", "SrcGeoTransform"), ("src_inv_transform", "SrcInvGeoTransform"), ("dst_transform", "DstGeoTransform"), ("dst_inv_transform", "DstInvGeoTransform"), ]: transforms[key] = misc.parse_gdal_transform(misc.find_element(proj_transformer, gdal_key, text=True)) band_mapping = [] for band_map in misc.find_element(warp_options, "BandList").findall("BandMapping"): src = band_map.get("src") if src is None: raise ValueError("Invalid src in BandMapping") dst = band_map.get("dst") if dst is None: raise ValueError("Invalid dst in BandMapping") band_mapping.append((int(src), int(dst))) return cls( shape=initial.shape, crs=initial.crs, transform=initial.transform, raster_bands=initial.raster_bands, crs_mapping=initial.crs_mapping, block_size=block_size, resample_algorithm=resample_algorithm, approximate=True, warp_memory_limit=warp_memory_limit, dst_dtype=dst_dtype, relative_filename=relative_filename, source_dataset=source_dataset, max_error=max_error, options=options, band_mapping=band_mapping, **transforms, )
[docs] def to_etree(self) -> ET.Element: vrt = ET.Element( "VRTDataset", {"rasterXSize": str(self.shape[1]), "rasterYSize": str(self.shape[0]), "subClass": "VRTWarpedDataset"}, ) crs = ET.SubElement(vrt, "SRS", {"dataAxisToSRSAxisMapping": self.crs_mapping}) crs.text = misc.crs_to_string(self.crs) transform = ET.SubElement(vrt, "GeoTransform") transform.text = misc.transform_to_gdal(self.transform) for band in self.raster_bands: vrt.append(band.to_etree()) for i, dim in enumerate(["X", "Y"]): size = ET.SubElement(vrt, f"Block{dim}Size") size.text = str(self.block_size[i]) warp = ET.SubElement(vrt, "GDALWarpOptions") warp.append(misc.new_element("WarpMemoryLimit", str(self.warp_memory_limit))) warp.append(misc.new_element("ResampleAlg", misc.resampling_rio_to_gdal(self.resample_algorithm))) warp.append( misc.new_element( "WorkingDataType", misc.dtype_numpy_to_gdal(self.dst_dtype), ) ) for key in self.options: warp.append(misc.new_element("Option", self.options[key], {"name": key})) warp.append( misc.new_element( "SourceDataset", str(self.source_dataset), {"relativeToVRT": str(int(self.relative_filename or 0))} ) ) transformer = ET.SubElement(ET.SubElement(warp, "Transformer"), "ApproxTransformer") transformer.append(misc.new_element("MaxError", str(self.max_error))) base_tr = ET.SubElement(ET.SubElement(transformer, "BaseTransformer"), "GenImgProjTransformer") for key, gdal_key in [ ("src_transform", "SrcGeoTransform"), ("src_inv_transform", "SrcInvGeoTransform"), ("dst_transform", "DstGeoTransform"), ("dst_inv_transform", "DstInvGeoTransform"), ]: base_tr.append(misc.new_element(gdal_key, misc.transform_to_gdal(getattr(self, key)).replace(" ", ""))) band_list = ET.SubElement(warp, "BandList") for src, dst in self.band_mapping: band_list.append(misc.new_element("BandMapping", None, {"src": str(src), "dst": str(dst)})) return vrt
[docs] def is_nested(self) -> bool: return hasattr(self.source_dataset, "to_tempfiles")
def _save_vrt_nested(self, filepath: Path, nested_level: list[int]) -> list[Path]: if len(nested_level) == 0: save_filepath = filepath else: save_filepath = filepath.with_stem(filepath.stem + "-nested-" + "-".join(map(str, nested_level))) nested_level += [0] filepaths = [save_filepath] vrt = self.copy() if vrt.is_nested(): new_nest = nested_level[:-1] + [1] new_filepaths = vrt.source_dataset._save_vrt_nested(filepath, new_nest) vrt.source_dataset = new_filepaths[0] vrt.relative_filename = False filepaths += new_filepaths vrt.save_vrt(save_filepath) # print(f"Saved {save_filepath}: {nested_level}") return filepaths
[docs] @classmethod # type: ignore def from_file( cls, filepath: Path | str, dst_crs: CRS | int | str, **kwargs: dict[str, Any] ) -> VRTDataset: # type: ignore with tempfile.TemporaryDirectory() as temp_dir: temp_vrt = Path(temp_dir).joinpath("temp.vrt") build_warped_vrt(vrt_filepath=temp_vrt, filepath=filepath, dst_crs=dst_crs, **kwargs) # type: ignore vrt = cls.load_vrt(temp_vrt) # Nodata values are not transferred with GDALs WarpedVRT builder, so this has to be done manually with rio.open(filepath) as raster: for band in vrt.raster_bands: band.nodata = raster.nodata return vrt
AnyVRTDataset = VRTDataset | WarpedVRTDataset
[docs]def dataset_from_etree(elem: ET.Element) -> AnyVRTDataset: if elem.tag != "VRTDataset": raise ValueError(f"Invalid root tag for VRT: {elem.tag}") subclass = elem.get("subClass") if subclass == "VRTWarpedDataset": return WarpedVRTDataset.from_etree(elem) if subclass is not None: warnings.warn(f"Unexpected subClass tag: {subclass}. Ignoring it", stacklevel=2) return VRTDataset.from_etree(elem)
[docs]def load_vrt(filepath: str | Path) -> AnyVRTDataset: with open(filepath) as infile: root = ET.fromstring(infile.read()) return dataset_from_etree(root)