Skip to content

Commit

Permalink
Merge pull request #1368 from adriangonz/update-storage
Browse files Browse the repository at this point in the history
Update Storage.py and initialiser image
  • Loading branch information
seldondev authored Feb 4, 2020
2 parents c6e49d7 + 5205434 commit a72257f
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 24 deletions.
107 changes: 94 additions & 13 deletions python/seldon_core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from azure.storage.blob import BlockBlobService
from minio import Minio
from seldon_core.imports_helper import _GCS_PRESENT
from seldon_core.utils import getenv

if _GCS_PRESENT:
from google.auth import exceptions
Expand Down Expand Up @@ -78,6 +79,7 @@ def _download_s3(uri, temp_dir: str):
bucket_name = bucket_args[0]
bucket_path = bucket_args[1] if len(bucket_args) > 1 else ""
objects = client.list_objects(bucket_name, prefix=bucket_path, recursive=True)
count = 0
for obj in objects:
# Replace any prefix from the object key with temp_dir
subdir_object_key = obj.object_name.replace(bucket_path, "", 1).strip("/")
Expand All @@ -90,6 +92,13 @@ def _download_s3(uri, temp_dir: str):
obj.object_name,
os.path.join(temp_dir, subdir_object_key),
)
count = count + 1
if count == 0:
raise RuntimeError(
"Failed to fetch model. \
The path or model %s does not exist."
% (uri)
)

@staticmethod
def _download_gcs(uri, temp_dir: str):
Expand All @@ -105,6 +114,7 @@ def _download_gcs(uri, temp_dir: str):
if not prefix.endswith("/"):
prefix = prefix + "/"
blobs = bucket.list_blobs(prefix=prefix)
count = 0
for blob in blobs:
# Replace any prefix from the object key with temp_dir
subdir_object_key = blob.name.replace(bucket_path, "", 1).strip("/")
Expand All @@ -120,38 +130,102 @@ def _download_gcs(uri, temp_dir: str):
dest_path = os.path.join(temp_dir, subdir_object_key)
logging.info("Downloading: %s", dest_path)
blob.download_to_filename(dest_path)
count = count + 1
if count == 0:
raise RuntimeError(
"Failed to fetch model. \
The path or model %s does not exist."
% (uri)
)

@staticmethod
def _download_blob(uri, out_dir: str):
def _download_blob(uri, out_dir: str): # pylint: disable=too-many-locals
match = re.search(_BLOB_RE, uri)
account_name = match.group(1)
storage_url = match.group(2)
container_name, prefix = storage_url.split("/", 1)

logging.info(
"Connecting to BLOB account: %s, contianer: %s",
"Connecting to BLOB account: [%s], container: [%s], prefix: [%s]",
account_name,
container_name,
prefix,
)
block_blob_service = BlockBlobService(account_name=account_name)
blobs = block_blob_service.list_blobs(container_name, prefix=prefix)

try:
block_blob_service = BlockBlobService(account_name=account_name)
blobs = block_blob_service.list_blobs(container_name, prefix=prefix)
except Exception: # pylint: disable=broad-except
token = Storage._get_azure_storage_token()
if token is None:
logging.warning(
"Azure credentials not found, retrying anonymous access"
)
block_blob_service = BlockBlobService(
account_name=account_name, token_credential=token
)
blobs = block_blob_service.list_blobs(container_name, prefix=prefix)
count = 0
for blob in blobs:
dest_path = os.path.join(out_dir, blob.name)
if "/" in blob.name:
head, _ = os.path.split(blob.name)
head, tail = os.path.split(blob.name)
if prefix is not None:
head = head[len(prefix) :]
if head.startswith("/"):
head = head[1:]
dir_path = os.path.join(out_dir, head)
dest_path = os.path.join(dir_path, tail)
if not os.path.isdir(dir_path):
os.makedirs(dir_path)

dest_path = os.path.join(out_dir, blob.name)
logging.info("Downloading: %s", dest_path)
logging.info("Downloading: %s to %s", blob.name, dest_path)
block_blob_service.get_blob_to_path(container_name, blob.name, dest_path)
count = count + 1
if count == 0:
raise RuntimeError(
"Failed to fetch model. \
The path or model %s does not exist."
% (uri)
)

