Skip to content

Commit

Permalink
add check blob url in nightly_build
Browse files Browse the repository at this point in the history
Signed-off-by: hwangdeyu <[email protected]>
  • Loading branch information
hwangdeyu committed Dec 20, 2021
1 parent ec8b14c commit 45247dd
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
13 changes: 3 additions & 10 deletions tests/keras2onnx_applications/nightly_build/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import tensorflow as tf
from onnxconverter_common.onnx_ex import get_maximum_opset_supported
from test_utils import check_bloburl_access

sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
from test_utils import run_onnx_runtime
Expand All @@ -21,20 +22,12 @@
if os.environ.get('ENABLE_FULL_TRANSFORMER_TEST', '0') != '0':
enable_transformer_test = True

def check_bloburl_access():
url = r'https://lotus.blob.core.windows.net/converter-models/transformer_tokenizer/'
try:
response = urllib.request.urlopen(url)
if response.getcode() != 200:
return False
except urllib.error.URLError:
return False
return True
CONVERTER_TRANSFERMER_PATH = r'https://lotus.blob.core.windows.net/converter-models/transformer_tokenizer/'


@unittest.skipIf(is_tensorflow_older_than('2.1.0'),
"Transformers conversion need tensorflow 2.1.0+")
@unittest.skipIf(check_bloburl_access(), "Model blob url can't access")
@unittest.skipIf(check_bloburl_access(CONVERTER_TRANSFERMER_PATH), "Model blob url can't access.")
class TestTransformers(unittest.TestCase):

text_str = 'The quick brown fox jumps over lazy dog.'
Expand Down
13 changes: 2 additions & 11 deletions tests/keras2onnx_applications/nightly_build/test_yolov3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from os.path import dirname, abspath
sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
from keras.models import load_model
from test_utils import check_bloburl_access

import urllib.request
YOLOV3_WEIGHTS_PATH = r'https://lotus.blob.core.windows.net/converter-models/yolov3.h5'
Expand All @@ -25,16 +26,6 @@
tmp_path = os.path.join(working_path, 'temp')


def check_bloburl_access(url):
try:
response = urllib.request.urlopen(url)
if response.getcode() != 200:
return False
except urllib.error.URLError:
return False
return True


class TestYoloV3(unittest.TestCase):

def setUp(self):
Expand All @@ -56,7 +47,7 @@ def post_compute(self, all_boxes, all_scores, indices):
@unittest.skipIf(StrictVersion(onnx.__version__.split('-')[0]) < StrictVersion("1.5.0"),
"NonMaxSuppression op is not supported for onnx < 1.5.0.")
@unittest.skipIf(check_bloburl_access(YOLOV3_WEIGHTS_PATH) or check_bloburl_access(YOLOV3_TINY_WEIGHTS_PATH),
"Model blob url can't access")
"Model blob url can't access.")
def test_yolov3(self):
img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg')
yolo3_yolo3_dir = os.path.join(os.path.dirname(__file__), '../../../keras-yolo3/yolo3')
Expand Down
11 changes: 10 additions & 1 deletion tests/keras2onnx_unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import mock_keras2onnx
from mock_keras2onnx.proto import keras, is_keras_older_than
from mock_keras2onnx.proto.tfcompat import is_tf2
#from mock_keras2onnx.common.onnx_ops import apply_identity, OnnxOperatorBuilder
# from mock_keras2onnx.common.onnx_ops import apply_identity, OnnxOperatorBuilder
import time
import json
import urllib

working_path = os.path.abspath(os.path.dirname(__file__))
tmp_path = os.path.join(working_path, 'temp')
Expand Down Expand Up @@ -282,3 +283,11 @@ def run_image(model, model_files, img_path, model_name='onnx_conversion', rtol=1
onnx_model = mock_keras2onnx.convert_keras(model, model.name)
res = run_onnx_runtime(model_name, onnx_model, x, preds, model_files, rtol=rtol, atol=atol, compare_perf=compare_perf)
return res, msg


def check_bloburl_access(url):
try:
response = urllib.request.urlopen(url)
return response.getcode() == 200
except urllib.error.URLError:
return False

0 comments on commit 45247dd

Please sign in to comment.