Skip to content

Commit

Permalink
fix type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
rabernat committed Jan 16, 2021
1 parent d493544 commit 83e1fc7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
15 changes: 12 additions & 3 deletions rechunker/executors/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ def _make_stage(stage: Stage) -> Delayed:
dsk = {(name, i): (stage.func, arg) for i, arg in enumerate(stage.map_args)}
# create a barrier
top_key = "stage-" + dask.base.tokenize(stage.func, stage.map_args)
dsk[top_key] = (lambda *args: None, *list(dsk))

def merge_all(*args):
# this function is dependent on its arguments but doesn't actually do anything
return None

dsk.update({top_key: (merge_all, *list(dsk))})
return Delayed(top_key, dsk)


Expand All @@ -62,9 +67,13 @@ def _merge_task(*args):

def _merge(*args: Iterable[Delayed]) -> Delayed:
name = "merge-" + dask.base.tokenize(*args)
keys = [arg.key for arg in args]
# mypy doesn't like arg.key
keys = [getattr(arg, "key") for arg in args]
new_task = (_merge_task, *keys)
graph = dask.base.merge(*[dask.utils.ensure_dict(d.dask) for d in args])
# mypy doesn't like arg.dask
graph = dask.base.merge(
*[dask.utils.ensure_dict(getattr(arg, "dask")) for arg in args]
)
graph[name] = new_task
d = Delayed(name, graph)
return d
Expand Down
7 changes: 4 additions & 3 deletions rechunker/executors/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import itertools
import math
from typing import Iterable, Iterator, Tuple, TypeVar
from typing import Iterable, Iterator, Tuple, TypeVar, Any

import dask
import numpy as np

from rechunker.types import (
CopySpec,
CopySpecExecutor,
MultiStagePipeline,
ParallelPipelines,
ReadableArray,
Expand Down Expand Up @@ -71,12 +72,12 @@ def specs_to_pipelines(specs: Iterable[CopySpec]) -> ParallelPipelines:
T = TypeVar("T")


class CopySpecToPipelinesMixin:
class CopySpecToPipelinesMixin(CopySpecExecutor):
# This signature doesn't work as a mixin because we don't know what type T is
def prepare_plan(self, specs: Iterable[CopySpec]) -> T:
pipelines = specs_to_pipelines(specs)
return self.pipelines_to_plan(pipelines)

def pipelines_to_plan(self, pipelines: ParallelPipelines) -> T:
def pipelines_to_plan(self, pipelines: ParallelPipelines) -> Any:
"""Transform ParallelPiplines to an execution plan"""
raise NotImplementedError
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-pywren_ibm_cloud.*]
ignore_missing_imports = True
[mypy-numpy.*]
ignore_missing_imports = True
[mypy-zarr.*]
ignore_missing_imports = True

0 comments on commit 83e1fc7

Please sign in to comment.