Skip to content

Commit

Permalink
[TVMScript] Allow T.target("device", host="host") to specify host (ap…
Browse files Browse the repository at this point in the history
…ache#14915)

[TVMScript] Allow T.target("device", host="host") in TVMScript

Prior to this commit, the `TargetNode::host` could be specified in
TVMScript as part of the config dictionary, under the key `"host"`.
However, this required all other device parameters to be explicitly
specified, rather than using any of the short-hand string
representations.  This commit forwards the `host` argument from TVMScript's
`T.target` method to `tvm.target.Target`, allowing both the device and
host to be specified using the shorthand string representation.

```python
@T.prim_func
def before_this_commit():
    T.func_attr(
        {
            "target": T.target(
                {
                    "arch": "sm_86",
                    "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""},
                    "keys": ["cuda", "gpu"],
                    "kind": "cuda",
                    "max_num_threads": 1024,
                    "tag": "",
                    "thread_warp_size": 32,
                }
            )
        }
    )
    T.evaluate(0)

@T.prim_func
def after_this_commit():
    T.func_attr({"target": T.target("cuda", host="llvm")})
    T.evaluate(0)
```
  • Loading branch information
Lunderberg authored and mei-ye committed Jun 1, 2023
1 parent a59cfbc commit b7afe10
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
22 changes: 20 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,7 +1655,10 @@ def index_map(
return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map)


def target(target_config: Union[Dict, str]) -> Target:
def target(
target_config: Union[Dict, str],
host: Optional[Union[Dict, str, Target]] = None,
) -> Target:
"""
Create a target
Expand All @@ -1664,6 +1667,9 @@ def target(target_config: Union[Dict, str]) -> Target:
target_config : Union[Dict, str]
The target configuration.
host : Optional[Union[Dict, str, Target]]
The target configuration.
Returns
-------
res : Target
Expand All @@ -1673,7 +1679,19 @@ def target(target_config: Union[Dict, str]) -> Target:
raise ValueError(
f"T.target expected a config dict or string, but got {type(target_config)}"
)
return Target(target_config)
if host is not None and not isinstance(host, (str, dict, Target)):
raise ValueError(
"T.target expected the host to be "
"a config dict, string, or T.target, "
f"but got {type(host)}"
)
if isinstance(target_config, dict) and "host" in target_config and host is not None:
raise ValueError(
"T.target expects to either receive the host "
"as part of the target's config dictionary, "
"or as a separate argument, but not both."
)
return Target(target_config, host)


def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3123,6 +3123,15 @@ def func_with_target_spec_by_str() -> None:
return func_with_target_spec_by_str


def func_with_target_and_host_spec_by_str():
@T.prim_func
def func():
T.func_attr({"target": T.target("nvidia/nvidia-a100", host="llvm")})
T.evaluate(0)

return func


def func_root_attr():
@T.prim_func
def func_root_attr():
Expand Down Expand Up @@ -3883,6 +3892,7 @@ def func():
nontrivial_range_axis,
func_with_target_spec_by_config,
func_with_target_spec_by_str,
func_with_target_and_host_spec_by_str,
func_root_attr,
func_trivial_root_block,
func_nested_root_block,
Expand Down

0 comments on commit b7afe10

Please sign in to comment.