Skip to content

Commit

Permalink
Update the spec to account for call_dps_packed
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Mar 10, 2023
1 parent 9d1855f commit 5707a0d
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions relax_spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ These semantic rules assume a single thread of evaluation on a single host machi
The above evaluation rules are general, but leave much room for implementations of operators to specify custom semantics. Certain operators are used to perform common operations and will be discussed here as well.
- `call_tir(prim_func, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`:
- `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`).
- `prim_func` must be a `GlobalVar` that denotes a `PrimFunc` in the current `IRModule` (we will call it `f`).
- `args` must be an expression that evaluates to a tuple of tensor values (where each member of a tuple will be a tensor argument to the `PrimFunc`). Let us call the members of the tuple `arg1`, `arg2`, ..., `argn`.
- `packed_ints` is an optional argument. If present, it must be a shape value (with `ShapeStructInfo`). If present, we will call the dimensions of the value`shape1`, `shape2`, ..., `shapem` for convenience.
- The `StructInfo` arguments `aS1` through `aSk` give the `StructInfo` of the results of calling the `PrimFunc`.
Expand All @@ -798,6 +798,13 @@ The above evaluation rules are general, but leave much room for implementations
- `f` will be called in destination-passing style, like so: `f(arg1, arg2, ..., argn, shape1, shape2, ..., shapem, r1, r2, ..., rk)`, omitting the `shapei` if `packed_ints` is not given. `f` is expected to mutate *only* the `ri` to give the output of the function, hence `call_tir` is considered pure.
- «If the shape or data type of the actual result do not correspond to the `aSi`, an error is issued.»
- After the call, the `ri` will be returned (returning `r1` directly if there is only a single result, otherwise returning `Tuple(fields=[r1, r2, ..., rk])`).
- «`call_dps_packed(global_symbol, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol` instead of a `PrimFunc` object. The `PackedFunc` may modify any member of `args` (`packed_ints`, if present, is immutable) in addition to the results, so purity is not assumed. The `StructInfo` for the result will be determined int he same manner as in `call_tir`, where it will be `aS1` if `sinfo_args` has a length of 1 and `TupleStructInfo(fields=[aS1, aS2, ..., aSk])` otherwise.»
- `call_dps_packed(packed_func, args, sinfo_args=[aS1])`:
- `packed_func` must evaluate to a `PackedFunc` object.
- `args` must be a tuple; we will call its elements `arg1`, `arg2`, ..., `argn`.
- The `StructInfo` argument `aS1` may be either a single `TensorStructInfo` (whose `shape` field _must_ be a `ShapeExpr`), which we will call `ts1`, or a `TupleStructInfo` whose fields are all `TensorStructInfo` (whose `shape` fields _must_ be `ShapeExpr`s), which we will call `ts1`, `ts2`, ..., `tsm`.
- Let `r1`, `r2`, ..., `rm` be newly allocated tensors whose shape match the `StructInfo` args `ts1`, `ts2`, ..., `tsm`, respectively.
- Evaluate `f(arg1, arg2, ..., argn, r1, r2, ..., rm)`.
- «If the shape or data type of the actual result do not correspond to the `tsi`, an error is issued.»
- Return `r1` if `aS1` is a single `TensorStructInfo`; otherwise, return `Tuple(fields=[r1, r2, ..., rm])`.
- `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a shape object.
- `null_value()`: Returns a null object (treated as `ObjectStructInfo`). This is used for indicating to operators that an optional argument has been omitted.

0 comments on commit 5707a0d

Please sign in to comment.