Skip to content
This repository has been archived by the owner on Jun 10, 2020. It is now read-only.

Commit

Permalink
Ensure stubs are valid for Python 2 and fix running of tests (#19)
Browse files Browse the repository at this point in the history
* Ensure stubs are valid for Python 2 and fix running of tests

The stubs contained an unconditional reference to SupportsBytes, which only exists in Python 3. To make these valid on Python 2, conditionally import that Protocol in Python 3 and otherwise use a dummy class in Python 2. Also have `ndarray` extend `Contains`, while we're here.

This also extends the test suites to run all tests against both Python 2 and Python 3, with the ability to specify that certain tests should only be run against Python 3 (eg to test Python 3 exclusive operators). This should help prevent errors like this moving forward.

One downside of this is that flake8 doesn't understand the `# type:` comments, so it thinks that imports from `typing` are unused. A workaround for this is to add `# noqa: F401` at the end of the relevant imports, though this is a bit tedious.

Finally, change how test requirements are installed and how the `numpy-stubs` package is exposed to mypy, and update the README/Travis file to reflect this. See python/mypy#5007 for more details about the rational behind this change.

* Split `pip install .` out of the `test-requirements.txt` file, update Travis and README files accordingly.
  • Loading branch information
FuegoFro authored and shoyer committed May 10, 2018
1 parent c9d28a2 commit a4857d4
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 36 deletions.
32 changes: 25 additions & 7 deletions numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,14 +1,32 @@
import builtins
import sys

from numpy.core._internal import _ctypes
from typing import (
Any, Dict, Iterable, List, Optional, Mapping, Sequence, Sized,
SupportsInt, SupportsFloat, SupportsComplex, SupportsBytes, SupportsAbs,
Text, Tuple, Type, TypeVar, Union,
Any,
Container,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Sized,
SupportsAbs,
SupportsComplex,
SupportsFloat,
SupportsInt,
Text,
Tuple,
Type,
TypeVar,
Union,
)

import sys

from numpy.core._internal import _ctypes
if sys.version_info[0] < 3:
class SupportsBytes: ...
else:
from typing import SupportsBytes

_Shape = Tuple[int, ...]

Expand Down Expand Up @@ -325,7 +343,7 @@ class _ArrayOrScalarCommon(SupportsInt, SupportsFloat, SupportsComplex,
def __getattr__(self, name) -> Any: ...


class ndarray(_ArrayOrScalarCommon, Iterable, Sized):
class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
real: ndarray
imag: ndarray

Expand Down
28 changes: 26 additions & 2 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,22 @@ reveal_type(x) # E: <type name>
Right now, the error messages and types are must be **contained within
corresponding mypy message**.

Test files that end in `_py3.py` will only be type checked against Python 3.
All other test files must be valid in both Python 2 and Python 3.

## Running the tests

To setup your test environment, cd into the root of the repo and run:


```
pip install -r test-requirements.txt
pip install .
```

Note that due to how mypy reads type information in PEP 561 packages, you'll
need to re-run the `pip install .` command each time you change the stubs.

We use `py.test` to orchestrate our tests. You can just run:

```
Expand All @@ -34,6 +48,16 @@ to run the entire test suite. To run `mypy` on a specific file (which
can be useful for debugging), you can also run:

```
$ cd tests
$ MYPYPATH=.. mypy <file_path>
mypy <file_path>
```

Note that it is assumed that all of these commands target the same
underlying Python interpreter. To ensure you're using the intended version of
Python you can use `python -m` versions of these commands instead:

```
python -m pip install -r test-requirements.txt
python -m pip install .
python -m pytest
python -m mypy <file_path>
```
10 changes: 5 additions & 5 deletions tests/pass/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import operator

import numpy as np
from typing import Iterable
from typing import Iterable # noqa: F401

# Basic checks
array = np.array([1, 2])
def ndarray_func(x: np.ndarray) -> np.ndarray:
def ndarray_func(x):
# type: (np.ndarray) -> np.ndarray
return x
ndarray_func(np.array([1, 2]))
array == 1
Expand All @@ -28,7 +29,8 @@ def ndarray_func(x: np.ndarray) -> np.ndarray:
np.dtype((np.int32, (np.int8, 4)))

# Iteration and indexing
def iterable_func(x: Iterable) -> Iterable:
def iterable_func(x):
# type: (Iterable) -> Iterable
return x
iterable_func(array)
[element for element in array]
Expand Down Expand Up @@ -122,8 +124,6 @@ def iterable_func(x: Iterable) -> Iterable:
1 | array
array |= 1

array @ array

# unary arithmetic
-array
+array
Expand Down
6 changes: 6 additions & 0 deletions tests/pass/simple_py3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np

array = np.array([1, 2])

# The @ operator is not in python 2
array @ array
55 changes: 33 additions & 22 deletions tests/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,48 @@
import pytest
from mypy import api

ROOT_DIR = os.path.dirname(os.path.dirname(__file__))
PASS_DIR = os.path.join(os.path.dirname(__file__), "pass")
FAIL_DIR = os.path.join(os.path.dirname(__file__), "fail")
REVEAL_DIR = os.path.join(os.path.dirname(__file__), "reveal")

os.environ['MYPYPATH'] = ROOT_DIR
TESTS_DIR = os.path.dirname(__file__)
PASS_DIR = os.path.join(TESTS_DIR, "pass")
FAIL_DIR = os.path.join(TESTS_DIR, "fail")
REVEAL_DIR = os.path.join(TESTS_DIR, "reveal")


def get_test_cases(directory):
for root, __, files in os.walk(directory):
for fname in files:
if os.path.splitext(fname)[-1] == ".py":
# yield relative path for nice py.test name
yield os.path.relpath(
os.path.join(root, fname), start=directory)


@pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
def test_success(path):
stdout, stderr, exitcode = api.run([os.path.join(PASS_DIR, path)])
fullpath = os.path.join(root, fname)
# Use relative path for nice py.test name
relpath = os.path.relpath(fullpath, start=directory)
skip_py2 = fname.endswith("_py3.py")

for py_version_number in (2, 3):
if py_version_number == 2 and skip_py2:
continue
py2_arg = ['--py2'] if py_version_number == 2 else []

yield pytest.param(
fullpath,
py2_arg,
# Manually specify a name for the test
id="{} - python{}".format(relpath, py_version_number),
)


@pytest.mark.parametrize("path,py2_arg", get_test_cases(PASS_DIR))
def test_success(path, py2_arg):
stdout, stderr, exitcode = api.run([path] + py2_arg)
assert stdout == ''
assert exitcode == 0


@pytest.mark.parametrize("path", get_test_cases(FAIL_DIR))
def test_fail(path):
stdout, stderr, exitcode = api.run([os.path.join(FAIL_DIR, path)])
@pytest.mark.parametrize("path,py2_arg", get_test_cases(FAIL_DIR))
def test_fail(path, py2_arg):
stdout, stderr, exitcode = api.run([path] + py2_arg)

assert exitcode != 0

with open(os.path.join(FAIL_DIR, path)) as fin:
with open(path) as fin:
lines = fin.readlines()

errors = {}
Expand All @@ -59,11 +70,11 @@ def test_fail(path):
pytest.fail(f'Error {repr(errors[lineno])} not found')


@pytest.mark.parametrize("path", get_test_cases(REVEAL_DIR))
def test_reveal(path):
stdout, stderr, exitcode = api.run([os.path.join(REVEAL_DIR, path)])
@pytest.mark.parametrize("path,py2_arg", get_test_cases(REVEAL_DIR))
def test_reveal(path, py2_arg):
stdout, stderr, exitcode = api.run([path] + py2_arg)

with open(os.path.join(REVEAL_DIR, path)) as fin:
with open(path) as fin:
lines = fin.readlines()

for error_line in stdout.split("\n"):
Expand Down

0 comments on commit a4857d4

Please sign in to comment.