Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[Discuss] emit_te sugar in TVMScript #197

Closed
psrivas2 opened this issue Aug 1, 2022 · 4 comments
Closed

[Discuss] emit_te sugar in TVMScript #197

psrivas2 opened this issue Aug 1, 2022 · 4 comments
Labels

Comments

@psrivas2
Copy link
Contributor

psrivas2 commented Aug 1, 2022

TVMScript is very useful to create Relax IRModule, especially for unit tests. However, since Relax does not have its own operator set yet, creating an IRModule with several well known operators is a hassle using TVMScript.

In such cases, one would have to generate the IRModule either using

In both cases, we use the BlockBuilder because its emit_te interface can generate TIR implementations corresponding to TE compute.

I propose, we introduce emit_te sugar in TVMScript as well which would internally find the relevant op strategy and generate the corresponding TIR primfunc. The signature could look something like below.

lv = R.emit_te(<te_compute>, input_arg0, input_arg1, ..., attrs=<dictionary of attributes>)

example: lv = R.emit_te(topi.add, x, y, attrs={'my_op_kind': 'addition operation', ...})

This will allow us to replace example test cases below with TVMScript which is easier to read.

using BlockBuilder

def test_fuse_simple():
    def before():
        bb = relax.BlockBuilder()
        x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32"))
        with bb.function("main", [x]):
            with bb.dataflow():
                lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
                lv1 = bb.emit_te(topi.exp, lv0)
                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
            bb.emit_func_output(gv)

        return bb.get()

    def expected():
        bb = relax.BlockBuilder()
        x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32"))
        p0 = relax.Var("p0", (), relax.DynTensorType(0, "float32"))

        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}):
            with bb.dataflow():
                lv0 = bb.emit_te(topi.add, x, p0)
                lv1 = bb.emit_te(topi.exp, lv0)
                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
            bb.emit_func_output(gv)
        fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze")

        x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32"))
        with bb.function("main", [x]):
            with bb.dataflow():
                gv = bb.emit_output(
                    relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")])
                )
            bb.emit_func_output(gv)

        return bb.get()

    _check(before(), expected())

using TVMScript with emit_te sugar.

def test_fuse_simple():
    @tvm.script.ir_module
    class Before:    
        @R.function
        def main(x: Tensor((10, 20), "float32")) -> Tensor(None, "float32", ndim = 2):
            # block 0
            with R.dataflow():
                lv0 = R.emit_te(topi.add, x, relax.const(1, "float32"))
                lv1 = R.emit_te(topi.exp, lv0)
                gv = R.emit_te(topi.squeeze, lv1)
                R.output(gv)
            return gv
        
    
    @tvm.script.ir_module
    class Expected:
        @R.function
        def fused_add_exp_squeeze(x: Tensor((10, 20), "float32"), p0: Tensor((), "float32")) -> Tensor(None, "float32", ndim = 2):
            with R.dataflow():
                lv0 = R.emit_te(topi.add, x, relax.const(1, "float32"))
                lv1 = R.emit_te(topi.exp, lv0)
                gv = R.emit_te(topi.squeeze, lv1)
                R.output(gv)
            return gv
        @R.function
        def main(x1: Tensor((10, 20), "float32")) -> Tensor(None, "float32", ndim = 2):
            with R.dataflow():
                gv1: Tensor((10, 20), "float32") = fused_add_exp_squeeze(x1, 1)
                R.output(gv1)
            return gv1

    _check(Before, Expected)
@tqchen
Copy link
Contributor

tqchen commented Aug 1, 2022

This seems to be related to meta-programming , cc @yelite @cyx-6 @junrushao1994

@Hzfengsy
Copy link
Member

Hzfengsy commented Aug 3, 2022

I totally agree it is useful. It can be done when the new parser is ready :)

@psrivas2
Copy link
Contributor Author

psrivas2 commented Aug 3, 2022

Great, looking forward to it!
I'll keep the issue open just for tracking purposes.

@yongwww
Copy link
Collaborator

yongwww commented Feb 27, 2023

support was merged in tvm/unity apache/tvm#14123

@yongwww yongwww closed this as completed Feb 27, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

4 participants