Skip to content

Commit

Permalink
feat: add standardising script (TDE-303) (#23)
Browse files Browse the repository at this point in the history
* feat: add standardising script

* chore: organize imports

* fix: set the AWS credentials in the command environment rather than in the global (os) environment

* fix: modify the way the source path is modified for gdal to allow local files
  • Loading branch information
paulfouquet authored Jun 23, 2022
1 parent ab2efc4 commit fd60e53
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 4 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ RUN poetry config virtualenvs.create false \

# Copy Python scripts
COPY ./scripts/create_polygons.py /app/
COPY ./scripts/standardising.py /app/
COPY ./scripts/aws_helper.py /app/
151 changes: 150 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ boto3 = "^1.24.12"
linz-logger = "^0.6.0"

[tool.poetry.dev-dependencies]
pytest = "^7.1.2"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
Empty file added scripts/__init__.py
Empty file.
13 changes: 10 additions & 3 deletions scripts/aws_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
from os import environ
from typing import Tuple
from urllib.parse import urlparse

import boto3
from linz_logger import get_log
Expand Down Expand Up @@ -35,20 +37,21 @@ def init_roles():
bucket_roles[cfg["bucket"]] = cfg

def get_credentials(bucket_name: str):
get_log().debug("get_credentials_bucket_name", bucket_name=bucket_name)
get_log().debug("get_credentials", bucket_name=bucket_name)
if not bucket_roles:
init_roles()
if bucket_name in bucket_roles:
# FIXME: check if the token is expired - add a parameter
if bucket_name not in bucket_credentials:
role_arn = bucket_roles[bucket_name]["roleArn"]
get_log().debug("s3_assume_role", bucket_name=bucket_name, role_arn=role_arn)
get_log().debug("sts_assume_role", bucket_name=bucket_name, role_arn=role_arn)
assumed_role_object = client_sts.assume_role(RoleArn=role_arn, RoleSessionName="gdal")
bucket_credentials[bucket_name] = Credentials(
assumed_role_object["Credentials"]["AccessKeyId"],
assumed_role_object["Credentials"]["SecretAccessKey"],
assumed_role_object["Credentials"]["SessionToken"],
)

return bucket_credentials[bucket_name]

session_credentials = session.get_credentials()
Expand All @@ -70,6 +73,10 @@ def get_bucket(bucket_name):

return s3_resource.Bucket(bucket_name)

def bucket_name_from_path(path: str) -> str:
def get_bucket_name_from_path(path: str) -> str:
path_parts = path.replace("s3://", "").split("/")
return path_parts.pop(0)

def parse_path(path: str) -> Tuple[str,str]:
parse = urlparse(path, allow_fragments=False)
return parse.netloc, parse.path
50 changes: 50 additions & 0 deletions scripts/standardising.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse
import os
import subprocess
import tempfile

from linz_logger import get_log

import aws_helper

parser = argparse.ArgumentParser()
parser.add_argument('--source',dest='source', required=True)
parser.add_argument('--destination', dest='destination', required=True)
arguments = parser.parse_args()
source = arguments.source
destination = arguments.destination
#TODO if destination needs write permission we have to handle this

get_log().info("standardising", source=source, destination=destination)

src_bucket_name, src_file_path = aws_helper.parse_path(source)
dst_bucket_name, dst_path = aws_helper.parse_path(destination)
get_log().debug("source", bucket=src_bucket_name, file_path=src_file_path)
get_log().debug("destination", bucket=dst_bucket_name, file_path=dst_path)
dst_bucket = aws_helper.get_bucket(dst_bucket_name)

with tempfile.TemporaryDirectory() as tmp_dir:
standardized_file_name = f"standardized_{os.path.basename(src_file_path)}"
tmp_file_path = os.path.join(tmp_dir, standardized_file_name)
src_gdal_path = source.replace('s3://', '/vsis3/')

# Set the credentials for GDAL to be able to read the source file
credentials = aws_helper.get_credentials(src_bucket_name)
gdal_env = os.environ.copy()
gdal_env["AWS_ACCESS_KEY_ID"] = credentials.access_key
gdal_env["AWS_SECRET_ACCESS_KEY"] = credentials.secret_key
gdal_env["AWS_SESSION_TOKEN"] = credentials.token

# Run GDAL to standardized the file
get_log().debug("run_gdal_translate", src=src_gdal_path, output=tmp_file_path)
command = ["gdal_translate", "-q", "-scale", "0", "255", "0", "254", "-a_srs", "EPSG:2193", "-a_nodata", "255", "-b", "1", "-b", "2", "-b", "3", "-co", "compress=lzw", src_gdal_path, tmp_file_path]
proc = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=gdal_env)
if proc.returncode != 0:
get_log().error("run_gdal_translate_failed", command=" ".join(command))
raise Exception(proc.stderr.decode())
get_log().debug("run_gdal_translate_succeded", command=" ".join(command))

# Upload the standardized file to destination
dst_file_path = os.path.join(dst_path, standardized_file_name).strip("/")
get_log().debug("upload_file", path=dst_file_path)
dst_bucket.upload_file(tmp_file_path, dst_file_path)
11 changes: 11 additions & 0 deletions scripts/tests/aws_helper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

from scripts.aws_helper import parse_path


def test_parse_path() -> None:
s3_path = "s3://bucket-name/path/to/the/file.test"
bucket_name, file_path = parse_path(s3_path)

assert bucket_name == "bucket-name"
assert file_path == "/path/to/the/file.test"

0 comments on commit fd60e53

Please sign in to comment.