diff --git a/src/sagemaker/base_deserializers.py b/src/sagemaker/base_deserializers.py index 7162e5274d..a152f0144d 100644 --- a/src/sagemaker/base_deserializers.py +++ b/src/sagemaker/base_deserializers.py @@ -196,14 +196,14 @@ class NumpyDeserializer(SimpleBaseDeserializer): single array. """ - def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True): + def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=False): """Initialize a ``NumpyDeserializer`` instance. Args: dtype (str): The dtype of the data (default: None). accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that is expected from the inference endpoint (default: "application/x-npy"). - allow_pickle (bool): Allow loading pickled object arrays (default: True). + allow_pickle (bool): Allow loading pickled object arrays (default: False). """ super(NumpyDeserializer, self).__init__(accept=accept) self.dtype = dtype @@ -227,10 +227,21 @@ def deserialize(self, stream, content_type): if content_type == "application/json": return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype) if content_type == "application/x-npy": - return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle) + try: + return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle) + except ValueError as ve: + raise ValueError( + "Please set the param allow_pickle=True \ + to deserialize pickle objects in NumpyDeserializer" + ).with_traceback(ve.__traceback__) if content_type == "application/x-npz": try: return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle) + except ValueError as ve: + raise ValueError( + "Please set the param allow_pickle=True \ + to deserialize pickle objectsin NumpyDeserializer" + ).with_traceback(ve.__traceback__) finally: stream.close() finally: diff --git a/tests/unit/sagemaker/deserializers/test_deserializers.py b/tests/unit/sagemaker/deserializers/test_deserializers.py index b8ede11ba5..cb1923a094 100644 --- a/tests/unit/sagemaker/deserializers/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/test_deserializers.py @@ -142,7 +142,8 @@ def test_numpy_deserializer_from_npy(numpy_deserializer): assert np.array_equal(array, result) -def test_numpy_deserializer_from_npy_object_array(numpy_deserializer): +def test_numpy_deserializer_from_npy_object_array(): + numpy_deserializer = NumpyDeserializer(allow_pickle=True) array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}]) stream = io.BytesIO() np.save(stream, array)