Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: gdal concurrency TDE-457 #96

Merged
merged 13 commits into from
Aug 23, 2022
2 changes: 1 addition & 1 deletion scripts/aws/aws_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def parse_path(path: str) -> S3Path:
path (str): A S3 path.

Returns:
S3Path (NamedTupe): s3_path.bucket (str), s3_path.key (str)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% related to this, but does vscode not detect automatically the types of all of these, we shouldnt be doubling up the typing in both the function definition and the docs as it often leads to them being out of sync.

S3Path (NamedTuple): s3_path.bucket (str), s3_path.key (str)
"""
parse = urlparse(path, allow_fragments=False)
return S3Path(parse.netloc, parse.path[1:])
Expand Down
5 changes: 5 additions & 0 deletions scripts/aws/tests/aws_helper_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from scripts.aws.aws_helper import parse_path
from scripts.cli.cli_helper import is_argo


def test_parse_path() -> None:
Expand All @@ -7,3 +8,7 @@ def test_parse_path() -> None:

assert path.bucket == "bucket-name"
assert path.key == "path/to/the/file.test"


def test_is_argo() -> None:
assert not is_argo()
5 changes: 5 additions & 0 deletions scripts/cli/cli_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
from os import environ
from typing import List

from linz_logger import get_log
Expand Down Expand Up @@ -30,3 +31,7 @@ def parse_source() -> List[str]:
arguments = parser.parse_args()

return format_source(arguments.source)


def is_argo() -> bool:
return bool(environ.get("ARGO_TEMPLATE"))
2 changes: 1 addition & 1 deletion scripts/gdal/gdal_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,6 @@ def run_gdal(
get_log().error("run_gdal_error", command=command_to_string(command), error=proc.stderr.decode())
raise GDALExecutionException(proc.stderr.decode())

get_log().debug("run_gdal_succeded", command=command_to_string(command), stdout=proc.stdout.decode())
get_log().debug("run_gdal_succeeded", command=command_to_string(command), stdout=proc.stdout.decode())

return proc
9 changes: 6 additions & 3 deletions scripts/standardise_validate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from scripts.cli.cli_helper import parse_source
from scripts.cli.cli_helper import is_argo, parse_source
from scripts.non_visual_qa import non_visual_qa
from scripts.standardising import standardising
from scripts.standardising import start_standardising


def main() -> None:
concurrency: int = 1
source = parse_source()
standardised_files = standardising(source)
if is_argo():
concurrency = 4
standardised_files = start_standardising(source, concurrency)
non_visual_qa(standardised_files)


Expand Down
120 changes: 69 additions & 51 deletions scripts/standardising.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,98 @@
import os
from multiprocessing import Pool
from typing import List

from linz_logger import get_log

from scripts.aws.aws_helper import parse_path
from scripts.cli.cli_helper import parse_source
from scripts.cli.cli_helper import is_argo, parse_source
from scripts.files.files_helper import get_file_name_from_path, is_tiff
from scripts.gdal.gdal_helper import run_gdal
from scripts.logging.time_helper import time_in_ms


def standardising(files: List[str]) -> List[str]:
def start_standardising(files: List[str], concurrency: int) -> List[str]:
start_time = time_in_ms()
output_folder = "/tmp/"
tiff_files = []
output_files = []

get_log().info("standardising_start", source=files)

for file in files:
if not is_tiff(file):
if is_tiff(file):
tiff_files.append(file)
else:
get_log().info("standardising_file_not_tiff_skipped", file=file)
continue

_, src_file_path = parse_path(file)
standardized_file_name = f"standardized_{get_file_name_from_path(src_file_path)}"
tmp_file_path = os.path.join(output_folder, standardized_file_name)
command = [
"gdal_translate",
"-q",
"-scale",
"0",
"255",
"0",
"254",
"-a_srs",
"EPSG:2193",
"-a_nodata",
"255",
"-b",
"1",
"-b",
"2",
"-b",
"3",
"-of",
"COG",
"-co",
"compress=lzw",
"-co",
"num_threads=all_cpus",
"-co",
"predictor=2",
"-co",
"overview_compress=webp",
"-co",
"bigtiff=yes",
"-co",
"overview_resampling=lanczos",
"-co",
"blocksize=512",
"-co",
"overview_quality=90",
"-co",
"sparse_ok=true",
]
run_gdal(command, input_file=file, output_file=tmp_file_path)
output_files.append(tmp_file_path)

with Pool(concurrency) as p:
output_files = p.map(standardising, tiff_files)
p.close()
p.join()

get_log().info("standardising_end", source=files, duration=time_in_ms() - start_time)

return output_files


def standardising(file: str) -> str:
output_folder = "/tmp/"

get_log().info("standardising_start", source=file)

_, src_file_path = parse_path(file)
standardized_file_name = f"standardized_{get_file_name_from_path(src_file_path)}"
tmp_file_path = os.path.join(output_folder, standardized_file_name)

command = [
"gdal_translate",
"-q",
"-scale",
"0",
"255",
"0",
"254",
"-a_srs",
"EPSG:2193",
"-a_nodata",
"255",
"-b",
"1",
"-b",
"2",
"-b",
"3",
"-of",
"COG",
"-co",
"compress=lzw",
"-co",
"num_threads=all_cpus",
"-co",
"predictor=2",
"-co",
"overview_compress=webp",
"-co",
"bigtiff=yes",
"-co",
"overview_resampling=lanczos",
"-co",
"blocksize=512",
"-co",
"overview_quality=90",
"-co",
"sparse_ok=true",
]
run_gdal(command, input_file=file, output_file=tmp_file_path)

return tmp_file_path


def main() -> None:
concurrency: int = 1
source = parse_source()
standardising(source)
if is_argo():
concurrency = 4
start_standardising(source, concurrency)


if __name__ == "__main__":
Expand Down