@staticmethod
def _get_azure_storage_token():
tenant_id = os.getenv("AZ_TENANT_ID", "")
client_id = os.getenv("AZ_CLIENT_ID", "")
client_secret = os.getenv("AZ_CLIENT_SECRET", "")
subscription_id = os.getenv("AZ_SUBSCRIPTION_ID", "")

if (
tenant_id == ""
or client_id == ""
or client_secret == ""
or subscription_id == ""
):
return None

# note the SP must have "Storage Blob Data Owner" perms for this to work
import adal
from azure.storage.common import TokenCredential

authority_url = "https://login.microsoftonline.com/" + tenant_id

context = adal.AuthenticationContext(authority_url)

token = context.acquire_token_with_client_credentials(
"https://storage.azure.com/", client_id, client_secret
)

token_credential = TokenCredential(token["accessToken"])

logging.info("Retrieved SP token credential for client_id: %s", client_id)

return token_credential

@staticmethod
def _download_local(uri, out_dir=None):
local_path = uri.replace(_LOCAL_PREFIX, "", 1)
if not os.path.exists(local_path):
raise Exception("Local path %s does not exist." % (uri))
raise RuntimeError("Local path %s does not exist." % (uri))

if out_dir is None:
return local_path
Expand All @@ -171,14 +245,21 @@ def _download_local(uri, out_dir=None):
@staticmethod
def _create_minio_client():
# Remove possible http scheme for Minio
url = urlparse(os.getenv("S3_ENDPOINT", ""))
url = urlparse(os.getenv("AWS_ENDPOINT_URL", "s3.amazonaws.com"))
use_ssl = (
url.scheme == "https" if url.scheme else bool(os.getenv("USE_SSL", True))
url.scheme == "https"
if url.scheme
# KFServing uses S3_USE_HTTPS, whereas Seldon was already using
# USE_SSL.
# To keep compatibility with the storage init layer we support
# both, giving priority to USE_SSL.
# https://github.com/SeldonIO/seldon-core/pull/827
# https://github.com/kubeflow/kfserving/pull/362
else bool(getenv("USE_SSL", "S3_USE_HTTPS", "false"))
)
minioClient = Minio(
return Minio(
url.netloc,
access_key=os.getenv("AWS_ACCESS_KEY_ID", ""),
secret_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
secure=use_ssl,
)
return minioClient
29 changes: 25 additions & 4 deletions python/seldon_core/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import json
import sys
import base64
Expand Down Expand Up @@ -520,21 +521,17 @@ def extract_request_parts_json(
if not isinstance(request, dict):
raise SeldonMicroserviceException(f"Invalid request data type: {request}")
meta = request.get("meta", None)
datadef_type = None
datadef = None

if "data" in request:
data_type = "data"
datadef = request["data"]
if "tensor" in datadef:
datadef_type = "tensor"
tensor = datadef["tensor"]
features = np.array(tensor["values"]).reshape(tensor["shape"])
elif "ndarray" in datadef:
datadef_type = "ndarray"
features = np.array(datadef["ndarray"])
elif "tftensor" in datadef:
datadef_type = "tftensor"
tf_proto = TensorProto()
json_format.ParseDict(datadef["tftensor"], tf_proto)
features = tf.make_ndarray(tf_proto)
Expand Down Expand Up @@ -597,3 +594,27 @@ def extract_feedback_request_parts(
truth = grpc_datadef_to_array(request.truth.data)
reward = request.reward
return request.request.data, features, truth, reward


def getenv(*env_vars, default=None):
"""
Overload of os.getenv() to allow falling back through multiple environment
variables. The environment variables will be checked sequentially until one
of them is found.
Parameters
------
*env_vars
Variadic list of environment variable names to check.
default
Default value to return if none of the environment variables exist.
Returns
------
Value of the first environment variable set or default.
"""
for env_var in env_vars:
if env_var in os.environ:
return os.environ.get(env_var)

return default
Loading

0 comments on commit a72257f

Please sign in to comment.