diff --git a/.travis.yml b/.travis.yml index 2d2a810..19a9f7e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,6 @@ notifications: install: - pip install -r test-requirements.txt -- pip install . script: - flake8 diff --git a/numpy-stubs/__init__.pyi b/numpy-stubs/__init__.pyi index 75d548f..eb9f21b 100644 --- a/numpy-stubs/__init__.pyi +++ b/numpy-stubs/__init__.pyi @@ -1,14 +1,32 @@ import builtins +import sys +from numpy.core._internal import _ctypes from typing import ( - Any, Dict, Iterable, List, Optional, Mapping, Sequence, Sized, - SupportsInt, SupportsFloat, SupportsComplex, SupportsBytes, SupportsAbs, - Text, Tuple, Type, TypeVar, Union, + Any, + Container, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Sized, + SupportsAbs, + SupportsComplex, + SupportsFloat, + SupportsInt, + Text, + Tuple, + Type, + TypeVar, + Union, ) -import sys - -from numpy.core._internal import _ctypes +if sys.version_info[0] < 3: + class SupportsBytes: ... +else: + from typing import SupportsBytes _Shape = Tuple[int, ...] @@ -325,7 +343,7 @@ class _ArrayOrScalarCommon(SupportsInt, SupportsFloat, SupportsComplex, def __getattr__(self, name) -> Any: ... -class ndarray(_ArrayOrScalarCommon, Iterable, Sized): +class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container): real: ndarray imag: ndarray diff --git a/test-requirements.txt b/test-requirements.txt index 512c913..32a8c1f 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -2,3 +2,6 @@ flake8==3.3.0 flake8-pyi==17.3.0 pytest==3.4.2 mypy==0.590 +# This makes sure that the repo's stubs are accessible. Using MYPYPATH won't +# work. See https://github.com/python/mypy/issues/5007 for more details. +. diff --git a/tests/README.md b/tests/README.md index e548abf..91c0c48 100644 --- a/tests/README.md +++ b/tests/README.md @@ -22,6 +22,9 @@ reveal_type(x) # E: Right now, the error messages and types are must be **contained within corresponding mypy message**. +Test files that end in `_py3.py` will only be type checked against Python 3. +All other test files must be valid in both Python 2 and Python 3. + ## Running the tests We use `py.test` to orchestrate our tests. You can just run: @@ -34,6 +37,23 @@ to run the entire test suite. To run `mypy` on a specific file (which can be useful for debugging), you can also run: ``` -$ cd tests -$ MYPYPATH=.. mypy +mypy +``` + +Note that for either of these commands, you must run: + +``` +pip install -r test-requirements.txt +``` + +for the version of python that you're going to be running `py.test` or `mypy` +with. To ensure you're using the intended version of Python you can use +`python -m` versions of these commands instead: + +``` +python -m pytest +python -m mypy +python -m pip install -r test-requirements.txt ``` +Due to how mypy reads type information in PEP 561 packages, you'll need +to re-run the `pip install` command each time you change the stubs. diff --git a/tests/pass/simple.py b/tests/pass/simple.py index dd8bab8..e4b52ed 100644 --- a/tests/pass/simple.py +++ b/tests/pass/simple.py @@ -2,11 +2,12 @@ import operator import numpy as np -from typing import Iterable +from typing import Iterable # noqa: F401 # Basic checks array = np.array([1, 2]) -def ndarray_func(x: np.ndarray) -> np.ndarray: +def ndarray_func(x): + # type: (np.ndarray) -> np.ndarray return x ndarray_func(np.array([1, 2])) array == 1 @@ -28,7 +29,8 @@ def ndarray_func(x: np.ndarray) -> np.ndarray: np.dtype((np.int32, (np.int8, 4))) # Iteration and indexing -def iterable_func(x: Iterable) -> Iterable: +def iterable_func(x): + # type: (Iterable) -> Iterable return x iterable_func(array) [element for element in array] @@ -122,8 +124,6 @@ def iterable_func(x: Iterable) -> Iterable: 1 | array array |= 1 -array @ array - # unary arithmetic -array +array diff --git a/tests/pass/simple_py3.py b/tests/pass/simple_py3.py new file mode 100644 index 0000000..c05a1ce --- /dev/null +++ b/tests/pass/simple_py3.py @@ -0,0 +1,6 @@ +import numpy as np + +array = np.array([1, 2]) + +# The @ operator is not in python 2 +array @ array diff --git a/tests/test_stubs.py b/tests/test_stubs.py index fdf3b62..b0e8809 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -3,37 +3,48 @@ import pytest from mypy import api -ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) -PASS_DIR = os.path.join(os.path.dirname(__file__), "pass") -FAIL_DIR = os.path.join(os.path.dirname(__file__), "fail") -REVEAL_DIR = os.path.join(os.path.dirname(__file__), "reveal") - -os.environ['MYPYPATH'] = ROOT_DIR +TESTS_DIR = os.path.dirname(__file__) +PASS_DIR = os.path.join(TESTS_DIR, "pass") +FAIL_DIR = os.path.join(TESTS_DIR, "fail") +REVEAL_DIR = os.path.join(TESTS_DIR, "reveal") def get_test_cases(directory): for root, __, files in os.walk(directory): for fname in files: if os.path.splitext(fname)[-1] == ".py": - # yield relative path for nice py.test name - yield os.path.relpath( - os.path.join(root, fname), start=directory) - - -@pytest.mark.parametrize("path", get_test_cases(PASS_DIR)) -def test_success(path): - stdout, stderr, exitcode = api.run([os.path.join(PASS_DIR, path)]) + fullpath = os.path.join(root, fname) + # Use relative path for nice py.test name + relpath = os.path.relpath(fullpath, start=directory) + skip_py2 = fname.endswith("_py3.py") + + for py_version_number in (2, 3): + if py_version_number == 2 and skip_py2: + continue + py2_arg = ['--py2'] if py_version_number == 2 else [] + + yield pytest.param( + fullpath, + py2_arg, + # Manually specify a name for the test + id="{} - python{}".format(relpath, py_version_number), + ) + + +@pytest.mark.parametrize("path,py2_arg", get_test_cases(PASS_DIR)) +def test_success(path, py2_arg): + stdout, stderr, exitcode = api.run([path] + py2_arg) assert stdout == '' assert exitcode == 0 -@pytest.mark.parametrize("path", get_test_cases(FAIL_DIR)) -def test_fail(path): - stdout, stderr, exitcode = api.run([os.path.join(FAIL_DIR, path)]) +@pytest.mark.parametrize("path,py2_arg", get_test_cases(FAIL_DIR)) +def test_fail(path, py2_arg): + stdout, stderr, exitcode = api.run([path] + py2_arg) assert exitcode != 0 - with open(os.path.join(FAIL_DIR, path)) as fin: + with open(path) as fin: lines = fin.readlines() errors = {} @@ -59,11 +70,11 @@ def test_fail(path): pytest.fail(f'Error {repr(errors[lineno])} not found') -@pytest.mark.parametrize("path", get_test_cases(REVEAL_DIR)) -def test_reveal(path): - stdout, stderr, exitcode = api.run([os.path.join(REVEAL_DIR, path)]) +@pytest.mark.parametrize("path,py2_arg", get_test_cases(REVEAL_DIR)) +def test_reveal(path, py2_arg): + stdout, stderr, exitcode = api.run([path] + py2_arg) - with open(os.path.join(REVEAL_DIR, path)) as fin: + with open(path) as fin: lines = fin.readlines() for error_line in stdout.split("\n"):