Source code for variete.vrt.pixel_functions

"""
Pixel functions that allow arithmetic and other complex operations.
"""
from __future__ import annotations

from typing import Literal

import lxml.etree as ET

from variete import misc


[docs]class PixelFunction: name: str arguments: dict[str, str] | None code: str | None language: Literal["Python"] | None
[docs] def __init__( self, name: str, arguments: dict[str, str] | None = None, code: str | None = None, language: Literal["Python"] | None = None, ): for attr in ["arguments", "name", "code", "language"]: setattr(self, attr, locals()[attr])
[docs] def validate(self, n_bands: int) -> None: assert len(self.name) > 0, "The PixelFunction must have a name" assert self.language is None or self.language == "Python", "PixelFunction language must be None or 'Python'" assert n_bands > 0, "PixelFunction requires at least one band" if self.code is not None and self.language is None: raise AssertionError("PixelFunction code is given but the language is not provided")
[docs] def to_etree_keys(self) -> list[ET.Element]: keys = [] keys.append(misc.new_element("PixelFunctionType", self.name)) keys.append(misc.new_element("PixelFunctionArguments", None, self.arguments)) if self.language is not None: keys.append(misc.new_element("PixelFunctionLanguage", self.language)) if self.code is not None: keys.append(misc.new_element("PixelFunctionCode", self.code)) return keys
def __repr__(self) -> str: code_string = f" {len(self.code)} character custom code" if self.code is not None else "" return "\n".join([f"PixelFunction: {self.name}({self.arguments or ''}){code_string}"])
[docs]class SumPixelFunction(PixelFunction):
[docs] def __init__(self, constant: int | float | str | None = None): self.name = "sum" self.arguments = {"k": misc.number_to_gdal(constant)} if constant is not None else None self.code = self.language = None
[docs] def validate(self, n_bands: int) -> None: PixelFunction.validate(self, n_bands=n_bands) if "k" in (self.arguments or {}): n_bands += 1 if n_bands < 2: raise ValueError("SumPixelFunction requires at least two bands (or one band and a constant")
[docs]class MulPixelFunction(PixelFunction):
[docs] def __init__(self, constant: int | float | str | None = None) -> None: base = SumPixelFunction(constant=constant) self.name = "mul" for attr in ["arguments", "code", "language"]: setattr(self, attr, getattr(base, attr))
[docs]class DivPixelFunction(PixelFunction):
[docs] def __init__(self, constant: int | float | str | None = None) -> None: base = SumPixelFunction(constant=constant) self.name = "div" for attr in ["arguments", "code", "language"]: setattr(self, attr, getattr(base, attr))
[docs]class InvPixelFunction(PixelFunction):
[docs] def __init__(self) -> None: self.name = "inv" self.arguments = self.code = self.language = None
[docs]class ScalePixelFunction(PixelFunction):
[docs] def __init__(self) -> None: self.name = "scale" self.arguments = self.code = self.language = None
[docs]class ReplaceNodataPixelFunction(PixelFunction):
[docs] def __init__(self, value: int | float | str) -> None: self.name = "replace_nodata" self.arguments = {"to": misc.number_to_gdal(value)} self.code = self.language = None
AnyPixelFunction = ( PixelFunction | SumPixelFunction | ScalePixelFunction | MulPixelFunction | DivPixelFunction | InvPixelFunction )
[docs]def pixel_function_from_etree(elem: ET.Element) -> AnyPixelFunction: if (pixel_function_type_elem := elem.find("PixelFunctionType")) is not None: name = pixel_function_type_elem.text if name is None: raise AssertionError("PixelFunctionType key is empty. Invalid PixelFunction") else: raise AssertionError("Key PixelFunctionType does not exist. Invalid PixelFunction") language = getattr(elem.find("PixelFunctionLanguage"), "text", None) if (arguments_elem := elem.find("PixelFunctionArguments")) is not None: arguments = dict(arguments_elem.items()) else: arguments = {} code = getattr(elem.find("PixelFunctionCode"), "text", None) if code is None: if name == "sum": return SumPixelFunction(arguments.get("k", None)) if name == "scale": return ScalePixelFunction() if name == "mul": return MulPixelFunction(arguments.get("k", None)) if name == "div": return DivPixelFunction(arguments.get("k", None)) if name == "inv": return InvPixelFunction() if name == "replace_nodata": return ReplaceNodataPixelFunction(arguments["to"]) raise ValueError(f"Empty PixelFunctionCode and unknown type: {name}") return PixelFunction(name=name, arguments=arguments, code=code, language=language)