From 4eab6fcf4edda9522fe92d95ca81b8f7eb3112d8 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 20 Oct 2022 19:32:32 -0400 Subject: [PATCH 01/30] Initial spec --- relax_spec.md | 749 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 749 insertions(+) create mode 100644 relax_spec.md diff --git a/relax_spec.md b/relax_spec.md new file mode 100644 index 0000000000..40f69e4ba9 --- /dev/null +++ b/relax_spec.md @@ -0,0 +1,749 @@ +# Informal Relax Language Specification + +Note: Text in «double chevrons» indicates features not present in the current prototype. + +In order to develop and test Relax, it is important for compiler developers to agree on what a given program in Relax means and what makes it valid so that test cases can be evaluated independently of any particular Relax implementation. This document is intended to describe Relax's grammar constructs (its [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree), or AST), the semantics of its grammar (what the different constructs mean), Relax's type system and type-checking rules (what makes a Relax program valid), and its rules for reasoning about tensor shapes in detailed though still informal terms. If necessary, we may encode these rules more formally to allow for more automated analysis. + +Though this document will use the TVMScript front end for some examples, specifying the mapping from Python's AST to Relax's AST will be deferred until the parser becomes more stable. + +# Table of Contents + +1. [Overview](#overview) +2. [Top-Level Program Organization](#top-level-program-organization-irmodule) +3. [Values in Relax](#values-in-relax) +4. [Variable Scoping](#variable-scoping) +5. [Well-Formedness Criteria](#well-formedness-criteria) +6. [Types in Relax](#types-in-relax) +7. [Shapes in Relax](#shapes-in-relax) +8. [Semantics](#detailed-semantics) + +# Overview + +This section will outline the grammar of Relax and give very brief descriptions of the different components, including the semantics, type system, and shape system. The rest of this document will provide more detailed descriptions of these facets of the language, including the validity conditions that the type system and shape system uphold. + +## Grammar + +Below is a diagram of the various AST constructs in Relax, including types. In code, these are defined on the C++ side in `include/tvm/relax/{expr.h, type.h}` and in Python in `python/tvm/relax/{expr.py, ty.py}`. This diagram will give the names of the AST nodes and the types and names of their members. The semantics will describe what computation each construct represents; an AST is simply data. A Relax program consists of an `IRModule` with global variables bound to Relax functions that implement the computations of interest. + +(On the notation: `[x]` means "a list of `x`," `x?` means "optionally `x`," `{x: y}` means "a map of `x` to `y`," `x | y` means "`x` or `y`," and `#` is used for comments.) + +``` +# PrimExprs are defined in TIR, see include/tvm/tir/expr.h +# They are intended to have the same semantics as in TIR +PrimExpr ::= + Var(name: string) # shape variables + | IntImm(value: int64) + | Add(a: PrimExpr, b: PrimExpr) + | Sub(a: PrimExpr, b: PrimExpr) + | Mul(a: PrimExpr, b: PrimExpr) + | Div(a: PrimExpr, b: PrimExpr) + | Min(a: PrimExpr, b: PrimExpr) + | Max(a: PrimExpr, b: PrimExpr) + | Not(a: PrimExpr) + | And(a: PrimExpr, b: PrimExpr) + | Or(a: PrimExpr, b: PrimExpr) + | Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr) + # (others may be added later, as deemed necessary) + +Type ::= DynTensorType(ndim: int, dtype: DataType) + | ShapeType() + | ObjectType() + | TupleType(fields: [Type]) + | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») + +# expressions +Expr ::= Constant(data: NDArray) + # scoped to functions or SeqExprs + | Var(name_hint: string) + # scoped to DataflowBlocks + | DataflowVar(name_hint: string) + | GlobalVar(name_hint: string) + | Tuple(fields: [Expr]) + | SeqExpr(blocks: [BindingBlock], body: Expr) + | Function(params: [Var], body: Expr, ret_type: Type?, attrs: Attrs?) + | If(cond: Expr, true_branch: Expr, false_branch: Expr) + | ExternFunc(global_symbol: string) + | Call(op: Expr, args: [Expr], type_args: [Type], attrs: Attrs?) + | ShapeExpr(values: [PrimExpr]) + | TupleGetItem(tuple_value: Expr, index: int) + | Op(op_name: string) + | RuntimeDepShape() + +# binding blocks (analogous to sequence of statements) +BindingBlock ::= + BindingBlock(bindings: [Binding]) + | DataflowBlock(bindings: [Binding]) + +# bindings (analogous to statements) +Binding ::= + VarBinding(var: Var|DataflowVar, value: Expr) + | MatchShape(var: (Var|DataflowVar)?, pattern: [PrimExpr], value: Expr) + +# Relax programs are IRModules. Modules may bind global variables either to +# Relax functions or TIR PrimFuncs (specified separately). +# The Relax compiler may analyze and modify the TIR PrimFUncs as well. +Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) +``` + +## Expression Survey + +This specification provides a more detailed description of what each expression and type represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. + +1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). +2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. +3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchShape` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. +4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchShape` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." +5. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. + 1. For `ExternFunc` nodes, the call will look up the registered `PackedFunc` by its global symbol and will call it with the given arguments (note that a TIR `PrimFunc` can be compiled into a `PackedFunc` and called using `ExternFunc` by defining a `global_symbol` attribute in the `PrimFunc`). «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» + 2. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s) «and `cast` (performs dynamic type conversions).» + 3. Any other expression must evaluate to a closure; the closure will then be called with the given arguments. + + Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. + +6. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. +7. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. +8. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: + 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). + 2. `MatchShape`s: The `value` expression is evaluated and the resulting shape is dynamically checked against the shape denoted by the `PrimExpr`s in the `pattern` field. + 1. If `value` evaluates to a tensor value, the pattern will be checked against the shape of the tensor; if it evaluates to a shape value, the pattern will be checked directly against the shape. + 2. Any shape dimension in the pattern that consists of a single new shape variable is treated as a binding: The variable is bound to the size of the corresponding dimension of the value being matched. + 3. If the shapes do not match, an error is triggered. If there is a variable provided, the value is bound to the `var` expression (if the variable is omitted, the shape check is performed and any shape variables are updated, but no new binding is introduced). Shape variables introduced in a `SeqExpr` are similarly scoped to the `SeqExpr`. + + The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. + +9. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +10. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. + + The function can have shape annotations on the parameters and a return shape parameter. When the function is called, the annotations on parameters checked against the argument values in similar fashion to `MatchShape` and can introduce new shape variables that are scoped to the function. + + «A function mapped bound to a `GlobalVar` can have a `global_symbol` attribute defined to indicate that it should be externally linked externally (be accessible outside the `IRModule`). The absence of a `global_symbol` attribute on a function definition bound to a `GlobalVar` indicates that it is "private" and hence can be called only within the `IRModule`.» + +11. `RuntimeDepShape` nodes are used to denote that a shape is unknown at compile time and must be deduced at run time. These nodes may appear only in shape annotations and have no run-time semantics of their own. + +## Purity and Dataflow Blocks + +A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. «In Relax, we conservatively assume that any function that calls an impure function is itself impure, though the attribute `force_pure` on a function can be used as an override (e.g., if a function creates a new tensor, mutates it, and returns it, that is still pure but does not satisfy the conservative rule).» + +Above, it is mentioned that `DataflowBlock`s are not allowed to contain constructs featuring control flow (`If` nodes or recursive calls to the current function) or calls to impure functions. This ensures that `DataflowBlock`s represent a directed acyclic graph of pure operations, which is similar to the graph-like abstractions of traditional deep learning frameworks. This allows many common optimizations from past frameworks to be directly adapted to `DataflowBlock`s without having to accommodate additional reasoning about more expressive features like control flow and side effects. + +There is one visible side effect that Relax permits inside otherwise "pure" functions, namely exiting the program with an error. This can arise in the following cases: + +- Shape matching errors (from `MatchShape` or from implicit shape checks upon calling a Relax function) +- Errors raised by otherwise pure Relax operators or `PackedFunc`s, such as in `cast` (which dynamically checks types). Since the purity of operators or `PackedFunc`s must be manually registered, this means that it is permissible to register an operator or `PackedFunc` as being pure if its only side effect is issuing an error in some cases. + +Even though an abnormal program exit is a visible side effect and removing or reordering it changes the observable semantics, it would be too great a restriction to prohibit error checking inside `DataflowBlock`s. Relax does not have any notion of exception handling, so the only consequence of a failed safety check can be exiting the program. It is permissible for the compiler to reorder, duplicate, or eliminate `MatchShape`, `cast`, or otherwise pure operations that have the potential of failing, provided that doing so does not change the value returned by the program or any other visible behavior. + +To indicate that an operator or `PackedFunc` that can abort with an error should *never* be reordered or removed by the compiler, it should *not* be marked as pure. However, this means that it cannot be used inside a `DataflowBlock`. + +Note that in some programming languages like Koka, non-termination is also considered a side effect, since it can in some sense be "observed" by a user and affects the visible behavior of a program (e.g., if there is an infinite loop before a print statement, the print will never happen). However, since non-termination cannot be automatically detected in general and is unlikely to arise in deep learning models, we do not attempt to systematically track non-termination in Relax. In general, the Relax compiler is allowed to reorder or remove otherwise pure function calls even if they may not terminate. For example, if a pure function `f` that returns an integer scalar does not terminate, it is permissible in principle to rewrite `f() - f()` to 0. + +Exiting with an error and infinitely looping are traditionally considered "[divergence](https://en.wikipedia.org/wiki/Divergence_(computer_science))" in the programming languages literature. As a general principle, Relax's compiler is permitted to turn a program that diverges into a program that does not diverge (provided that no other visible effects change) so long as it never transforms a program that does not diverge into one that diverges. + +## Type System Survey + +The types in Relax correspond to the broad categories of the values given above: + +1. `DynTensorType` corresponds to tensor values, giving the scalar data type and the number of dimensions (rank), both of which are optional. +2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. +3. `ShapeType` corresponds to shape values. +4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» +5. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. + +The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» + +## Shape System Survey + +In Relax, tensor shapes are not handled in the type system; each expression instead a has an associated shape expression. In many cases, these shape computations can allow for statically concluding that two shapes are the same and thus eliminate the need for dynamic checks via `MatchShape`. However, when shapes cannot be statically concluded to be the same, it may be necessary for there to be dynamic checks. The compiler is also free to make use of shape expressions for memory planning purposes. «Relax is "strongly shaped," meaning that if the compiler cannot conclude that shapes match in certain cases, an error will be issued and an explicit `MatchShape` will be required.» + +--- + +# Top-level Program Organization: `IRModule` + +As with Relay, the top level of organization for a Relax program is an `IRModule`. An `IRModule` contains mappings of global variables to functions, both Relax functions as well as TIR functions (which can be called from Relax). The global function called `main` is usually considered the entry point to the program (meaning that execution starts by calling that function), though any function with a `global_symbol` attribute can be specified as the entry point during compilation. In the AST (see below), the names of Relax functions in the `IRModule`s are `GlobalVar` nodes. + +Oftentimes, compiler passes operate only on particular functions or add new functions to the `IRModule`, but a pass can operate over the entirety of a Relax program by iterating through all the functions in an `IRModule`. + +# Values in Relax + +Here are the classes of values that Relax operates over, meaning that they can be assigned to variables or be the result of evaluating expressions. + +- *Tensors* are n-dimensional arrays of scalar values (which can be integers of fixed bitwidths, floats of fixed bitwidths, or bools). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. +- *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations return no value (as may be the case in some `PackedFunc` or operator calls that have side effects). +- *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time type information* (RTTI) indicating their argument types and result type, in order to facilitate dynamic type checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTTI is left up to the compiler implementation to determine so long as the `cast` operator can verify the type of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» +- *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. +- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values of type `ObjectType`. + +## Representation of Values at Run Time + +Because Relax supports calls to arbitrary `PackedFunc`s that can operate on a low level, it is necessary to define a convention for how values will be represented at run time. At this time, the specification does not require any specific representation and permits compiler implementations to choose their own representations, provided that each value type listed above can be recognized at run time (for dynamic type checks). This means that Relax programs that call `PackedFunc`s directly are not portable across compiler implementations: The `PackedFunc`s used must be able to operate on the run-time representations of values. + +Possible specification in terms of the TVM object system: + +- Tensors are represented at run time as `NDArray`s (see `include/tvm/NDArray.h`). +- Tuples are represented using TVM ADTs (algebraic data types), which are arrays of TVM objects with a tag (see `include/tvm/runtime/container/adt.h`). Tuples use a tag of 0. +- At run time, closures are represented as a `ClosureObj` (see `include/tvm/runtime/container/closure.h`); in the Relax VM these more specifically use the `VMClosureObj` (see [`https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h`](https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h)). +- Shape values are represented at run time as a `ShapeTuple` (see `include/tvm/runtime/container/shape_tuple.h`). +- We require objects other than the above values used by and returned by `PackedFunc` to inherit from TVM's `Object` class (defined in `include/tvm/runtime/Object.h`). Note that `PackedFunc`s are capable of using and returning all TVM POD (plain-old data) values (see `include/tvm/runtimes/packed_func.h`), which includes some representations that do not inherit from `Object`. In the future, we may define semantics for other values, but at present, these are *unsupported* in Relax and we make no guarantees about the semantics of calling `PackedFunc`s that use or return anything that does not inherit from `Object`. + +# Variable Scoping + +There are four relevant scopes in Relax, which determine where variables are visible and can be used: + +1. Global: `GlobalVar`s can be referenced from any function in the `IRModule`, whether a Relax function or a TIR `PrimFunc`. All global functions are visible to each other and to themselves, allowing for mutual recursion. +2. Function: The parameters to a function (ordinary `Var` nodes) can be referenced from anywhere in that function. In a recursive binding (a `Binding` node where the RHS is a `Function` node or `GlobalVar` being mapped to a function at the `IRModule` level), the variable being bound is also scoped to that function, allowing for defining a recursive function. +3. `SeqExpr`: `Var` nodes defined in a `BindingBlock` in a `SeqExpr` node can be referenced in any later binding within the same `BindingBlock`, in any binding within any later `BindingBlock` in that `SeqExpr` node, or in the `SeqExpr`'s body expression. The variables defined in the `BindingBlock`s leave scope once the `SeqExpr` returns. +4. `DataflowBlock`: `DataflowVar`s introduced in a `DataflowBlock` can be referenced in any later binding within that `DataflowBlock`, but leave scope *once that `DataflowBlock` finishes executing*. Definitions in a `DataflowBlock` that are intended to leave the `DataflowBlock` should be bound to an ordinary `Var`. + +# Well-Formedness Criteria + +Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid. + +1. `DataflowVar`s can be bound only inside `DataflowBlock`s. Additionally, a `DataflowVar` may not be used outside of the `DataflowBlock` in which it is defined. +2. A `Var` of any kind used in the program must be either a function parameter or appear on the LHS of a binding exactly once. In the binding where a `Var` is defined, the same `Var` is permitted to occur in the RHS of the binding only if the binding is defining a function (i.e., local functions are permitted to be recursive). +3. A `Var` of any kind may not appear before it is bound. Namely, if a `Var` is bound in a `BindingBlock` in a `SeqExpr`, that `Var` may not appear in bindings that precede the one where it appears on the LHS. +4. «A return shape annotation for a function is not allowed to use any shape variables that are not in scope at the function definition. That is, the only shape variables that can appear on the return shape annotation are those defined in the outer scope or those introduced in the argument shape annotations.» +5. In each function, `PrimExpr` variables (shape variables) similarly may not appear in `ShapeExpr`s or shape annotations before the shape variables are bound (either in function signatures or `MatchShape` bindings). A shape variable is bound only when it appears in a dimension by itself (for example, a dimension consisting of `x` will bind `x`; however, `2*x` is not a binding and is considered an error if `x` has not yet been bound) in a `MatchShape` node or a function argument shape annotation. +6. The following constructs are not permitted to occur inside `DataflowBlock`s, which must be side effect– and control flow–free: + 1. Recursive calls to the current function + 2. Calls to a global function that is mutually recursive with the current function + 3. `If` nodes + + «Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during type checking.» + +7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return type annotation is *required*. [TODO: Do we also require a return shape annotation in such cases?]» +8. `Op` nodes may appear only as the `op` argument to `Call` nodes. +9. `ExternFunc` expressions may appear only as the `op` argument to `Call` nodes. +10. The `type_args` field is used only for type checking calls of `ExternFunc`s and `Op`s. Calls to `ExternFunc`s must have exactly one type argument, indicating the return type. Calls to `Op`s may use `type_args` as they wish. No other calls may have a non-empty `type_args`. +11. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. +12. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. +13. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» +14. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» +15. «Any `PackedFunc` or operator called in a shape annotation or `shape_` expression must be pure and be annotated as such.» +16. The node `RuntimeDepShape` may appear only in shape annotations and `shape_` expressions. It has no defined semantics at run time. + +# Types in Relax + +Relax presently has five types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: + +1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. +2. `ShapeType`, referring to shape values. +3. `FuncType`, referring to functions (closures). `FuncType`s specify the types of their parameters, a return type, and whether the function is pure. +4. `TupleType`, referring to tuple values, giving the types of their fields. +5. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. + +## Subtyping + +Relax implements subtyping, which means that members of types can be accepted where members of their supertypes are accepted. We will denote the subtyping relationship as `T1 <: T2`, indicating that `T1` is a subtype of `T2`. For example. if `T1 <: T2` and some function expects an argument of type `T2`, then passing a member of type `T1` to that function is permitted; passing a member of type `T2` as an argument to a function that expects type `T1` for that argument is *not* permitted—the value would have to be dynamically cast to `T1` using the `cast` operator. + +### Rules for Subtyping + +1. Reflexivity: For all types `T`, `T <: T`. +2. Transitivity: For all types `T1`, `T2`, and `T3`, if `T1 <: T2` and `T2 <: T3`, then `T1 <: T3`. +3. For all types `T`, `T <: ObjectType`. Hence, `ObjectType` is a supertype to all Relax types (all values in Relax are members of `ObjectType`). +4. Rules for `DynTensorType`: + 1. For all fixed `ndim` values `m`, where `m` ≥ 0, and `dtype`s `d`, `DynTensorType(ndim=m, dtype=d) <: DynTensorType(ndim=m, dtype=Void)`. + 2. For all fixed `ndim` values `m` and `dtype`s `d` that are not `Void`, `DynTensorType(ndim=m, dtype=d) <: DynTensorType(ndim=-1, dtype=d)`. + 3. Corollary: `DynTensorType(ndim=-1, dtype=Void)` is a supertype to all tensor types, since it refers to any possible tensor value. +5. Suppose we have types `T1 <: T1'`, `T2 <: T2'`, …, `Tn <: Tn'`. Then `TupleType(fields=[T1, T2, ..., Tn]) <: TupleType(fields=[T1', T2', ..., Tn'])`. +6. Rules for `FuncType`: + 1. Impure functions are supertypes to pure functions. Namely, if we have types `T1`, `T2`, …, `Tn` and `Tr`, then `FuncType(arg_types=[T1, T2, ..., Tn], ret_type=Tr, pure=True) <: FuncType(arg_types=[T1, T2, ..., Tn], ret_type=Tr, pure=False)`. + 2. Suppose we have types `T1' <: T1`, `T2' <: T2`, …, `Tn' <: Tn` and `Tr <: Tr'`. Then `FuncType(arg_types=[T1, T2, ... Tn], ret_type=Tr, pure=p) <: FuncType(arg_types=[T1', T2', ..., Tn'], ret_type=Tr', pure=p)`. Note the direction of the subtyping relationships for the argument and return types: We must be able to *call* this function with the *same* arguments and *use the returned value* wherever it is accepted—hence a function that takes more general arguments and returns a more specific return value can be used in place of the original. + +These rules allow us to define the least upper bound (LUB) for any two types `T1` and `T2`, meaning that it is the most specific type `T` for which `T1 <: T` and `T2 <: T` ("most specific" meaning that if there exists some other `T'` for which `T1 <: T'` and `T2 <: T'`, then `T <: T'`). The LUB is guaranteed to exist for any two types because `Object` is a supertype to all types. + +Note that the rule for obtaining the LUB of function types relies on the counterpart to the LUB, the greatest lower bound (GLB). The GLB is not guaranteed to exist for any two types in Relax, as there is no single type that is a subtype of all others. + +We can give an algorithm for determining the LUB and GLB for two types, in pseudocode: + +```python +def find_glb(T1 : Type, T2 : Type) -> Type?: + if T1 == T2: # syntactic equality + return T2 + if T1 is ObjectType: + return T2 + if T2 is ObjectType: + return T1 + if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType: + return None + if T1 and T2 are both DynTensorType: + ret_ndim = T1.ndim + ret_dtype = T1.dtype + if ret_ndim == -1: + ret_ndim == T2.ndim + if ret_dtype == Void: + ret_dtype = T2.dtype + if ret_ndim != -1 and T2.ndim != ret_ndim: + # mismatch, so there's no common lower bound + return None + if ret_dtype != Void and T2.dtype != ret_dtype: + return None + return DynTensorType(ret_ndim, ret_dtype) + if T1 and T2 are both TupleType: + if they do not have the same length: + return None + fields = [] + for field1, field2 in zip(T1.fields, T2.fields): + glb = find_glb(field1, field2) + if glb is None: + return None + fields.append(glb) + return TupleType(fields) + if T1 and T2 are both FuncType: + «if they are not both pure or both impure:» + «return None» + purity = T1.purity + if they do not have the same arity: + return None + # mutual recursion with finding the LUB + arg_types = [ + find_lub(arg_type1, arg_type2) + for arg_type1, arg_type2 in zip(T1.arg_types, T2.arg_types) + ] + ret_type = find_glb(T1.ret_type, T2.ret_type) + if ret_type is None: + return None + return FuncType(arg_types, ret_type, purity) + +def find_lub(T1 : Type, T2 : Type) -> Type: + if T1 == T2: # syntactic equality + return T1 + if T1 or T2 is ObjectType: + return Object + if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType: + return ObjectType + if T1 and T2 are both DynTensorType: + res_ndim = T1.ndim + res_dtype = T1.dtype + if T1.ndim != T2.ndim: + res_ndim = -1 + if T1.dtype != T2.dtype: + res_dtype = Void + return DynTensorType(res_ndim, res_dtype) + if T1 and T2 are both TupleType: + if they do not have the same length: + return ObjectType + return TupleType([ + find_lub(field1, field2) + for field1, field2 in zip(T1.fields, T2.fields) + ]) + if T1 and T2 are both FuncType: + «purity = (True iff they're both pure)» + if they do not have the same arity: + return ObjectType + arg_types = [] + for arg_type1, arg_type2 in zip(T1.arg_types, T2.arg_types): + # potential mutual recursion + glb = find_glb(arg_type1, arg_type2) + if glb is None: + return ObjectType + arg_types.append(glb) + return FuncType(arg_types, find_lub(T1.ret_type, T2.ret_type), «purity») +``` + +### When Type Conversions are Necessary + +For two types `T1` and `T2`, if `T1 <: T2`, then a value of type `T1` can be passed anywhere a value of type `T2` is expected without any need for type conversions or dynamic checks. + +*However*, if `T1 <: T2`, then passing a value of type `T2` where `T1` is expected can only be done if there has been some kind of dynamic check or conversion of that value. «Relax is *strongly* *typed*, meaning that the compiler will give an error in this situation and require an explicit conversion via the `cast` operator, which inspects the value's run-time representation and exits the program with an error message if the value is not a subtype of T1.» + +If `T1` is not a subtype of `T2` and `T2` is not a subtype of `T1`, then it is always a type error to pass a value of either type where a value of the other is expected (no member of either type can be a member of the other). + +## Type Checking Rules + +The type checking rules for Relax are relatively simple and allow in some cases for types to be inferred without user annotations. Below, we describe how the types for each expression can be derived and when type checking should return an error. + +Let us consider a typing context `Γ`, which is a map of variables to types. + +1. «We type check the entire `IRModule` one function definition at a time. To handle mutual recursion, we prepopulate `Γ` with the annotated types of all global functions that are called mutually recursively. We then proceed to check the types of the global functions one at a time.» +2. Given a variable `v`, if `v` is in `Γ`, then we return `Γ[v]`. Otherwise, it is an error. +3. Given a constant expression `Constant(data)`, the type is `DynTensorType(ndim=n, dtype=d)` where `n` is the number of dimensions in `data` and `d` is its data type (recall that `data` is an `NDArray` literal). +4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType`. +5. The type of a `RuntimeDepShape` expression is `ShapeType`. +6. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. +7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: + 1. If `op` is a Relax `Op` node, then we look up its registered `FInferType` property. `FInferType` is a macro that takes in the `Call` node and produces a type. We return the type `op.FInferType(Call(op, [a1, ..., an], type_args=[aT]))`. The implementation of `FInferType` is free to throw errors. + 2. If `op` is `ExternFunc`, then use the sole member of `type_args` (calls to `ExternFunc`s are required to have exactly one `type_args` member) `aT` as the return type. Packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function itself to do any validation. + 3. Otherwise, check the types of the subexpressions, left to right. Suppose `op` has the type `Tf`, `a1` has type `T1`, …, and `an` has type `Tn`. If `Tf` is not a function type with exactly `n` arguments, we consider it a type error and require an explicit cast. Suppose `Tf` has argument types `T1'`, `T2'`, …, `Tn'`. Consider it a type error if any of the following does not hold: `T1 <: T1'`, `T2 <: T2'`, …, or `Tn <: Tn'`. Then the return type is `Tf.ret_type`. +8. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. +9. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» +10. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. + 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. + 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» + 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. + 3. For each `MatchShape(v: T, e, shape_pattern)`, where `T` is an optional type annotation, let the checked type of `e` be `T'`. + 1. If `T'` is `ShapeType`, then emit an error if `T` is not a supertype of `ShapeType`. Add `v` to `Γ` with type `T`. + 2. If `T'` is `DynTensorType`: + 1. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. Let the datatype of `T'` be `d`. + 2. If `T` is not a supertype of `DynTensorType(ndim=len(shape_pattern), dtype=d)`, then emit an error. If `T` is a subtype of that type, emit an error and request a cast. + 3. Add `v` to `Γ` with type `T`. + 3. If `T'` is `ObjectType`, then the only type we can conclude for `v` is `ObjectType`. If `T` is not `ObjectType`, emit an error and request a cast. + 4. If `T'` is `TupleType` or `FuncType`, emit a type error. + 2. If the current block is a `DataflowBlock`, remove any `DataflowVar`s from `Γ` after we have finished processing the bindings in that block. + 3. Finally, the type of the `SeqExpr` is the type of `body` checked under the current `Γ`. Afterwards, remove all bindings added from the `b0`, `b1`, …, `bn` from `Γ`. +11. Let us consider a function `Function(v1 : T1, v2 : T2, ..., vn : Tn, attrs=a) -> Tr: body`. + 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, where `p` is `True` if a `pure` attribute is included and `False` otherwise. Remove `fv` from `Γ` before returning. + 2. Add `v1` to `Γ` with type `T1`, `v2` with type `T2`, …, and `tn` with type `Tn`. Recursively type check `body` in this new context: + 1. «Determining purity: If `body` contains any call to a function whose return type does not specify that it is pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not specify the `pure` attribute, then we consider the function to be (potentially) impure. If all calls are to functions whose return type specifies purity or that include the `pure` attribute on the call or `Op`, then the function is treated as pure.» + 2. «Suppose the purity defined in the previous step is `p'`. Suppose the annotated function purity (in the attributes) is `p`. If `p'` is false while `p` is true, then it is a type error; if `p` was omitted, use `p'` for `p`.» + 3. «If the function has the attribute "`force_pure`," then consider `p` to be true, even if the check above judged the function not to be pure. The compiler may emit a warning in this situation.» + 4. Suppose the result of type-checking `body` is `Tr'`. If the current function is not recursive or a mutually recursive global function and `Tr` was omitted, consider the function to have type `FuncType([T1, T2, ..., Tn], Tr', pure=p)`. If `Tr' <: Tr`, then we consider the function to have type `FuncType([T1, T2, ..., Tn], Tr, pure=p)`. «If `Tr <: Tr'`, return an error and require an explicit cast in the function body.» If `Tr'` is not a subtype of `Tr` and `Tr` is also not a subtype of `Tr'` (meaning a dynamic cast cannot succeed), this is an error. + 5. Remove `v1`, `v2`, …, and `vn` from `Γ` before returning. + +# Shapes in Relax + +In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. In Relax, to allow for greater flexibility for variable-shape tensors and make it easier to implement new operators, shapes can be checked at run time. Though every expression in Relax has a shape associated with it just as expressions also have types, there is no requirement that the shape be expressed at compile time. Instead, the compiler merely requires that an expression's shape define *a way* to compute a fully specified shape at run time. Users have the ability to make use of shape variables and arithmetic expressions to encode a wide variety of shape constraints that can be checked dynamically. + +Nevertheless, in many cases, these shapes can be analyzed at compile time (particularly when they are consist of constants or deducible variables) to facilitate compile-time optimization much like is possible with Relay or TIR. Through constant propagation, function inlining, and other partial evaluation–like transformations, we can potentially eliminate many more dynamic checks by allowing some shape computations to be simplified at compile time. + +## Defining Shape Computations + +In Relax, each expression has an associated shape computation, which defines how that expression's shape can be computed based on the shapes of its subexpressions. We will refer to this computation as `shape_`, as that is what it is called in the implementation. This essentially serves as a mechanism for propagating shape annotations on variable bindings and function definitions to other expressions and enable more compile-time analysis of shapes. In particular, `shape_` is useful for memory planning. These computations can also be used to simplify shape checking and eliminate many dynamic checks. + +### Expressing Dimensions + +A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds to a dimension. The use of TIR `PrimExpr`s for shape dimension allows shape computations to express complex constraints that include variables and integer arithmetic expressions in addition to just constant dimensions. + +**Scope of Shape Variables** + +Shape variables can be introduced in two places in a Relax program: In a function signature, where they may be included with the argument shapes and return shape annotations, or in `MatchShape` bindings. Shape variables used in the function signature are scoped to the entire function in which they appear. Shape variables used in `MatchShape` bindings are scoped only to the `SeqExpr` in which they appear. + +**Informal Semantics of `PrimExpr`s for Dimensions** + +1. Shape variables can be bound to a value exactly once: at the start of a function for shape annotations on function arguments, in `MatchShape` bindings, or before a function returns (for shape variables on the return type). In particular, matching a `PrimExpr` consisting only of an uninitialized shape variable is treated as its binding (see below on `MatchShape`). After a shape variable has been bound for the first time, future uses of it will refer to the same value. +2. It is not legal to use a shape var that has not yet been bound. This results in an error at run time, though most cases can be detected at compile time. +3. «Local functions will "capture" defined shape variables from the parent scope with their present values in the resulting closure.» +4. If all variables in the `PrimExpr` are defined, `PrimExpr` arithmetic will generally be evaluated according to the semantics of TIR. + +### Evaluating `MatchShape` + +`MatchShape` allows for binding shape variables in Relax. It can be used with either tensor values or shape values, and in both cases the evaluation of the `PrimExpr`s proceeds similarly. + +1. Evaluating `MatchShape(v, t, s)`, where `t` is a tensor value and `s` is a list of `PrimExpr`s corresponding to shape dimensions: + 1. Suppose `s` is `(p1, p2, ..., pn)` , where each variables is a `PrimExpr`. We evaluate `p1`, then `p2`, and so, in that order according to the following rules (corresponding to the `i`th dimension): + 1. If the current `PrimExpr` consists only of an uninitialized shape variable, we bind the shape variable in that scope to the concrete value of the `i`th dimension of the value of `t`. + 2. Evaluate the current `PrimExpr` and compare it to the concrete value of the `i`th dimension of `t`. Raise an error if they do not match. + 2. If `v` is provided, bind `t` to `v` (see the general semantics for how that should be implemented). +2. Evaluating `MatchShape(v, S, s)`, where `S` is a shape value proceeds identically to the above, except the `PrimExpr`s are compared to the `i`th element of `S`. + +### General Shape Computation Grammar + +Shape computations can consist of the following expressions, which are a subset of general Relax `Expr`s: + +``` +ShapeCompExpr ::= ShapeExpr(dims: [PrimExpr]) + | RuntimeDepShape() + | Tuple(fields: [ShapeCompExpr]) + | Call(op: Op|ExternFunc, args: [Var|Constant]) + | TupleGetItem(tuple_value: ShapeCompExpr, index: int) +``` + +The shape expressions can be interpreted as follows: + +- `ShapeExpr` describes the shape of a tensor as a list of dimensions +- `Tuple` describes the shapes of each member of a tuple +- `TupleGetItem` describes the shape of a member of a tuple +- `Call` describes the shape of a function (or operator) call return value in terms of its arguments +- `RuntimeDepShape` describes shapes that are unknown at compile time (like when a shape annotation is omitted) or the shapes of values that don't have shapes (like shapes themselves, paradoxically: they *are* shapes but do not *have* shapes). + +The `PrimExpr`s in a `ShapeCompExpr` can reference the same shape variables as in shape annotations, with the same semantics. + +**Restrictions** + +Shape computations are allowed to include calls to operators and even `PackedFunc`s, but these operators and `PackedFunc`s *must* be pure. Shape computations are primarily used for memory planning and it is at the compiler's discretion when, if ever, to evaluate them (except as described below), hence they must not have side effects. + +**Shape Annotations** + +For shape annotations, we use `ShapeCompExpr` as the grammar, as with `shape_` expressions. `ShapeExpr` is used to annotate shapes of tensor values, `Tuple` is used to annotate the shapes of tuple values, and `RuntimeDepShape` is used to indicate annotations that have been omitted or shapes that cannot be known at compile time (like the shapes of tensors whose rank is unknown at compile time). `Call` is used to annotate the shapes of calls to operators and `TupleGetItem` annotates the shapes of tuple indices. + +For example, suppose we have a tuple where some fields are tensors like the following: + +```python +x : Tuple(Tensor((m, n), "int32"), Tuple(), Tensor((), "int32"), Tensor(_, "int32")) = ... +``` + +It has the shape annotation + +```python +Tuple([ShapeExpr([m, n]), Tuple([]), ShapeExpr([]), RuntimeDepShape]) +``` + +Note that it is [a well-formedness requirement](https://www.notion.so/Informal-Relax-Language-Specification-d1fdedb8fae84f0d82b9f880f25e7370) that if any field in a type has a `ShapeExpr` annotation, it must be a `DynTensorType` with an `ndim` matching the number of dimensions in the `ShapeExpr`. For example, in the above function signatures, the `ndim` in the type annotations must be 2. + +### Assigning Shape Variables at the Start and End of a Function + +Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchShape`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchShape` bindings. Suppose a function has the following signature, where the `Ti` are type annotation and the `Si` are shape annotations: + +```python +def f(arg1 : (T1, S1), arg2 : (T2, S2), ..., argn : (Tn, Sn)) -> (Tr, Sr): + return body +``` + +This can be treated as a macro that expands to + +```python +def f(arg1 : T1, arg2 : T2, ..., argn : Tn) -> Tr: + check_annotation(arg1, T1, S1) + check_annotation(arg2, T2, S2) + ... + check_annotation(argn, Tn, Sn) + ret_var = body + check_annotation(ret_var, Tr, Sr) + return ret_var +``` + +Because `MatchShape` is defined only for tensor and shape values, we must use a macro to handle other possible types that may be passed into a function, given here in pseudocode: + +```python +def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: + if s is a ShapeExpr: + tmp = fresh_var() + # type checking should ensure that e is always a tensor + return SeqExpr( + [BindingBlock([MatchShape(tmp, e, s.dims)])], + tmp + ) + else if s is a Tuple: + # type checking should ensure that e is always a tuple and the lengths match + shapes = s.fields + tmp = fresh_var() + return SeqExpr( + [BindingBlock([ + VarBinding(tmp, e), + # recursive in case we have nested tuples + VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, 0), shapes[0])), + VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, 1), shapes[1])), + ..., + VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, n-1), shapes[n-1])) + ])], tmp + ) + else if s is a Call: + tmp = fresh_var() + return SeqExpr( + [BindingBlock([ + VarBinding(tmp, e), + # completely dynamic check that does not assign shape vars. + VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) + ])], tmp + ) + else if s is TupleGetItem: + val = s.tuple_value + if val is Tuple: + return check_annotation(e, val.fields[s.index]) + # otherwise, evaluate it + return SeqExpr( + [BindingBlock([ + VarBinding(tmp, e), + VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) + ])], tmp + ) + else if s is RuntimeDepShape: + # no need to check + return e +``` + +### Evaluating Shape Expressions + +Every shape expression in the program (`shape_`) is associated with a program expression. Other than in the above procedure for checking function parameter shapes and the return shape, the specification does not guarantee that any `shape_` expression will ever be evaluated or how many times it may be evaluated; `shape_` is intended primarily for the benefit of memory planning. Hence, all `shape_` expressions must be pure and must be guaranteed to terminate. The `shape_` for a given expression `e` is intended to be evaluated *before* `e`. + +Shape expressions follow the same evaluation rules as general program expressions. In particular, shape functions are permitted to reference any variable that is in scope at the point of its associated expression; i.e., when evaluated, they form closures that capture any free variables (Relax variables and shape variables) referenced in their body. The `RuntimeDepShape` expression has no semantics at run time and indicates a shape that cannot be predicted in advance. If a `RuntimeDepShape` is encountered at any point while dynamically checking a shape match (see the `check_annotation` procedure above), it should "short-circuit" the match and cause the match to succeed immediately. + +### Building Up `shape_` for Each Expression + +For each expression type, we can recursively build up an associated `shape_` expression according to the following rules: + +1. For `Constant(value)`, the `shape_` expression is a `ShapeExpr` corresponding to the concrete shape of `value`. For example, for `Constant(1)`, `shape_` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape_` is `ShapeExpr([2])`. +2. For `Tuple(fields)`, `shape_` can be defined as `Tuple([field.shape_ for field in fields])`. +3. For `ShapeExpr`s, `shape_` is `RuntimeDepShape`. +4. `RuntimeDepShape` expressions should appear only in shape expressions; their `shape_` is not defined. +5. For `If(cond, true_branch, false_branch)`, we compare the `shape_` of `true_branch` and `false_branch`. If these can be proven equivalent (by a method that the compiler implementation is free to determine), then the `If` node's `shape_` is that shape. If they do not match, then we set it to `RuntimeDepShape`. +6. For `SeqExpr`, we set the `shape_` to be the `shape_` of the body expression. The `shape_` must respect the scoping rules for the `SeqExpr`: If the `shape_` of the body expression contains shape variables not defined in the outer scope (i.e., shape variables that are scoped to the `SeqExpr` only) or if the `shape_` contains any `Var`s or `DataflowVar`s scoped to the `SeqExpr`, use `RuntimeDepShape` as the shape. +7. For handling variable bindings: + 1. For the arguments to a function, set the `shape_` to the annotated shape. If the annotation is omitted, use `RuntimeDepShape`. + 2. In the general `VarBinding(v, e)`, if `v` does not have a shape annotation or the annotation is `RuntimeDepShape`, then we set the `shape_` of `v` to the `shape_` of `e`. If `v` has a shape annotation, then if the `shape_` of `e` can be proven equivalent to the shape annotation, use the shape annotation for the `shape_` of `v`. «Otherwise, give an error and require an explicit `MatchShape`.» + + It is up to the compiler implementation to decide what method to use for attempting to prove equivalence. + + 3. For bindings where the RHS is a function literal or assigning the `shape_` of a `GlobalVar`, see the rule for `Function` nodes. + 4. For `MatchShape(var, value, shape)`, we set the `shape_` of `var` to `shape`, as it will be dynamically checked. +8. For `TupleGetItem(tuple_value, i)`, we examine the `shape_` of `tuple_value`; suppose it is `s`. If `s` is a `Tuple`, then we use its `i`th field. If it is `RuntimeDepShape`, we use `RuntimeDepShape`. If it is a `Call` to a function that returns a tuple with at least `i + 1` members, set the `shape_` to `TupleGetItem(s, i)`. Otherwise, raise an error at compile time (though this should not happen if type checking has passed). +9. For `Call` nodes: + 1. For a call to an `ExternFunc`, we use `RuntimeDepShape` because we cannot analyze the shapes of arbitrary `PackedFunc`s and must check dynamically. + 2. For a call to an `Op`, we use the manually defined `FInferShape` macro if it has been defined and `RuntimeDepShape` if it has not. `FInferShape` is a function that takes in the call node and produces a `ShapeCompExpr`. + 3. For all other cases with `Call(op, args)`, we consider the following cases: + 1. If `op` is a `GlobalVar` or a `Var` that refers to a function defined in the current scope, look up the `Function` node it references; let us call it `f`. Similarly, if `op` is itself a `Function` node, let `f` be `op`. + + Attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) on `f`'s return shape. A pseudocode procedure for this beta-reduction is given below, as a macro. + + 1. If the return shape of `f` is a `Call` node or contains any `Call` nodes, substitute any parameters of `f` for the corresponding member of `args`. (E.g., if `f` has parameters `p1`, `p2`, …, `pn` and any of these variables appears in the return shape, `p1` should be replaced with the first member of `args`; `p2`, with the second; etc.) If any member of `args` that is substituted this way is not a `Var` or `Constant`, consider beta-reduction to fail. + 2. For each shape annotation in the parameters of `f`, attempt to match it with the `shape_` of the corresponding member of `args`, substituting shape variables in the return shape accordingly. If the `shape_` of the member of `args` is `RuntimeDepShape`, consider beta-reduction to fail. If the `shape_` is not `RuntimeDepShape` but is incompatible with the parameter's shape annotation (e.g., a `Tuple` where a `ShapeExpr` was expected), report an error at compile time. + + If `f`'s return shape is `RuntimeDepShape`, then consider the call result to have `RuntimeDepShape`. If beta-reduction is considered to fail, then consider the call result to have `RuntimeDepShape`. If it succeeds, use the resulting shape as the `shape_` of the call result. + + 2. Otherwise, consider the result of the call to have `RuntimeDepShape`. +10. For a function node, set the `shape_` to `RuntimeDepShape`. + +### Procedure for Substituting a Function Return Shape to Determine the Shape of a Call + +The `substitute_shape` procedure defined below describes how the shape expression for a call result can be defined given the call arguments and the return shape annotation on the corresponding function node. Note that this procedure can obtain much more precise results in the cases of `Call` or `TupleGetItem` return shapes. + +```python +def map_shape_vars(param_shape: ShapeCompExpr, arg_shape: ShapeCompExpr, shape_var_mapping: {tir::Var : PrimExpr}) -> bool: + if param_shape is RuntimeDepShape or arg_shape is RuntimeDepShape: + return False + if param_shape is ShapeExpr and arg_shape is ShapeExpr: + if len(param_shape.values) != len(arg_shape.values): + raise UnificationError("Shapes are of incompatible ranks") + for param_dim, arg_dim in zip(param_shape.values, arg_shape.values): + if param_dim in shape_var_mapping: + # syntactic equality + if arg_dim != shape_var_mapping[param_dim]: + # if they are statically not equal, e.g., 5 != 7 or 3 + 3 != 3*3 + if can_prove_not_equal(arg_dim, shape_var_mapping[param_dim]): + raise UnificationError("Incompatible dimensions") + else: + return False + else: + shape_var_mapping[param_dim] = arg_dim + return True + if param_shape is Tuple and arg_shape is Tuple: + if len(param_shape.fields) != len(arg_shape.fields): + raise UnificationError("Tuples are of incompatible lengths") + for param_field, arg_field in zip(param_shape.fields, arg_shape.fields): + ret = map_shape_vars(param_field, arg_field, shape_var_mapping) + if not ret: + return False + return True + if param_shape is TupleGetItem and arg_shape is TupleGetItem: + # Does not necessarily indicate a unification error, + # depending on what the tuple values are. + # Constant folding the TupleGetItem nodes could improve this unification case + if param_shape.index != arg_shape.index: + return False + return map_shape_vars(param_shape.tup_value, arg_shape.tup_value) + if param_shape is Call and arg_shape is Call: + # no dimension mapping to do in this case + return True + # if either is a Call or TupleGetItem, it is possible that the shapes + # can match dynamically even if they don't match statically + if (param_shape is Call + or param_shape is TupleGetItem + or arg_shape is Call + or arg_shape is TupleGetItem): + return False + raise UnificationError("Incompatible shape constructs") + +def substitute_vars(target: Expr, var_mapping: {Var: Expr}, shape_var_mapping: {tir::Var: PrimExpr}) -> Expr: + def substitute_shape_vars(target: PrimExpr): + if target is tir::Var: + if target in shape_var_mapping: + return shape_var_mapping[target] + else: + return target + # proceed recursively in all subexpressions, checking for vars + + if target is Var: + if target in var_mapping: + return var_mapping[target] + return target + if target is ShapeExpr: + return ShapeExpr([ + substitute_shape_vars(dim) + for dim in target.values + ]) + # recurse through all other cases, checking for vars and shape exprs analogously + +def substitute_shape(func_params, arg_exprs, ret_shape): + var_mapping = {param: arg_expr for param, arg_expr in zip(func_params, arg_exprs)} + shape_var_mapping = {} + for param, arg_expr in zip(func_params, arg_exprs): + can_unify = map_shape_vars(param.shape_, arg_expr.shape_, shape_var_mapping) + if not can_unify: + return RuntimeDepShape() + + new_shape = substitute_vars(ret_shape, var_mapping, shape_var_mapping) + if new_shape contains any free (Relax or shape) variables: + return RuntimeDepShape() + return new_shape +``` + +### Note on Proving Shapes Equivalent and Eliminating Dynamic Checks + +There can be some complexity involved in checking whether two shapes match during shape inference. A very simple, conservative method for determining equality is simply using alpha-equivalence: If the two shapes have the same structure, then they are equivalent. However, this method is conservative and can overlook numerical properties in `PrimExpr`s. We leave it up to compiler implementations as to whether to use more advanced methods for proving equivalence, such as attempting to use algebraic rewrite rules. (As a consequence, portability requires inserting dynamic checks wherever there needs to be a comparison of shapes.) + +Note that optimizations like function inlining or constant folding could allow for simplifying many shape annotations and expressions and make it possible to conclude at compile time that shapes in more cases are equivalent. In general, developing compiler infrastructure for partial evaluation and reasoning about common situations with shape annotations may eliminate many dynamic checks. + +Applying some kind of normalization or algebraic simplifications to `PrimExpr`s used in shape annotations and in `shape_` fields can also make it easier to conclude that certain dynamic checks may not be necessary by increasing the likelihood that more `shape_` expressions could be made syntactically identical to the shape annotations. It would also be possible to generate compile-time warnings if analysis reveals that two shapes may not match (either using rewrite rules or by trying random values for shape variables and checking). + +Since most dynamic shape checks are done for safety, it may be feasible to introduce a compilation mode that eliminates almost all dynamic shape checks. Some shape checks may not be possible to eliminate, since the body of the program may construct `ShapeExpr`s and use them in calls to `PackedFunc`s, so some bindings to shape variables may need to be preserved, per a liveness analysis. + +## Possible Extensions to the Shape Expression System + +We may consider two possible extensions to the shape expression system in order to accommodate two further cases: + +1. An explicit wildcard dimension (e.g., using `tir::Any`) to allow for dimensions to be specified as "unknown" in function return shapes. As described at present, the only way for a function to specify a partly unknown return shape is to make the entire return shape unknown (`RuntimeDepShape`), which loses partial shape information. +2. Adding `shape_` expressions consisting of functions, to allow arbitrary closures to have a known shape. This would allow the shapes of calls to closures of unknown origin (namely, in a higher-order function) to have their shapes correctly inferred rather than made `RuntimeDepShape`. + +In both cases, these additions would entail additional complexity (shape inference macros for operators would have to deal with potential `tir::Any` nodes and we would have to define rules for constructing, calling, and simplifying functions in `shape_` expressions). However, the advantage of implementing these features would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using `RuntimeDepShape` means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchShape` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. + +# Detailed Semantics + +## Program Entry Point + +In the `IRModule`, every mapping of a `GlobalVar` to a `Function` node or a TIR `PrimFunc` should be processed first and added to the global scope. «Global functions that have a `global_symbol` attribute should be externally linked, meaning that they can be invoked as program entry points; those that do not have a `global_symbol` attribute can be called only from within the global functions in the `IRModule`.» + +The rules for evaluating `Function` nodes into closures are given below. TIR `PrimFunc`s evaluate into objects that are opaque to Relax; these objects have type `Object` and can be used only by the `call_tir` operator. None of the values in global scope is mutable. Execution of a Relax function in an IR module thus begins by evaluating all globally visible functions into a form in which they can be accessed. + +## Evaluating Expressions + +For each expression, we define how it affects the program's visible state and the order in which they are evaluated. Below, all evaluation results are passed by reference (and hence possibly alias) unless it is explicitly specified that they allocate new values. + +1. The node `Constant(value)` creates a new tensor whose contents are `value`. +2. A variable (whether `Var`, `DataflowVar` , or `GlobalVar`) evaluates to the stored value for that variable in the current scope. +3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. +4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per type checking, must evaluate to a tuple) and then returning the `i`th field of the result. +5. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. +6. `RuntimeDepShape` expressions must not appear in the general body of a program; it is a well-formedness error if they do. They do not have any defined semantics. +7. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. +8. The node `If(cond, true_branch, false_branch)` is evaluated as follows: + 1. First `cond` is evaluated. Let the result be `r` (per type checking, it must be a bool scalar). + 2. If `r` is true, evaluate the `true_branch` and return its result. + 3. If `r` is false, evaluate the `false_branch` and return its result. +9. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: + 1. If `op` is an `ExternFunc` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Next, look up the `PackedFunc` registered under the global symbol name. If it exists (it is an error at run time if it does not), call the `PackedFunc` using the given arguments and return the result. Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. + 2. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» + 3. In all other cases, first evaluate `op` (it must evaluate to a closure). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. +10. For the node `SeqExpr(blocks, body)`, we evaluate as follows: + 1. Push a new scope onto the stack. + 2. Iterate through the `BindingBlock`s in `blocks` in order. We will call the current one `block`. For each binding in `Block`: + 1. If the binding is `MatchShape(var, value, shape)`, perform the shape matching and shape variable updates as described in the shape evaluation section. If `var` is provided, `var` will be bound to `value` in the current scope; this assignment is aliasing and no new value is allocated. If `var` is not provided, then the shape check is performed and shape variables are updated, but no new binding is introduced. + 2. If the binding is `VarBinding(var, value)`, then evaluate `value` and bind `var` to that value in the current scope; this assignment is aliasing and no new value is allocated. + 3. If `block` is a `DataflowBlock`, remove all `DataflowVar`s bound in the block from the current scope before proceeding to the next block. + 3. After iterating through the binding blocks, evaluate `body` in the current scope. That will be the return value of the `SeqExpr`. + 4. Pop the scope, removing any `Var` bindings introduced in the `SeqExpr`. This should also remove any shape variables introduced and bound in the `SeqExpr` as well. + +### Optimizations + +Optimizations are allowed to reorder and modify the operations of a program in any way so long as they do not change the value returned by evaluating the program or any visible behavior of the program. For the purposes of compilation, visible behaviors consist of side effects like mutating values in the program or external effects like I/O (printing to the console, creating files, etc.) and the order and number of times in which they happen. + +«Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchShape` or `cast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» + +The specification makes no guarantees about certain memory-related properties and hence also does not consider them to be "visible behaviors": + +- Whether an allocation happens at a given point. Compiler implementations are permitted to reuse already-allocated memory if it would not interfere with visible state in any other way, per the aliasing rules (`PackedFunc`s or operators may mutate values that are passed to them and those mutations should be visible as per aliasing in this specification). Copying values or sharing representations (e.g., interning constants) between values may be done only if they will not affect any other visible behaviors, dependent on the aliasing behavior. +- It is entirely the domain of compiler implementations to make guarantees (or not) as to whether memory allocations will succeed. +- `PackedFunc`s or operators can, in principle, access information about the machine's state and make changes to allocation policies or the state that affect how memory allocations are performed. The specification makes no guarantees in such an event. + +These semantic rules assume a single thread of evaluation on a single host machine. At this time, it is unspecified as to how Relax programs should behave if split over distinct threads or across multiple machines. + +### Notable Operators + +The above evaluation rules are general, but leave much room for implementations of operators to specify custom semantics. Certain operators are used to perform common operations and will be discussed here as well. + +- `call_tir(prim_func, arg1, arg2, ..., argn, shape, type_args=[aT])`: `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). The `shape` argument gives the shapes of the result of calling the TIR `PrimFunc`: It must be either of `ShapeType` (corresponding to returning a single tensor) or `TupleType` whose members are `ShapeType` (corresponding to returning a tuples of tensors). The type arg `aT` gives the type of the result of calling the `PrimFunc` and it must correspond to `shape` (namely, if `shape` is of `ShapeType`, `aT` must be a `DynTensorType`; if `shape` is of `TupleType`, `aT` must be a `TupleType` whose fields are `ShapeType`). `aT` is used especially to provide the `dtype` of returned tensors. + + Based on `shape`, the resulting tensor or tuple `r` will be allocated according to the sizes given in `shape`. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. + +- `call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type. +- `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. +- «`cast(v, type_args=[aT])`: Given an argument `v`, it dynamically checks if `v`'s run-time representation is a subtype of `aT`. If it is not, it exits the program with an error message. Otherwise, it returns `v`.» + From 1c570e16dc185302b6aff37b9de60d1d6ee85858 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 25 Oct 2022 18:16:36 -0400 Subject: [PATCH 02/30] call_dps_packed is not yet implemented --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 40f69e4ba9..4123539323 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -743,7 +743,7 @@ The above evaluation rules are general, but leave much room for implementations Based on `shape`, the resulting tensor or tuple `r` will be allocated according to the sizes given in `shape`. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. -- `call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type. +- «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type.» - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. - «`cast(v, type_args=[aT])`: Given an argument `v`, it dynamically checks if `v`'s run-time representation is a subtype of `aT`. If it is not, it exits the program with an error message. Otherwise, it returns `v`.» From 3946c2a47ccd52c32d973a02fafd4a0720861277 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 25 Oct 2022 18:22:07 -0400 Subject: [PATCH 03/30] Many shape mechanics are still unimplemented --- relax_spec.md | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 4123539323..1f4c828868 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -432,16 +432,16 @@ Shape computations can consist of the following expressions, which are a subset ``` ShapeCompExpr ::= ShapeExpr(dims: [PrimExpr]) | RuntimeDepShape() - | Tuple(fields: [ShapeCompExpr]) + | «Tuple(fields: [ShapeCompExpr])» | Call(op: Op|ExternFunc, args: [Var|Constant]) - | TupleGetItem(tuple_value: ShapeCompExpr, index: int) + | «TupleGetItem(tuple_value: ShapeCompExpr, index: int)» ``` The shape expressions can be interpreted as follows: - `ShapeExpr` describes the shape of a tensor as a list of dimensions -- `Tuple` describes the shapes of each member of a tuple -- `TupleGetItem` describes the shape of a member of a tuple +- «`Tuple` describes the shapes of each member of a tuple» +- «`TupleGetItem` describes the shape of a member of a tuple» - `Call` describes the shape of a function (or operator) call return value in terms of its arguments - `RuntimeDepShape` describes shapes that are unknown at compile time (like when a shape annotation is omitted) or the shapes of values that don't have shapes (like shapes themselves, paradoxically: they *are* shapes but do not *have* shapes). @@ -453,9 +453,9 @@ Shape computations are allowed to include calls to operators and even `PackedFun **Shape Annotations** -For shape annotations, we use `ShapeCompExpr` as the grammar, as with `shape_` expressions. `ShapeExpr` is used to annotate shapes of tensor values, `Tuple` is used to annotate the shapes of tuple values, and `RuntimeDepShape` is used to indicate annotations that have been omitted or shapes that cannot be known at compile time (like the shapes of tensors whose rank is unknown at compile time). `Call` is used to annotate the shapes of calls to operators and `TupleGetItem` annotates the shapes of tuple indices. +For shape annotations, we use `ShapeCompExpr` as the grammar, as with `shape_` expressions. `ShapeExpr` is used to annotate shapes of tensor values, «`Tuple` is used to annotate the shapes of tuple values», and `RuntimeDepShape` is used to indicate annotations that have been omitted or shapes that cannot be known at compile time (like the shapes of tensors whose rank is unknown at compile time). `Call` is used to annotate the shapes of calls to operators and «`TupleGetItem` annotates the shapes of tuple indices.» -For example, suppose we have a tuple where some fields are tensors like the following: +«For example, suppose we have a tuple where some fields are tensors like the following: ```python x : Tuple(Tensor((m, n), "int32"), Tuple(), Tensor((), "int32"), Tensor(_, "int32")) = ... @@ -466,12 +466,13 @@ It has the shape annotation ```python Tuple([ShapeExpr([m, n]), Tuple([]), ShapeExpr([]), RuntimeDepShape]) ``` +» Note that it is [a well-formedness requirement](https://www.notion.so/Informal-Relax-Language-Specification-d1fdedb8fae84f0d82b9f880f25e7370) that if any field in a type has a `ShapeExpr` annotation, it must be a `DynTensorType` with an `ndim` matching the number of dimensions in the `ShapeExpr`. For example, in the above function signatures, the `ndim` in the type annotations must be 2. -### Assigning Shape Variables at the Start and End of a Function +### «Assigning Shape Variables at the Start and End of a Function» -Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchShape`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchShape` bindings. Suppose a function has the following signature, where the `Ti` are type annotation and the `Si` are shape annotations: +«Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchShape`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchShape` bindings. Suppose a function has the following signature, where the `Ti` are type annotation and the `Si` are shape annotations: ```python def f(arg1 : (T1, S1), arg2 : (T2, S2), ..., argn : (Tn, Sn)) -> (Tr, Sr): @@ -490,6 +491,7 @@ def f(arg1 : T1, arg2 : T2, ..., argn : Tn) -> Tr: check_annotation(ret_var, Tr, Sr) return ret_var ``` +» Because `MatchShape` is defined only for tensor and shape values, we must use a macro to handle other possible types that may be passed into a function, given here in pseudocode: @@ -502,7 +504,7 @@ def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: [BindingBlock([MatchShape(tmp, e, s.dims)])], tmp ) - else if s is a Tuple: + «else if s is a Tuple: # type checking should ensure that e is always a tuple and the lengths match shapes = s.fields tmp = fresh_var() @@ -515,7 +517,7 @@ def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: ..., VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, n-1), shapes[n-1])) ])], tmp - ) + )» else if s is a Call: tmp = fresh_var() return SeqExpr( @@ -525,7 +527,7 @@ def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) ])], tmp ) - else if s is TupleGetItem: + «else if s is TupleGetItem: val = s.tuple_value if val is Tuple: return check_annotation(e, val.fields[s.index]) @@ -535,7 +537,7 @@ def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: VarBinding(tmp, e), VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) ])], tmp - ) + )» else if s is RuntimeDepShape: # no need to check return e @@ -552,7 +554,7 @@ Shape expressions follow the same evaluation rules as general program expression For each expression type, we can recursively build up an associated `shape_` expression according to the following rules: 1. For `Constant(value)`, the `shape_` expression is a `ShapeExpr` corresponding to the concrete shape of `value`. For example, for `Constant(1)`, `shape_` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape_` is `ShapeExpr([2])`. -2. For `Tuple(fields)`, `shape_` can be defined as `Tuple([field.shape_ for field in fields])`. +2. «For `Tuple(fields)`, `shape_` can be defined as `Tuple([field.shape_ for field in fields])`.» 3. For `ShapeExpr`s, `shape_` is `RuntimeDepShape`. 4. `RuntimeDepShape` expressions should appear only in shape expressions; their `shape_` is not defined. 5. For `If(cond, true_branch, false_branch)`, we compare the `shape_` of `true_branch` and `false_branch`. If these can be proven equivalent (by a method that the compiler implementation is free to determine), then the `If` node's `shape_` is that shape. If they do not match, then we set it to `RuntimeDepShape`. @@ -565,11 +567,11 @@ For each expression type, we can recursively build up an associated `shape_` exp 3. For bindings where the RHS is a function literal or assigning the `shape_` of a `GlobalVar`, see the rule for `Function` nodes. 4. For `MatchShape(var, value, shape)`, we set the `shape_` of `var` to `shape`, as it will be dynamically checked. -8. For `TupleGetItem(tuple_value, i)`, we examine the `shape_` of `tuple_value`; suppose it is `s`. If `s` is a `Tuple`, then we use its `i`th field. If it is `RuntimeDepShape`, we use `RuntimeDepShape`. If it is a `Call` to a function that returns a tuple with at least `i + 1` members, set the `shape_` to `TupleGetItem(s, i)`. Otherwise, raise an error at compile time (though this should not happen if type checking has passed). +8. «For `TupleGetItem(tuple_value, i)`, we examine the `shape_` of `tuple_value`; suppose it is `s`. If `s` is a `Tuple`, then we use its `i`th field. If it is `RuntimeDepShape`, we use `RuntimeDepShape`. If it is a `Call` to a function that returns a tuple with at least `i + 1` members, set the `shape_` to `TupleGetItem(s, i)`. Otherwise, raise an error at compile time (though this should not happen if type checking has passed).» 9. For `Call` nodes: 1. For a call to an `ExternFunc`, we use `RuntimeDepShape` because we cannot analyze the shapes of arbitrary `PackedFunc`s and must check dynamically. 2. For a call to an `Op`, we use the manually defined `FInferShape` macro if it has been defined and `RuntimeDepShape` if it has not. `FInferShape` is a function that takes in the call node and produces a `ShapeCompExpr`. - 3. For all other cases with `Call(op, args)`, we consider the following cases: + 3. «For all other cases with `Call(op, args)`, we consider the following cases: 1. If `op` is a `GlobalVar` or a `Var` that refers to a function defined in the current scope, look up the `Function` node it references; let us call it `f`. Similarly, if `op` is itself a `Function` node, let `f` be `op`. Attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) on `f`'s return shape. A pseudocode procedure for this beta-reduction is given below, as a macro. @@ -580,6 +582,7 @@ For each expression type, we can recursively build up an associated `shape_` exp If `f`'s return shape is `RuntimeDepShape`, then consider the call result to have `RuntimeDepShape`. If beta-reduction is considered to fail, then consider the call result to have `RuntimeDepShape`. If it succeeds, use the resulting shape as the `shape_` of the call result. 2. Otherwise, consider the result of the call to have `RuntimeDepShape`. + » 10. For a function node, set the `shape_` to `RuntimeDepShape`. ### Procedure for Substituting a Function Return Shape to Determine the Shape of a Call From a914e2893323b5aadafd0d3abe2afbd8780f0f10 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 25 Oct 2022 18:23:41 -0400 Subject: [PATCH 04/30] Indicate datatype in AST diagram --- relax_spec.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index 1f4c828868..d38afff8f9 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -51,6 +51,12 @@ Type ::= DynTensorType(ndim: int, dtype: DataType) | TupleType(fields: [Type]) | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») +DataType ::= + Int(bitwidth: int) + | Float(bitwidth: int) + | Bool() + | Void() + # expressions Expr ::= Constant(data: NDArray) # scoped to functions or SeqExprs From b9aa684ff3930cde2d699b4ad73ef7cd0eae33ec Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 25 Oct 2022 18:37:14 -0400 Subject: [PATCH 05/30] Add text about variable shadowing --- relax_spec.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index d38afff8f9..6357a23c0d 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -200,6 +200,27 @@ There are four relevant scopes in Relax, which determine where variables are vis 3. `SeqExpr`: `Var` nodes defined in a `BindingBlock` in a `SeqExpr` node can be referenced in any later binding within the same `BindingBlock`, in any binding within any later `BindingBlock` in that `SeqExpr` node, or in the `SeqExpr`'s body expression. The variables defined in the `BindingBlock`s leave scope once the `SeqExpr` returns. 4. `DataflowBlock`: `DataflowVar`s introduced in a `DataflowBlock` can be referenced in any later binding within that `DataflowBlock`, but leave scope *once that `DataflowBlock` finishes executing*. Definitions in a `DataflowBlock` that are intended to leave the `DataflowBlock` should be bound to an ordinary `Var`. +Note that Relax variables must be bound _exactly_ once. A global variable is bound if it is mapped to a function in the `IRModule` and a local variable is bound if it appears as a function parameter or if it appears on the left-hand side (LHS) of a binding (`VarBinding` or `MatchShape`). + +«If there is another binding to a local variable with the same name as an already-bound variable, that is binding is considered to _shadow_ the previous binding, i.e., it is a binding to a new, distinct variable that happens to have the same name as the existing variable. The new, shadowing variable will exist only in the current scope; if the older variable was defined in an outer scope, then future uses of that name will refer to the older variable. [See the Wikipedia page for more information on variable shadowing.](https://en.wikipedia.org/wiki/Variable_shadowing)» + +Below is an example of shadowing, in pseudocode: + +```python +@R.function +def func(x: Tensor) -> Tensor: + if True: + # the true branch will be a nested SeqExpr and hence a new scope + # this x will shadow the function parameter x + x = R.const(1) + R.print(x) # prints 1 + # the inner x goes out of scope + else: + R.print("not executed") + R.print(x) # this x is the function parameter + return x +``` + # Well-Formedness Criteria Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid. From 548adc5976e042c0ab85291784d172bd6720f527 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 22 Nov 2022 22:45:38 -0500 Subject: [PATCH 06/30] Discuss differences from Relay --- relax_spec.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index 6357a23c0d..89866323ca 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -21,6 +21,12 @@ Though this document will use the TVMScript front end for some examples, specify This section will outline the grammar of Relax and give very brief descriptions of the different components, including the semantics, type system, and shape system. The rest of this document will provide more detailed descriptions of these facets of the language, including the validity conditions that the type system and shape system uphold. +## Differences from Relay + +Per the [original workshop paper](https://arxiv.org/abs/1810.00952) and the [later report](https://arxiv.org/abs/1904.08368), Relay was designed to be a high-level functional language for expressing deep learning models at a high level. While Relay is not entirely pure (the `Ref` type is modeled after reference types in SML and similar functional languages), the assumption in Relay is that tensor operators are generally pure, meaning that they do not change the program state other than by producing new values. Additionally, Relay's type system also requires operators to have type relations that infer static tensor types or conclude that a dimension is unknown at compile time (`Any`). The need to register type relations and ensure operators' purity makes it difficult to add new operators to Relay and particularly difficult to call directly into TIR or external libraries, which are often not pure; any such extension requires adding new operators and abstracting over any impurity. + +While Relax aims to be as general and expressive in Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has an associated shape _computation_ associated with it, in addition to a type. These shape computations support static reasoning about shapes in many cases, but also facilitate a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. + ## Grammar Below is a diagram of the various AST constructs in Relax, including types. In code, these are defined on the C++ side in `include/tvm/relax/{expr.h, type.h}` and in Python in `python/tvm/relax/{expr.py, ty.py}`. This diagram will give the names of the AST nodes and the types and names of their members. The semantics will describe what computation each construct represents; an AST is simply data. A Relax program consists of an `IRModule` with global variables bound to Relax functions that implement the computations of interest. From ca3b79279c0b980ac5535dbe0b745630d2e81ef0 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 22 Nov 2022 22:56:35 -0500 Subject: [PATCH 07/30] Correct typo --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 89866323ca..6084c4e286 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -25,7 +25,7 @@ This section will outline the grammar of Relax and give very brief descriptions Per the [original workshop paper](https://arxiv.org/abs/1810.00952) and the [later report](https://arxiv.org/abs/1904.08368), Relay was designed to be a high-level functional language for expressing deep learning models at a high level. While Relay is not entirely pure (the `Ref` type is modeled after reference types in SML and similar functional languages), the assumption in Relay is that tensor operators are generally pure, meaning that they do not change the program state other than by producing new values. Additionally, Relay's type system also requires operators to have type relations that infer static tensor types or conclude that a dimension is unknown at compile time (`Any`). The need to register type relations and ensure operators' purity makes it difficult to add new operators to Relay and particularly difficult to call directly into TIR or external libraries, which are often not pure; any such extension requires adding new operators and abstracting over any impurity. -While Relax aims to be as general and expressive in Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has an associated shape _computation_ associated with it, in addition to a type. These shape computations support static reasoning about shapes in many cases, but also facilitate a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. +While Relax aims to be as general and expressive as Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has an associated shape _computation_ associated with it, in addition to a type. These shape computations support static reasoning about shapes in many cases, but also facilitate a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. ## Grammar From 928533c16863bd841fc30c237b12c11fbde21794 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 22:48:23 -0500 Subject: [PATCH 08/30] Add description of PackedFuncType --- relax_spec.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 6084c4e286..5899fdfdca 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -56,6 +56,7 @@ Type ::= DynTensorType(ndim: int, dtype: DataType) | ObjectType() | TupleType(fields: [Type]) | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») + | PackedFuncType() DataType ::= Int(bitwidth: int) @@ -159,7 +160,8 @@ The types in Relax correspond to the broad categories of the values given above: 2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. 3. `ShapeType` corresponds to shape values. 4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» -5. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. +5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). Since packed functions are not first-class values (`ExternFunc` can appear only in the `op` position of a `Call` node), these do not actually correspond to any value in Relax, but can be used to assign a type to `ExternFunc` nodes. +6. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» @@ -342,7 +344,7 @@ def find_lub(T1 : Type, T2 : Type) -> Type: return T1 if T1 or T2 is ObjectType: return Object - if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType: + if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType, or both PackedFuncType: return ObjectType if T1 and T2 are both DynTensorType: res_ndim = T1.ndim @@ -393,13 +395,14 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType`. 5. The type of a `RuntimeDepShape` expression is `ShapeType`. 6. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. -7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: +7. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. +8. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: 1. If `op` is a Relax `Op` node, then we look up its registered `FInferType` property. `FInferType` is a macro that takes in the `Call` node and produces a type. We return the type `op.FInferType(Call(op, [a1, ..., an], type_args=[aT]))`. The implementation of `FInferType` is free to throw errors. 2. If `op` is `ExternFunc`, then use the sole member of `type_args` (calls to `ExternFunc`s are required to have exactly one `type_args` member) `aT` as the return type. Packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function itself to do any validation. 3. Otherwise, check the types of the subexpressions, left to right. Suppose `op` has the type `Tf`, `a1` has type `T1`, …, and `an` has type `Tn`. If `Tf` is not a function type with exactly `n` arguments, we consider it a type error and require an explicit cast. Suppose `Tf` has argument types `T1'`, `T2'`, …, `Tn'`. Consider it a type error if any of the following does not hold: `T1 <: T1'`, `T2 <: T2'`, …, or `Tn <: Tn'`. Then the return type is `Tf.ret_type`. -8. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. -9. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» -10. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. +9. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. +10. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» +11. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. @@ -413,7 +416,7 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 4. If `T'` is `TupleType` or `FuncType`, emit a type error. 2. If the current block is a `DataflowBlock`, remove any `DataflowVar`s from `Γ` after we have finished processing the bindings in that block. 3. Finally, the type of the `SeqExpr` is the type of `body` checked under the current `Γ`. Afterwards, remove all bindings added from the `b0`, `b1`, …, `bn` from `Γ`. -11. Let us consider a function `Function(v1 : T1, v2 : T2, ..., vn : Tn, attrs=a) -> Tr: body`. +12. Let us consider a function `Function(v1 : T1, v2 : T2, ..., vn : Tn, attrs=a) -> Tr: body`. 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, where `p` is `True` if a `pure` attribute is included and `False` otherwise. Remove `fv` from `Γ` before returning. 2. Add `v1` to `Γ` with type `T1`, `v2` with type `T2`, …, and `tn` with type `Tn`. Recursively type check `body` in this new context: 1. «Determining purity: If `body` contains any call to a function whose return type does not specify that it is pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not specify the `pure` attribute, then we consider the function to be (potentially) impure. If all calls are to functions whose return type specifies purity or that include the `pure` attribute on the call or `Op`, then the function is treated as pure.» From c710e6032f9bf98e5d40623f301dd64339d8b0a3 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 7 Dec 2022 17:05:56 -0500 Subject: [PATCH 09/30] Add a couple of missed references to PackedFuncType --- relax_spec.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 5899fdfdca..b690d6ab7b 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -258,13 +258,14 @@ Prior to type-checking and shape inference, Relax programs must conform to certa # Types in Relax -Relax presently has five types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: +Relax presently has six types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: 1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. 2. `ShapeType`, referring to shape values. 3. `FuncType`, referring to functions (closures). `FuncType`s specify the types of their parameters, a return type, and whether the function is pure. 4. `TupleType`, referring to tuple values, giving the types of their fields. -5. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. +5. `PackedFuncType`, referring to the type of PackedFunctions. +6. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. ## Subtyping @@ -298,7 +299,7 @@ def find_glb(T1 : Type, T2 : Type) -> Type?: return T2 if T2 is ObjectType: return T1 - if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType: + if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType, or not both PackedFuncType: return None if T1 and T2 are both DynTensorType: ret_ndim = T1.ndim From c92b3011b6a913ee172b19edb6176c8094c1d9f1 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 7 Dec 2022 17:07:56 -0500 Subject: [PATCH 10/30] Add forward pointer to the type-checking rule for local functions --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index b690d6ab7b..c41ddf04ad 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -406,7 +406,7 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 11. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» - 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. + 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, the type annotation `T` is not optional and we add `v` to `Γ` before type-checking the function body; see the rule for `Function` nodes.) 3. For each `MatchShape(v: T, e, shape_pattern)`, where `T` is an optional type annotation, let the checked type of `e` be `T'`. 1. If `T'` is `ShapeType`, then emit an error if `T` is not a supertype of `ShapeType`. Add `v` to `Γ` with type `T`. 2. If `T'` is `DynTensorType`: From be02660712c2cfce6e5da0cfdd5a046a611e018c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Dec 2022 23:35:07 -0500 Subject: [PATCH 11/30] Describe normal form in the spec --- relax_spec.md | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index c41ddf04ad..84868ac139 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -12,10 +12,11 @@ Though this document will use the TVMScript front end for some examples, specify 2. [Top-Level Program Organization](#top-level-program-organization-irmodule) 3. [Values in Relax](#values-in-relax) 4. [Variable Scoping](#variable-scoping) -5. [Well-Formedness Criteria](#well-formedness-criteria) -6. [Types in Relax](#types-in-relax) -7. [Shapes in Relax](#shapes-in-relax) -8. [Semantics](#detailed-semantics) +5. [Normal Form](#normal-form) +6. [Well-Formedness Criteria](#well-formedness-criteria) +7. [Types in Relax](#types-in-relax) +8. [Shapes in Relax](#shapes-in-relax) +9. [Semantics](#detailed-semantics) # Overview @@ -229,10 +230,31 @@ def func(x: Tensor) -> Tensor: return x ``` +# Normal Form + +To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the type- and shape-checking rules for operators rely on macros (`FInferType` and `FInferShape`), _this means that the structure of the program can affect type and shape inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these type- and shape-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying type- or shape-checking. + +The normal form for Relax is very similar to ANF; differences will be noted. Here are the criteria required for a program to be in normal form: +1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, `RuntimeDepShape`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. +2. `SeqExpr`s may appear only in the following locations: + 1. In the `body` field of a `Function` node. + 2. In the `true_branch` and `false_branch` fields of `If` nodes. +3. In fact, the `body` field of a `Function` node and the `true_branch` and `false_branch` fields of `If` nodes _must_ be `SeqExpr`s. If these fields are not `SeqExpr`s, they must be "wrapped" in a `SeqExpr`. +4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. + +Programs that are parsed should be "normalized" before performing type-checking or shape-checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: +1. For each function in the `IRModule`, ensure that the body is a `SeqExpr`. If the body is not a `SeqExpr`, wrap the function body in a `SeqExpr`, creating a new `BindingBlock` to hold `VarBinding`s for any non-leaf expressions that need to be bound to variables. +2. If the function body is already a `SeqExpr`, consolidate all `BindingBlock`s, then check if the `body` field of the `SeqExpr` is a leaf expression. If not, bind it to a new var in the final `BindingBlock` and replace the `SeqExpr` body with the new var. +3. If the function body is not a `SeqExpr`, then recurse down the body's AST, binding any nested non-leaf expressions to a var in the current scope (doing this process in breadth-first order from left to right will respect the evaluation order in the semantics). If the body itself is a non-leaf expression, finally bind it to a var and have the final `SeqExpr` return the new var. +4. If an `If` node is encountered, ensure the `true_branch` and `false_branch` fields are `SeqExpr`s (consolidate `BindingBlock`s if necessary) or "wrap" them in `SeqExpr`s in the same manner as the function body. +5. If a `SeqExpr` node is encountered as the `value` node in a binding, "flatten" the `SeqExpr` by adding its bindings to the current scope and replacing the `SeqExpr` with its body. If the `SeqExpr` body is a non-leaf expression, normalize it recursively in the same manner as in step 3 before replacing the binding. Note that if the current scope (the location of the binding) is a `DataflowBlock` and the nested `SeqExpr` contains an ordinary `BindingBlock`, that indicates a malformed program. + + # Well-Formedness Criteria -Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid. +Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid, which includes conforming to the expectations of the above-described normal form. +The following criteria apply to all programs (including before normalization): 1. `DataflowVar`s can be bound only inside `DataflowBlock`s. Additionally, a `DataflowVar` may not be used outside of the `DataflowBlock` in which it is defined. 2. A `Var` of any kind used in the program must be either a function parameter or appear on the LHS of a binding exactly once. In the binding where a `Var` is defined, the same `Var` is permitted to occur in the RHS of the binding only if the binding is defining a function (i.e., local functions are permitted to be recursive). 3. A `Var` of any kind may not appear before it is bound. Namely, if a `Var` is bound in a `BindingBlock` in a `SeqExpr`, that `Var` may not appear in bindings that precede the one where it appears on the LHS. @@ -256,6 +278,8 @@ Prior to type-checking and shape inference, Relax programs must conform to certa 15. «Any `PackedFunc` or operator called in a shape annotation or `shape_` expression must be pure and be annotated as such.» 16. The node `RuntimeDepShape` may appear only in shape annotations and `shape_` expressions. It has no defined semantics at run time. +Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. + # Types in Relax Relax presently has six types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: From e4ad832436d3ec80b2e7465f3d61e4026e524012 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 20 Dec 2022 19:54:54 -0500 Subject: [PATCH 12/30] Specify consolidating empty binding blocks --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 84868ac139..b27e75c8e0 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -240,7 +240,7 @@ The normal form for Relax is very similar to ANF; differences will be noted. Her 1. In the `body` field of a `Function` node. 2. In the `true_branch` and `false_branch` fields of `If` nodes. 3. In fact, the `body` field of a `Function` node and the `true_branch` and `false_branch` fields of `If` nodes _must_ be `SeqExpr`s. If these fields are not `SeqExpr`s, they must be "wrapped" in a `SeqExpr`. -4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. +4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. Empty `BindingBlock`s should be dropped. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. If all the `BindingBlock`s are empty, then the `blocks` field of the `SeqExpr` should be set to an empty list. Programs that are parsed should be "normalized" before performing type-checking or shape-checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: 1. For each function in the `IRModule`, ensure that the body is a `SeqExpr`. If the body is not a `SeqExpr`, wrap the function body in a `SeqExpr`, creating a new `BindingBlock` to hold `VarBinding`s for any non-leaf expressions that need to be bound to variables. From 1075ce6283fed30fd3c90d9f56d22341b5a72565 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 4 Jan 2023 19:16:43 -0500 Subject: [PATCH 13/30] Add ndim parameter to ShapeType --- relax_spec.md | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index b27e75c8e0..d0e9db96a5 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -53,7 +53,7 @@ PrimExpr ::= # (others may be added later, as deemed necessary) Type ::= DynTensorType(ndim: int, dtype: DataType) - | ShapeType() + | ShapeType(ndim: int) | ObjectType() | TupleType(fields: [Type]) | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») @@ -159,7 +159,7 @@ The types in Relax correspond to the broad categories of the values given above: 1. `DynTensorType` corresponds to tensor values, giving the scalar data type and the number of dimensions (rank), both of which are optional. 2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. -3. `ShapeType` corresponds to shape values. +3. `ShapeType` corresponds to shape values, optionally giving the number of dimensions in the shape. 4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» 5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). Since packed functions are not first-class values (`ExternFunc` can appear only in the `op` position of a `Call` node), these do not actually correspond to any value in Relax, but can be used to assign a type to `ExternFunc` nodes. 6. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. @@ -285,7 +285,7 @@ Additionally, the criteria for normal form listed in the previous section must a Relax presently has six types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: 1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. -2. `ShapeType`, referring to shape values. +2. `ShapeType`, referring to shape values. The number of dimensions in the shape as given as `ndim` and is optional (using -1 for `ndim` indicates an unknown number of dimensions). 3. `FuncType`, referring to functions (closures). `FuncType`s specify the types of their parameters, a return type, and whether the function is pure. 4. `TupleType`, referring to tuple values, giving the types of their fields. 5. `PackedFuncType`, referring to the type of PackedFunctions. @@ -325,6 +325,13 @@ def find_glb(T1 : Type, T2 : Type) -> Type?: return T1 if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType, or not both PackedFuncType: return None + if T1 and T2 are both ShapeType: + ret_ndim = T1.ndim + if ret_ndim == -1: + ret_ndim == T2.ndim + if ret_ndim != -1 and T2.ndim != ret_ndim: + return None + return ShapeType(ret_ndim) if T1 and T2 are both DynTensorType: ret_ndim = T1.ndim ret_dtype = T1.dtype @@ -371,6 +378,11 @@ def find_lub(T1 : Type, T2 : Type) -> Type: return Object if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType, or both PackedFuncType: return ObjectType + if T1 and T2 are both ShapeType: + res_ndim = T1.ndim + if T1.ndim != T2.ndim: + res_ndim = -1 + return ShapeType(res_ndim) if T1 and T2 are both DynTensorType: res_ndim = T1.ndim res_dtype = T1.dtype @@ -417,8 +429,8 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 1. «We type check the entire `IRModule` one function definition at a time. To handle mutual recursion, we prepopulate `Γ` with the annotated types of all global functions that are called mutually recursively. We then proceed to check the types of the global functions one at a time.» 2. Given a variable `v`, if `v` is in `Γ`, then we return `Γ[v]`. Otherwise, it is an error. 3. Given a constant expression `Constant(data)`, the type is `DynTensorType(ndim=n, dtype=d)` where `n` is the number of dimensions in `data` and `d` is its data type (recall that `data` is an `NDArray` literal). -4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType`. -5. The type of a `RuntimeDepShape` expression is `ShapeType`. +4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType(n)`, where `n` is the length of `dims`. +5. The type of a `RuntimeDepShape` expression is `ShapeType(-1)`. 6. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. 7. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. 8. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: @@ -432,7 +444,10 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, the type annotation `T` is not optional and we add `v` to `Γ` before type-checking the function body; see the rule for `Function` nodes.) 3. For each `MatchShape(v: T, e, shape_pattern)`, where `T` is an optional type annotation, let the checked type of `e` be `T'`. - 1. If `T'` is `ShapeType`, then emit an error if `T` is not a supertype of `ShapeType`. Add `v` to `Γ` with type `T`. + 1. If `T'` is `ShapeType`: + 1. Emit an error if `T` is not a supertype of `ShapeType`. + 2. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. + 3. Add `v` to `Γ` with type `T`. 2. If `T'` is `DynTensorType`: 1. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. Let the datatype of `T'` be `d`. 2. If `T` is not a supertype of `DynTensorType(ndim=len(shape_pattern), dtype=d)`, then emit an error. If `T` is a subtype of that type, emit an error and request a cast. From 8fd40082bfb6b3312dae678e746245a489647fd6 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 7 Jan 2023 21:10:25 -0500 Subject: [PATCH 14/30] StructInfo update --- relax_spec.md | 768 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 457 insertions(+), 311 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index d0e9db96a5..0e2b4c3660 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -2,7 +2,7 @@ Note: Text in «double chevrons» indicates features not present in the current prototype. -In order to develop and test Relax, it is important for compiler developers to agree on what a given program in Relax means and what makes it valid so that test cases can be evaluated independently of any particular Relax implementation. This document is intended to describe Relax's grammar constructs (its [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree), or AST), the semantics of its grammar (what the different constructs mean), Relax's type system and type-checking rules (what makes a Relax program valid), and its rules for reasoning about tensor shapes in detailed though still informal terms. If necessary, we may encode these rules more formally to allow for more automated analysis. +In order to develop and test Relax, it is important for compiler developers to agree on what a given program in Relax means and what makes it valid so that test cases can be evaluated independently of any particular Relax implementation. This document is intended to describe Relax's grammar constructs (its [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree), or AST), the semantics of its grammar (what the different constructs mean), Relax's type system and type-checking rules (what makes a Relax program valid), and its rules for reasoning about structural information (such as tensor shapes) in detailed though still informal terms. If necessary, we may encode these rules more formally to allow for more automated analysis. Though this document will use the TVMScript front end for some examples, specifying the mapping from Python's AST to Relax's AST will be deferred until the parser becomes more stable. @@ -15,7 +15,7 @@ Though this document will use the TVMScript front end for some examples, specify 5. [Normal Form](#normal-form) 6. [Well-Formedness Criteria](#well-formedness-criteria) 7. [Types in Relax](#types-in-relax) -8. [Shapes in Relax](#shapes-in-relax) +8. [Structural Information in Relax](#structural-information-in-relax) 9. [Semantics](#detailed-semantics) # Overview @@ -26,7 +26,7 @@ This section will outline the grammar of Relax and give very brief descriptions Per the [original workshop paper](https://arxiv.org/abs/1810.00952) and the [later report](https://arxiv.org/abs/1904.08368), Relay was designed to be a high-level functional language for expressing deep learning models at a high level. While Relay is not entirely pure (the `Ref` type is modeled after reference types in SML and similar functional languages), the assumption in Relay is that tensor operators are generally pure, meaning that they do not change the program state other than by producing new values. Additionally, Relay's type system also requires operators to have type relations that infer static tensor types or conclude that a dimension is unknown at compile time (`Any`). The need to register type relations and ensure operators' purity makes it difficult to add new operators to Relay and particularly difficult to call directly into TIR or external libraries, which are often not pure; any such extension requires adding new operators and abstracting over any impurity. -While Relax aims to be as general and expressive as Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has an associated shape _computation_ associated with it, in addition to a type. These shape computations support static reasoning about shapes in many cases, but also facilitate a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. +While Relax aims to be as general and expressive as Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has associated structural information associated with it, in addition to a type. This structural information supports static reasoning about tensor shapes in many cases, but also facilitates a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints and other structural properties to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. ## Grammar @@ -59,29 +59,33 @@ Type ::= DynTensorType(ndim: int, dtype: DataType) | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») | PackedFuncType() -DataType ::= - Int(bitwidth: int) - | Float(bitwidth: int) - | Bool() - | Void() +DataType ::= Int(bitwidth: int) + | Float(bitwidth: int) + | Bool() + | Void() + +StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, ndim: int) + | ShapeStructInfo(values: [PrimExpr]?, ndim: int) + | ObjectStructInfo() + | TupleStructInfo(fields: [StructInfo]) + | FuncStructInfo(params: [StructInfo]?, ret: StructInfo, derive_func: EnvFunc?*) # expressions Expr ::= Constant(data: NDArray) # scoped to functions or SeqExprs - | Var(name_hint: string) + | Var(name_hint: string, struct_info_annotation: StructInfo?) # scoped to DataflowBlocks - | DataflowVar(name_hint: string) + | DataflowVar(name_hint: string, struct_info_annotation: StructInfo?) | GlobalVar(name_hint: string) | Tuple(fields: [Expr]) | SeqExpr(blocks: [BindingBlock], body: Expr) - | Function(params: [Var], body: Expr, ret_type: Type?, attrs: Attrs?) + | Function(params: [Var], body: Expr, ret_struct_info: StructInfo?, attrs: Attrs?) | If(cond: Expr, true_branch: Expr, false_branch: Expr) | ExternFunc(global_symbol: string) | Call(op: Expr, args: [Expr], type_args: [Type], attrs: Attrs?) | ShapeExpr(values: [PrimExpr]) | TupleGetItem(tuple_value: Expr, index: int) | Op(op_name: string) - | RuntimeDepShape() # binding blocks (analogous to sequence of statements) BindingBlock ::= @@ -91,7 +95,7 @@ BindingBlock ::= # bindings (analogous to statements) Binding ::= VarBinding(var: Var|DataflowVar, value: Expr) - | MatchShape(var: (Var|DataflowVar)?, pattern: [PrimExpr], value: Expr) + | MatchCast(var: (Var|DataflowVar)?, struct_info: StructInfo, value: Expr) # Relax programs are IRModules. Modules may bind global variables either to # Relax functions or TIR PrimFuncs (specified separately). @@ -99,13 +103,15 @@ Binding ::= Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) ``` +*The `derive_func` field of `FuncStructInfo` is a macro in the meta-language: Given a function call and the variable mapping context, return the `StructInfo` of the result. This field is used only at compile time for reasoning about the `StructInfo` of calls to `ExternFunc`s. + ## Expression Survey This specification provides a more detailed description of what each expression and type represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. 1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). 2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. -3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchShape` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. +3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchShape` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. 1. For `ExternFunc` nodes, the call will look up the registered `PackedFunc` by its global symbol and will call it with the given arguments (note that a TIR `PrimFunc` can be compiled into a `PackedFunc` and called using `ExternFunc` by defining a `global_symbol` attribute in the `PrimFunc`). «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» @@ -118,22 +124,23 @@ This specification provides a more detailed description of what each expression 7. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. 8. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). - 2. `MatchShape`s: The `value` expression is evaluated and the resulting shape is dynamically checked against the shape denoted by the `PrimExpr`s in the `pattern` field. - 1. If `value` evaluates to a tensor value, the pattern will be checked against the shape of the tensor; if it evaluates to a shape value, the pattern will be checked directly against the shape. - 2. Any shape dimension in the pattern that consists of a single new shape variable is treated as a binding: The variable is bound to the size of the corresponding dimension of the value being matched. - 3. If the shapes do not match, an error is triggered. If there is a variable provided, the value is bound to the `var` expression (if the variable is omitted, the shape check is performed and any shape variables are updated, but no new binding is introduced). Shape variables introduced in a `SeqExpr` are similarly scoped to the `SeqExpr`. + 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. + 1. The types must match: All `StructInfo` variants correspond to a type (`TensorStructInfo` to `DynTensorType`, `ShapeStructInfo` to `ShapeType`, etc.) and each type corresponds to a value (`DynTensorType` to a tensor value, `ShapeType` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: + 1. For comparing tensor values to `TensorStructInfo`, `ndim` must match the number of dimensions in the tensor value (unless `ndim` is -1) and `dtype` must match the datatype used (unless `dtype` is `Void`). If `shape` has been specified, the shape of the value must match that encoded by `shape`; if specified, `shape` must be either a `Var` already bound in the current scope or a `ShapeExpr`. + 2. For comparing shape values to `ShapeStructInfo`, `ndim` must match the number of dimensions in the shape value (unless `ndim` is -1). If `values` has been specified, the shape value must match that encoded by `values`. + 3. «For comparing closures (function values) to `FuncStructInfo`, it is necessary for the compiled program to track run-time structural information for closures, since it is not possible to introspect the closure; this subject will be discussed in further detail later in the document.» + 2. When comparing tensor values with `TensorStructInfo` or shape values with `ShapeStructInfo`, any member of `shape` in `TensorStructInfo` (if `shape` is a `ShapeExpr`) or `values` in `ShapeStructInfo` that consists of a single new (hitherto unbound) shape variable is treated as a binding: The shape variable is bound to the size of the corresponding dimension of the value being matched. + 3. If there is a variable provided, the value is bound to the `var` expression (if the variable is omitted, the structural check is performed and any shape variables are updated, but no new binding is introduced). Shape variables introduced in a `SeqExpr` are similarly scoped to the `SeqExpr`. The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. 9. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. 10. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. - The function can have shape annotations on the parameters and a return shape parameter. When the function is called, the annotations on parameters checked against the argument values in similar fashion to `MatchShape` and can introduce new shape variables that are scoped to the function. + The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. «A function mapped bound to a `GlobalVar` can have a `global_symbol` attribute defined to indicate that it should be externally linked externally (be accessible outside the `IRModule`). The absence of a `global_symbol` attribute on a function definition bound to a `GlobalVar` indicates that it is "private" and hence can be called only within the `IRModule`.» -11. `RuntimeDepShape` nodes are used to denote that a shape is unknown at compile time and must be deduced at run time. These nodes may appear only in shape annotations and have no run-time semantics of their own. - ## Purity and Dataflow Blocks A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. «In Relax, we conservatively assume that any function that calls an impure function is itself impure, though the attribute `force_pure` on a function can be used as an override (e.g., if a function creates a new tensor, mutates it, and returns it, that is still pure but does not satisfy the conservative rule).» @@ -142,10 +149,10 @@ Above, it is mentioned that `DataflowBlock`s are not allowed to contain construc There is one visible side effect that Relax permits inside otherwise "pure" functions, namely exiting the program with an error. This can arise in the following cases: -- Shape matching errors (from `MatchShape` or from implicit shape checks upon calling a Relax function) -- Errors raised by otherwise pure Relax operators or `PackedFunc`s, such as in `cast` (which dynamically checks types). Since the purity of operators or `PackedFunc`s must be manually registered, this means that it is permissible to register an operator or `PackedFunc` as being pure if its only side effect is issuing an error in some cases. +- Casting errors (from `MatchCast` or from implicit structural information checks upon calling a Relax function) +- Errors raised by otherwise pure Relax operators or `PackedFunc`s. Since the purity of operators or `PackedFunc`s must be manually registered, this means that it is permissible to register an operator or `PackedFunc` as being pure if its only side effect is issuing an error in some cases. -Even though an abnormal program exit is a visible side effect and removing or reordering it changes the observable semantics, it would be too great a restriction to prohibit error checking inside `DataflowBlock`s. Relax does not have any notion of exception handling, so the only consequence of a failed safety check can be exiting the program. It is permissible for the compiler to reorder, duplicate, or eliminate `MatchShape`, `cast`, or otherwise pure operations that have the potential of failing, provided that doing so does not change the value returned by the program or any other visible behavior. +Even though an abnormal program exit is a visible side effect and removing or reordering it changes the observable semantics, it would be too great a restriction to prohibit error checking inside `DataflowBlock`s. Relax does not have any notion of exception handling, so the only consequence of a failed safety check can be exiting the program. It is permissible for the compiler to reorder, duplicate, or eliminate `MatchCast`, or otherwise pure operations that have the potential of failing, provided that doing so does not change the value returned by the program or any other visible behavior. To indicate that an operator or `PackedFunc` that can abort with an error should *never* be reordered or removed by the compiler, it should *not* be marked as pure. However, this means that it cannot be used inside a `DataflowBlock`. @@ -166,7 +173,7 @@ The types in Relax correspond to the broad categories of the values given above: The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» -## Shape System Survey +## Structural Information System Survey In Relax, tensor shapes are not handled in the type system; each expression instead a has an associated shape expression. In many cases, these shape computations can allow for statically concluding that two shapes are the same and thus eliminate the need for dynamic checks via `MatchShape`. However, when shapes cannot be statically concluded to be the same, it may be necessary for there to be dynamic checks. The compiler is also free to make use of shape expressions for memory planning purposes. «Relax is "strongly shaped," meaning that if the compiler cannot conclude that shapes match in certain cases, an error will be issued and an explicit `MatchShape` will be required.» @@ -183,7 +190,7 @@ Oftentimes, compiler passes operate only on particular functions or add new func Here are the classes of values that Relax operates over, meaning that they can be assigned to variables or be the result of evaluating expressions. - *Tensors* are n-dimensional arrays of scalar values (which can be integers of fixed bitwidths, floats of fixed bitwidths, or bools). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. -- *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations return no value (as may be the case in some `PackedFunc` or operator calls that have side effects). +- *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time type information* (RTTI) indicating their argument types and result type, in order to facilitate dynamic type checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTTI is left up to the compiler implementation to determine so long as the `cast` operator can verify the type of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values of type `ObjectType`. @@ -222,8 +229,8 @@ def func(x: Tensor) -> Tensor: # the true branch will be a nested SeqExpr and hence a new scope # this x will shadow the function parameter x x = R.const(1) - R.print(x) # prints 1 - # the inner x goes out of scope + R.print(x) # prints 1 + # the inner x goes out of scope else: R.print("not executed") R.print(x) # this x is the function parameter @@ -232,17 +239,17 @@ def func(x: Tensor) -> Tensor: # Normal Form -To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the type- and shape-checking rules for operators rely on macros (`FInferType` and `FInferShape`), _this means that the structure of the program can affect type and shape inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these type- and shape-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying type- or shape-checking. +To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the type- and structure-checking rules for operators rely on macros (`FInferShapeInfo`), _this means that the structure of the program can affect type and structure inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these type- and structure-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying type or structure checking. The normal form for Relax is very similar to ANF; differences will be noted. Here are the criteria required for a program to be in normal form: -1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, `RuntimeDepShape`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. +1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. 2. `SeqExpr`s may appear only in the following locations: 1. In the `body` field of a `Function` node. 2. In the `true_branch` and `false_branch` fields of `If` nodes. 3. In fact, the `body` field of a `Function` node and the `true_branch` and `false_branch` fields of `If` nodes _must_ be `SeqExpr`s. If these fields are not `SeqExpr`s, they must be "wrapped" in a `SeqExpr`. 4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. Empty `BindingBlock`s should be dropped. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. If all the `BindingBlock`s are empty, then the `blocks` field of the `SeqExpr` should be set to an empty list. -Programs that are parsed should be "normalized" before performing type-checking or shape-checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: +Programs that are parsed should be "normalized" before performing type checking or structure checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: 1. For each function in the `IRModule`, ensure that the body is a `SeqExpr`. If the body is not a `SeqExpr`, wrap the function body in a `SeqExpr`, creating a new `BindingBlock` to hold `VarBinding`s for any non-leaf expressions that need to be bound to variables. 2. If the function body is already a `SeqExpr`, consolidate all `BindingBlock`s, then check if the `body` field of the `SeqExpr` is a leaf expression. If not, bind it to a new var in the final `BindingBlock` and replace the `SeqExpr` body with the new var. 3. If the function body is not a `SeqExpr`, then recurse down the body's AST, binding any nested non-leaf expressions to a var in the current scope (doing this process in breadth-first order from left to right will respect the evaluation order in the semantics). If the body itself is a non-leaf expression, finally bind it to a var and have the final `SeqExpr` return the new var. @@ -258,8 +265,8 @@ The following criteria apply to all programs (including before normalization): 1. `DataflowVar`s can be bound only inside `DataflowBlock`s. Additionally, a `DataflowVar` may not be used outside of the `DataflowBlock` in which it is defined. 2. A `Var` of any kind used in the program must be either a function parameter or appear on the LHS of a binding exactly once. In the binding where a `Var` is defined, the same `Var` is permitted to occur in the RHS of the binding only if the binding is defining a function (i.e., local functions are permitted to be recursive). 3. A `Var` of any kind may not appear before it is bound. Namely, if a `Var` is bound in a `BindingBlock` in a `SeqExpr`, that `Var` may not appear in bindings that precede the one where it appears on the LHS. -4. «A return shape annotation for a function is not allowed to use any shape variables that are not in scope at the function definition. That is, the only shape variables that can appear on the return shape annotation are those defined in the outer scope or those introduced in the argument shape annotations.» -5. In each function, `PrimExpr` variables (shape variables) similarly may not appear in `ShapeExpr`s or shape annotations before the shape variables are bound (either in function signatures or `MatchShape` bindings). A shape variable is bound only when it appears in a dimension by itself (for example, a dimension consisting of `x` will bind `x`; however, `2*x` is not a binding and is considered an error if `x` has not yet been bound) in a `MatchShape` node or a function argument shape annotation. +4. «A return structural annotation for a function is not allowed to use any shape variables that are not in scope at the function definition. That is, the only shape variables that can appear on the return structural annotation are those defined in the outer scope or those introduced in the argument structural annotations.» +5. In each function, `PrimExpr` variables (shape variables) similarly may not appear in `ShapeExpr`s or shape annotations before the shape variables are bound (either in function signatures or `MatchCast` bindings). A shape variable is bound only when it appears in a dimension by itself (for example, a dimension consisting of `x` will bind `x`; however, `2*x` is not a binding and is considered an error if `x` has not yet been bound) in a `MatchCast` node or a function argument shape annotation. 6. The following constructs are not permitted to occur inside `DataflowBlock`s, which must be side effect– and control flow–free: 1. Recursive calls to the current function 2. Calls to a global function that is mutually recursive with the current function @@ -267,22 +274,25 @@ The following criteria apply to all programs (including before normalization): «Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during type checking.» -7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return type annotation is *required*. [TODO: Do we also require a return shape annotation in such cases?]» +7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return structural annotation is *required*.» 8. `Op` nodes may appear only as the `op` argument to `Call` nodes. 9. `ExternFunc` expressions may appear only as the `op` argument to `Call` nodes. -10. The `type_args` field is used only for type checking calls of `ExternFunc`s and `Op`s. Calls to `ExternFunc`s must have exactly one type argument, indicating the return type. Calls to `Op`s may use `type_args` as they wish. No other calls may have a non-empty `type_args`. +10. The `type_args` field is used only for type checking calls of `ExternFunc`s and `Op`s. No other calls may have a non-empty `type_args`. 11. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. 12. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. 13. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» 14. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» -15. «Any `PackedFunc` or operator called in a shape annotation or `shape_` expression must be pure and be annotated as such.» -16. The node `RuntimeDepShape` may appear only in shape annotations and `shape_` expressions. It has no defined semantics at run time. +15. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. +16. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. +17. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. # Types in Relax -Relax presently has six types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: +Relax's type system is intended to enforce strong guarantees that values are passed correctly between expressions. The design emphasis is on simplicity, aiming to leave more complex analysis to the structural information. + +Relax presently has six types, corresponding to the values in the language: 1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. 2. `ShapeType`, referring to shape values. The number of dimensions in the shape as given as `ndim` and is optional (using -1 for `ndim` indicates an unknown number of dimensions). @@ -291,6 +301,32 @@ Relax presently has six types, defined in the implementation in `python/tvm/rela 5. `PackedFuncType`, referring to the type of PackedFunctions. 6. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. +## Erasing Structural Information into Types + +Several type-checking rules rely on structural annotations or rules for defining the structural information for a call to an `Op` or `PackedFunc`. In general, types are simpler than structural information (to facilitate more precise reasoning). Structural information can be convereted into a type as follows (in pseudocode): + +```python +def erase_struct_info(si: StructInfo) -> Type: + if si is TensorStructInfo: + return DynTensorType(ndim=si.ndim, dtype=si.dtype) + if si is ShapeStructInfo: + return ShapeType(ndim=si.ndim) + if si is TupleStructInfo: + return TupleType(fields=[erase_struct_info(field) for field in si.fields]) + if si is FuncStructInfo: + # this should be the case only for packed funcs + if si.params is not specified: + return PackedFuncType() + return FuncType( + arg_types=[erase_struct_info(arg_type) for arg_type in si.params], + ret_type=erase_struct_info(si.ret) + pure=False) # TODO: This suggests we should either handle purity + # in StructInfo entirely (and not make it part of the type) + # or include it in both StructInfo and the type system + # only remaining case is ObjectStructInfo + return ObjectType() +``` + ## Subtyping Relax implements subtyping, which means that members of types can be accepted where members of their supertypes are accepted. We will denote the subtyping relationship as `T1 <: T2`, indicating that `T1` is a subtype of `T2`. For example. if `T1 <: T2` and some function expects an argument of type `T2`, then passing a member of type `T1` to that function is permitted; passing a member of type `T2` as an argument to a function that expects type `T1` for that argument is *not* permitted—the value would have to be dynamically cast to `T1` using the `cast` operator. @@ -416,7 +452,7 @@ def find_lub(T1 : Type, T2 : Type) -> Type: For two types `T1` and `T2`, if `T1 <: T2`, then a value of type `T1` can be passed anywhere a value of type `T2` is expected without any need for type conversions or dynamic checks. -*However*, if `T1 <: T2`, then passing a value of type `T2` where `T1` is expected can only be done if there has been some kind of dynamic check or conversion of that value. «Relax is *strongly* *typed*, meaning that the compiler will give an error in this situation and require an explicit conversion via the `cast` operator, which inspects the value's run-time representation and exits the program with an error message if the value is not a subtype of T1.» +*However*, if `T1 <: T2`, then passing a value of type `T2` where `T1` is expected can only be done if there has been some kind of dynamic check or conversion of that value. «Relax is *strongly typed*, meaning that the compiler will give an error in this situation and require an explicit conversion via a `MatchCast` node, which inspects the value's run-time representation.» If `T1` is not a subtype of `T2` and `T2` is not a subtype of `T1`, then it is always a type error to pass a value of either type where a value of the other is expected (no member of either type can be a member of the other). @@ -430,339 +466,452 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 2. Given a variable `v`, if `v` is in `Γ`, then we return `Γ[v]`. Otherwise, it is an error. 3. Given a constant expression `Constant(data)`, the type is `DynTensorType(ndim=n, dtype=d)` where `n` is the number of dimensions in `data` and `d` is its data type (recall that `data` is an `NDArray` literal). 4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType(n)`, where `n` is the length of `dims`. -5. The type of a `RuntimeDepShape` expression is `ShapeType(-1)`. -6. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. -7. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. -8. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: - 1. If `op` is a Relax `Op` node, then we look up its registered `FInferType` property. `FInferType` is a macro that takes in the `Call` node and produces a type. We return the type `op.FInferType(Call(op, [a1, ..., an], type_args=[aT]))`. The implementation of `FInferType` is free to throw errors. - 2. If `op` is `ExternFunc`, then use the sole member of `type_args` (calls to `ExternFunc`s are required to have exactly one `type_args` member) `aT` as the return type. Packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function itself to do any validation. +5. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. +6. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. +7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT1, aT2, ..., aTn])`: + 1. If `op` is a Relax `Op` node, then we look up its registered `FInferStructInfo` property. `FInferStructInfo` is a macro that takes in the `Call` node and produces structural information. Invoke `op.FInferStructInfo(Call(op, [a1, ..., an], type_args=[aT1, aT2, ..., aTn]))` and convert the result to a type using the `erase_struct_info` procedure defined above. The implementation of `FInferStructInfo` is free to throw errors. + 2. If `op` is `ExternFunc`, note that packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function's implementation to do any validation at run time. However, the type system uses the `type_args` field to determine the result type as follows: + 1. If there are no `type_args`, the resulting type is `ObjectType()`. + 2. If there is exactly one member of `type_args`, use that as the return type. + 3. If there are multiple members of `type_args`, then the type is `TupleType(fields=[aT1, aT2, ..., aTn])`. 3. Otherwise, check the types of the subexpressions, left to right. Suppose `op` has the type `Tf`, `a1` has type `T1`, …, and `an` has type `Tn`. If `Tf` is not a function type with exactly `n` arguments, we consider it a type error and require an explicit cast. Suppose `Tf` has argument types `T1'`, `T2'`, …, `Tn'`. Consider it a type error if any of the following does not hold: `T1 <: T1'`, `T2 <: T2'`, …, or `Tn <: Tn'`. Then the return type is `Tf.ret_type`. -9. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. -10. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» -11. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. +8. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. +9. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» +10. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» - 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, the type annotation `T` is not optional and we add `v` to `Γ` before type-checking the function body; see the rule for `Function` nodes.) - 3. For each `MatchShape(v: T, e, shape_pattern)`, where `T` is an optional type annotation, let the checked type of `e` be `T'`. - 1. If `T'` is `ShapeType`: - 1. Emit an error if `T` is not a supertype of `ShapeType`. - 2. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. - 3. Add `v` to `Γ` with type `T`. - 2. If `T'` is `DynTensorType`: - 1. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. Let the datatype of `T'` be `d`. - 2. If `T` is not a supertype of `DynTensorType(ndim=len(shape_pattern), dtype=d)`, then emit an error. If `T` is a subtype of that type, emit an error and request a cast. - 3. Add `v` to `Γ` with type `T`. - 3. If `T'` is `ObjectType`, then the only type we can conclude for `v` is `ObjectType`. If `T` is not `ObjectType`, emit an error and request a cast. - 4. If `T'` is `TupleType` or `FuncType`, emit a type error. + 2. For each binding `VarBinding(v, e)` in the current block, check the type of `e` and suppose it is `T'`. If `v` has a structural annotation, then let `T` be the corresponding type (via the `erase_struct_info` procedure above). If there is no annotation, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and otherwise add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, we require `v` to have a structural annotation add `v` to `Γ` with its annotated type before type-checking the function body; see the rule for `Function` nodes.) + 3. For each `MatchCast(v, e, struct_info)`: + 1. Check the type of `e` and let it be `T'`. + 2. Let `T''` be the type corresponding to `struct_info` (via the `erase_struct_info` procedure). + 3. Emit a warning if `T'` is not a supertype of `T''` and `T''` is also not a supertype of `T'`; this indicates that the cast is _guaranteed_ to fail at run time. + 4. If `v` has been defined and it has a structural annotation, then let `T` be its corresponding type (via `erase_struct_info`). + 5. If `T` has been defined, then emit an error if `T` is not a supertype of `T''`. + 6. If `v` has been defined and does not have a structural annotation, then add `v` to `Γ` with type `T''`. If `T` has also been defined, then add `v` to `Γ` with type `T`. 2. If the current block is a `DataflowBlock`, remove any `DataflowVar`s from `Γ` after we have finished processing the bindings in that block. 3. Finally, the type of the `SeqExpr` is the type of `body` checked under the current `Γ`. Afterwards, remove all bindings added from the `b0`, `b1`, …, `bn` from `Γ`. -12. Let us consider a function `Function(v1 : T1, v2 : T2, ..., vn : Tn, attrs=a) -> Tr: body`. - 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, where `p` is `True` if a `pure` attribute is included and `False` otherwise. Remove `fv` from `Γ` before returning. +11. Let us consider a function `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`. All of the vars are required to have structural annotations; let `T1` be the type corresponding to `v1`'s annotation (via `erase_struct_info`), `T2` be the type corresponding to `v2`'s annotation, etc.. + 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, «where `p` is `True` if a `pure` attribute is included and `False` otherwise». Remove `fv` from `Γ` before returning. 2. Add `v1` to `Γ` with type `T1`, `v2` with type `T2`, …, and `tn` with type `Tn`. Recursively type check `body` in this new context: 1. «Determining purity: If `body` contains any call to a function whose return type does not specify that it is pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not specify the `pure` attribute, then we consider the function to be (potentially) impure. If all calls are to functions whose return type specifies purity or that include the `pure` attribute on the call or `Op`, then the function is treated as pure.» 2. «Suppose the purity defined in the previous step is `p'`. Suppose the annotated function purity (in the attributes) is `p`. If `p'` is false while `p` is true, then it is a type error; if `p` was omitted, use `p'` for `p`.» 3. «If the function has the attribute "`force_pure`," then consider `p` to be true, even if the check above judged the function not to be pure. The compiler may emit a warning in this situation.» - 4. Suppose the result of type-checking `body` is `Tr'`. If the current function is not recursive or a mutually recursive global function and `Tr` was omitted, consider the function to have type `FuncType([T1, T2, ..., Tn], Tr', pure=p)`. If `Tr' <: Tr`, then we consider the function to have type `FuncType([T1, T2, ..., Tn], Tr, pure=p)`. «If `Tr <: Tr'`, return an error and require an explicit cast in the function body.» If `Tr'` is not a subtype of `Tr` and `Tr` is also not a subtype of `Tr'` (meaning a dynamic cast cannot succeed), this is an error. + 4. Suppose the result of type-checking `body` is `Tr'`. If the current function is not recursive or a mutually recursive global function and `ret_struct_info` is undefined, consider the function to have type `FuncType([T1, T2, ..., Tn], Tr', pure=p)`. If `ret_struct_info` is defined, then let `Tr` be `erase_struct_info(ret_struct_info)`. If `Tr' <: Tr`, then we consider the function to have type `FuncType([T1, T2, ..., Tn], Tr, pure=p)`. «If `Tr <: Tr'`, return an error and require an explicit cast in the function body.» If `Tr'` is not a subtype of `Tr` and `Tr` is also not a subtype of `Tr'` (meaning a dynamic cast cannot succeed), this is an error. 5. Remove `v1`, `v2`, …, and `vn` from `Γ` before returning. -# Shapes in Relax +# Structural Information in Relax -In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. In Relax, to allow for greater flexibility for variable-shape tensors and make it easier to implement new operators, shapes can be checked at run time. Though every expression in Relax has a shape associated with it just as expressions also have types, there is no requirement that the shape be expressed at compile time. Instead, the compiler merely requires that an expression's shape define *a way* to compute a fully specified shape at run time. Users have the ability to make use of shape variables and arithmetic expressions to encode a wide variety of shape constraints that can be checked dynamically. +In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. While this allows Relay's type system to make strong guarantees about tensor shapes, it results in greater complexity in type checking and makes it difficult to implement new operators or handle cases like tensors with symbolic shapes. -Nevertheless, in many cases, these shapes can be analyzed at compile time (particularly when they are consist of constants or deducible variables) to facilitate compile-time optimization much like is possible with Relay or TIR. Through constant propagation, function inlining, and other partial evaluation–like transformations, we can potentially eliminate many more dynamic checks by allowing some shape computations to be simplified at compile time. +Relax instead aims to facilitate analysis of more complex properties like shapes by tracking _structural information_ pertaining, encoding as much analysis as is feasible at compile-time in a _"best-effort"_ fashion. Anything that cannot be proved statically can instead be checked at run time. Each Relax expression has structural information associated with it just as it has a type. Indeed, the structural information for each expression can be simplified into a type (recall [the procedure for doing so](#erasing-structural-information-into-types)), so the structural information for an expression can be thought of as an extended type that is checked in a less precise manner. The best-effort nature of the structural system in Relax means that the analysis may detect _some_ errors at compile time and report them, but it may give warnings when it _cannot_ draw conclusions, perhaps suggesting that dynamic checks via `MatchCast` should be inserted. Note that the precision of the static analysis can potentially be improved by some compile-time optimizations like constant propagation, function inlining, and other partial evaluation–like transformations. -## Defining Shape Computations +Tensor shapes are the primary motivation for including structural information in Relax, as shape information is particularly important for memory planning. Relax's structural information system uses expressions to encode tensor shapes, which allows for using shape variables and arithmetic expressions to encode a rich variety of shape constraints. Note, however, that the structural system could potentially be extended to encode and analyze further information, like tensor sparsity or density. -In Relax, each expression has an associated shape computation, which defines how that expression's shape can be computed based on the shapes of its subexpressions. We will refer to this computation as `shape_`, as that is what it is called in the implementation. This essentially serves as a mechanism for propagating shape annotations on variable bindings and function definitions to other expressions and enable more compile-time analysis of shapes. In particular, `shape_` is useful for memory planning. These computations can also be used to simplify shape checking and eliminate many dynamic checks. +## Defining Structural Information -### Expressing Dimensions +As with types, the structural information in Relax corresponds to the values in the language: +* `TensorStructInfo` describes tensor values. Like in `DynTensorType`, the `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` whose type is `ShapeType`. If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation (that returns a shape). which can be useful for memory planning. +* `ShapeStructInfo` describes shape values. Like `ShapeType`, it has an `ndim` field that gives the number of dimensions in the shape (with -1 indicating that it is statically unknown). It additionally has an optional `values` field. If defined, `values` gives a list of `PrimExpr`s that indicate how to compute the dimensions of the shape, potentially providing further information for static analyses. +* `TupleStructInfo` describes tuple values, namely by giving the structural information for each of the tuple's members via `fields`. +* `FuncStructInfo` describes closure values or `PackedFunc`s. There are two ways in which to specify `FuncStructInfo`: + 1. By specifying `params` and `ret` (for closures). `params` gives the structural information corresponding to each of the function's parameters and `ret` gives the structural information corresponding to the result of calling the function. In this case, the `derive_func` field is left undefined. + 2. By giving a `derive_func` macro (for `PackedFunc`s). The `derive_func` macro is takes a call to the corresponding `PackedFunc` and the variable mapping context and returns the `StructInfo` of the result. In this case, the `params` field is left undefined and the `ret` field is ignored. +* `ObjectStructInfo` describes arbitrary object values. -A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds to a dimension. The use of TIR `PrimExpr`s for shape dimension allows shape computations to express complex constraints that include variables and integer arithmetic expressions in addition to just constant dimensions. +While these categories correspond closely to types, they serve as a mechanism for propagating further information (especially as given in shape annotations in variable bindings) throughout the program and facilitating more static analysis. + +### Expressing Shape Dimensions + +A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds to a dimension. The use of TIR `PrimExpr`s for shape dimensions allows shape computations to express complex constraints that include variables and integer arithmetic expressions in addition to just constant dimensions. **Scope of Shape Variables** -Shape variables can be introduced in two places in a Relax program: In a function signature, where they may be included with the argument shapes and return shape annotations, or in `MatchShape` bindings. Shape variables used in the function signature are scoped to the entire function in which they appear. Shape variables used in `MatchShape` bindings are scoped only to the `SeqExpr` in which they appear. +New shape variables can be bound in two places in a Relax program: In `TensorStructInfo` or `ShapeStructInfo` annotations on function parameters or as the `struct_info` parameter in a `MatchCast` binding. Shape variables used in the function signature are scoped to the entire function in which they appear (including in the return structural annotation). Shape variables used in `MatchShape` bindings are scoped only to the `SeqExpr` in which they appear. **Informal Semantics of `PrimExpr`s for Dimensions** -1. Shape variables can be bound to a value exactly once: at the start of a function for shape annotations on function arguments, in `MatchShape` bindings, or before a function returns (for shape variables on the return type). In particular, matching a `PrimExpr` consisting only of an uninitialized shape variable is treated as its binding (see below on `MatchShape`). After a shape variable has been bound for the first time, future uses of it will refer to the same value. -2. It is not legal to use a shape var that has not yet been bound. This results in an error at run time, though most cases can be detected at compile time. +1. Shape variables can be bound to a value exactly once, either at the start of a function for shape annotations on function arguments or in `MatchCast` bindings. In particular, matching a `PrimExpr` consisting only of an uninitialized shape variable is treated as its binding (see below on `MatchCast`). After a shape variable has been bound for the first time, future uses of it will refer to the same value. +2. It is not legal to use a shape var that has not yet been bound. This results in an error at compile time. 3. «Local functions will "capture" defined shape variables from the parent scope with their present values in the resulting closure.» 4. If all variables in the `PrimExpr` are defined, `PrimExpr` arithmetic will generally be evaluated according to the semantics of TIR. -### Evaluating `MatchShape` +### Evaluating `MatchCast` -`MatchShape` allows for binding shape variables in Relax. It can be used with either tensor values or shape values, and in both cases the evaluation of the `PrimExpr`s proceeds similarly. +Because structural information is checked in a "best-effort" fashion, it is not always possible for the compiler to statically draw conclusions about all details of a given value's structural information. Hence, `MatchCast` allows for checking this information at run time, similar to a typecast. However, `MatchCast` also allows for binding shape variables in the process of pattern matching, hence the "match" portion of its name. -1. Evaluating `MatchShape(v, t, s)`, where `t` is a tensor value and `s` is a list of `PrimExpr`s corresponding to shape dimensions: - 1. Suppose `s` is `(p1, p2, ..., pn)` , where each variables is a `PrimExpr`. We evaluate `p1`, then `p2`, and so, in that order according to the following rules (corresponding to the `i`th dimension): - 1. If the current `PrimExpr` consists only of an uninitialized shape variable, we bind the shape variable in that scope to the concrete value of the `i`th dimension of the value of `t`. - 2. Evaluate the current `PrimExpr` and compare it to the concrete value of the `i`th dimension of `t`. Raise an error if they do not match. - 2. If `v` is provided, bind `t` to `v` (see the general semantics for how that should be implemented). -2. Evaluating `MatchShape(v, S, s)`, where `S` is a shape value proceeds identically to the above, except the `PrimExpr`s are compared to the `i`th element of `S`. +This section describes the run-time checking performed by `MatchCast(var, value, struct_info)`, for each combination of value and structural annotation (if `var` is defined, then `value` will be bound to `var` as discussed in the [general section on semantics](#detailed-semantics)). If any check given below fails, an error is raised by the `MatchCast`. -### General Shape Computation Grammar +1. If `struct_info` is `ObjectStructInfo`, then no additional check is performed. All values in Relax are objects. +2. If `struct_info` is `TensorStructInfo(ndim, dtype, shape)`, then check that `value` is a tensor value, that it has a rank of `ndim` (if `ndim` is not -1), a datatype of `dtype` (if `dtype` is not `Void`). If `shape` is defined, consider the following cases: + 1. If `shape` is a `Var`, then check that the concrete shape of `value` matches the value bound to the `Var`. + 2. If `shape` is a `ShapeExpr`, then compare the fields of the `ShapeExpr` to the concrete shape of `value`, dimension by dimension (comparing the `i`th field of the `ShapeExpr` to the `i`th dimension of the shape of `value`). Give an error if the number of the dimensions does not match the number of fields in the `ShapeExpr`. + 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. + 2. Otherwise, evaluate the field of the `ShapeExpr` and ensure that it matches the concrete value of the dimension. +3. If `struct_info` is `ShapeStructInfo(ndim, values)`, then check that `value` is a shape value, that it has `ndim` dimensions (if `ndim` is not -1). If `values` is defined, then compare it to the concrete shape value (comparing the `i`th member of `values` to the `i`th field of the shape value): + 1. If the `i`th member of `values` consists of only an unbound shape variable, then bind that variable to the `i`th field of the the concrete shape value. + 2. Otherwise, evaluate the `i`th member of `values` and check that it is equal to teh `i`th field of the concrete shape value. +4. If `struct_info` is `TupleStructInfo(fields)`, then check that `value` is a tuple value with `n` fields, where `n` is the length of `fields`. Also recursively check the `i`th field of the tuple value against the `i`th member of `fields`. +5. If `struct_info` is `FuncStructInfo(params, ret, derive_func)`, then if `params` is defined, check that `value` is a closure value; if `derive_func` is defined, check that `value` is a `PackedFunc`. No further validation may be done on a `PackedFunc`. «If `value` is a closure value, then it can contain run-time structural information indicating the structural information of its intended arguments and return value that can be compared against `params` and `ret`.» -Shape computations can consist of the following expressions, which are a subset of general Relax `Expr`s: +### Checking Structural Information at the Start and End of a Function -``` -ShapeCompExpr ::= ShapeExpr(dims: [PrimExpr]) - | RuntimeDepShape() - | «Tuple(fields: [ShapeCompExpr])» - | Call(op: Op|ExternFunc, args: [Var|Constant]) - | «TupleGetItem(tuple_value: ShapeCompExpr, index: int)» +«Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchCast`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchCast` bindings. Suppose a function has the following signature, where the `Si` are structural annotations: + +```python +def f(arg1 : S1, arg2 : S2, ..., argn : Sn) -> Sr: + return body ``` -The shape expressions can be interpreted as follows: +This can be treated as a macro that expands to -- `ShapeExpr` describes the shape of a tensor as a list of dimensions -- «`Tuple` describes the shapes of each member of a tuple» -- «`TupleGetItem` describes the shape of a member of a tuple» -- `Call` describes the shape of a function (or operator) call return value in terms of its arguments -- `RuntimeDepShape` describes shapes that are unknown at compile time (like when a shape annotation is omitted) or the shapes of values that don't have shapes (like shapes themselves, paradoxically: they *are* shapes but do not *have* shapes). +```python +def f(arg1, arg2, ..., argn): + MatchCast(arg1, S1) + MatchCast(arg2, S2) + ... + MatchCast(argn, Sn) + ret_var = body + MatchCast(ret_var, Sr) + return ret_var +``` +» -The `PrimExpr`s in a `ShapeCompExpr` can reference the same shape variables as in shape annotations, with the same semantics. +## Deriving the Structural Information for Each Expression -**Restrictions** +For each expression type, we can recursively build up the structural information associated with the expression. -Shape computations are allowed to include calls to operators and even `PackedFunc`s, but these operators and `PackedFunc`s *must* be pure. Shape computations are primarily used for memory planning and it is at the compiler's discretion when, if ever, to evaluate them (except as described below), hence they must not have side effects. +### Auxiliary Procedures -**Shape Annotations** +**`derive_func` for `FuncStructInfo`** -For shape annotations, we use `ShapeCompExpr` as the grammar, as with `shape_` expressions. `ShapeExpr` is used to annotate shapes of tensor values, «`Tuple` is used to annotate the shapes of tuple values», and `RuntimeDepShape` is used to indicate annotations that have been omitted or shapes that cannot be known at compile time (like the shapes of tensors whose rank is unknown at compile time). `Call` is used to annotate the shapes of calls to operators and «`TupleGetItem` annotates the shapes of tuple indices.» +There are two special `derive_func` values built into the compiler that are used for checking the structural information of `PackedFunc`s. -«For example, suppose we have a tuple where some fields are tensors like the following: +The first is `default_derive`, giving a simple way to determine the resulting structural information of a `PackedFunc` from its type arguments. `default_derive` takes one argument that is a `Call` node and is defined as follows: +1. Suppose its call node argument is `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`. +2. If `type_args` is of length 0, then return `ObjectStructInfo()`. +3. If `type_args` is of length 1, then return `wrap_type(aT1)`. +4. If `type_args` is of a greater length than 1, then return `TupleStructInfo(fields=[wrap_type(aT1), wrap_type(aT2), ..., wrap_type(aTn)])`. -```python -x : Tuple(Tensor((m, n), "int32"), Tuple(), Tensor((), "int32"), Tensor(_, "int32")) = ... -``` +The second is `empty_derive`, which is the weakest possible derivation. It simply returns `ObjectStructInfo` regardless of its argument. This is used for worst-case deducation of `StructInfo` for a `PackedFunc`. + +**Wrapping Types** -It has the shape annotation +For deriving the structural information for a `PackedFunc` call, the type arguments are converted into structural information. This is a straightforward procedure, given here in pseudocode: ```python -Tuple([ShapeExpr([m, n]), Tuple([]), ShapeExpr([]), RuntimeDepShape]) +def wrap_type(t: Type) -> StructInfo: + if t is ObjectType: + return ObjectStructInfo() + if t is PackedFuncType: + # leave params undefined; see default_derive below + return FuncStructInfo(ret=ObjectStructInfo(), derive_func=default_derive) + if t is FuncType: + # leave derive_func undefined + return FuncStructInfo( + params=[wrap_type(arg_type) for arg_type in t.arg_types], + ret=wrap_type(t.ret_type) + ) + if t is TupleType: + return TupleStructInfo(fields=[wrap_type(field) for field in t.fields]) + if t is ShapeType: + # leave values undefined + return ShapeStructInfo(ndim=t.ndim) + if t is DynTensorType: + # leave shape undefined + return TensorStructInfo(ndim=t.ndim, dtype=t.dtype) ``` -» - -Note that it is [a well-formedness requirement](https://www.notion.so/Informal-Relax-Language-Specification-d1fdedb8fae84f0d82b9f880f25e7370) that if any field in a type has a `ShapeExpr` annotation, it must be a `DynTensorType` with an `ndim` matching the number of dimensions in the `ShapeExpr`. For example, in the above function signatures, the `ndim` in the type annotations must be 2. -### «Assigning Shape Variables at the Start and End of a Function» +**Erasing Out-of-Scope Information** -«Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchShape`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchShape` bindings. Suppose a function has the following signature, where the `Ti` are type annotation and the `Si` are shape annotations: +When returning a value from an inner scope to an outer scope (namely, the `body` field of a `SeqExpr`, which may use variables defined in the binding blocks, and the `body` field of a `Function`, which may use variables defined in the function body), it may be possible for the derived `TensorStructInfo` or `ShapeStructInfo` to contain Relax variables or shape vars that have gone out of scope. We defined a procedure to check for any of these out-of-scope variables and weaken the structural information not to include it. The procedure is defined below, in pseudocode: ```python -def f(arg1 : (T1, S1), arg2 : (T2, S2), ..., argn : (Tn, Sn)) -> (Tr, Sr): - return body +def erase_to_well_defined( + s: StructInfo, + var_scope: set of Relax vars in current scope, + shape_var_scope: set of shape vars in current scope) + -> StructInfo: + + if s is ObjectStructInfo: + return s + if s is TensorStructInfo: + if s.shape is defined: + if (s.shape is a Relax var that is not in var_scope + or s.shape is a ShapeExpr that contains any shape var not in shape_var_scope): + # leave shape undefined + return TensorStructInfo(ndim=s.ndim, dtype=s.dtype) + else: + return s + else: + return s + if s is ShapeStructInfo: + if (s.values is defined + and any member of s.values contains a shape var not in shape_var_scope): + # leave values undefined + return ShapeStructInfo(ndim=s.ndim) + if s is TupleStructInfo: + return TupleStructInfo( + fields=[ + erase_to_well_defined(field, var_scope, shape_var_scope) + for field in s.fields + ] + ) + if s is FuncStructInfo: + if params is defined: + return FuncStructInfo( + params=[ + erase_to_well_defined(param, var_scope, shape_var_scope) + for param in s.params + ], + ret=erase_to_well_defined(s.ret, var_scope, shape_var_scope) + ) + else: + return FuncStructInfo( + ret=erase_to_well_defined(s.ret, var_scope, shape_var_scope), + derive_func=s.derive_func + ) ``` -This can be treated as a macro that expands to +**Substituting Free Shape Variables in `FuncStructInfo`** -```python -def f(arg1 : T1, arg2 : T2, ..., argn : Tn) -> Tr: - check_annotation(arg1, T1, S1) - check_annotation(arg2, T2, S2) - ... - check_annotation(argn, Tn, Sn) - ret_var = body - check_annotation(ret_var, Tr, Sr) - return ret_var -``` -» +The `params` field of `FuncStructInfo` can contain free shape variables, indicating that these shape variables are bound to the corresponding dimensions of the argument when the function is called. For checking the compatibility of two function types, we can construct a mapping of shape variables and then substitute shape variables according to the mapping. The mapping can be constructed by doing a simple structural match, as when checking alpha-equivalence. -Because `MatchShape` is defined only for tensor and shape values, we must use a macro to handle other possible types that may be passed into a function, given here in pseudocode: +For clarity, additional detail on how the mapping should be constructed is given here in pseudocode: ```python -def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: - if s is a ShapeExpr: - tmp = fresh_var() - # type checking should ensure that e is always a tensor - return SeqExpr( - [BindingBlock([MatchShape(tmp, e, s.dims)])], - tmp - ) - «else if s is a Tuple: - # type checking should ensure that e is always a tuple and the lengths match - shapes = s.fields - tmp = fresh_var() - return SeqExpr( - [BindingBlock([ - VarBinding(tmp, e), - # recursive in case we have nested tuples - VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, 0), shapes[0])), - VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, 1), shapes[1])), - ..., - VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, n-1), shapes[n-1])) - ])], tmp - )» - else if s is a Call: - tmp = fresh_var() - return SeqExpr( - [BindingBlock([ - VarBinding(tmp, e), - # completely dynamic check that does not assign shape vars. - VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) - ])], tmp - ) - «else if s is TupleGetItem: - val = s.tuple_value - if val is Tuple: - return check_annotation(e, val.fields[s.index]) - # otherwise, evaluate it - return SeqExpr( - [BindingBlock([ - VarBinding(tmp, e), - VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) - ])], tmp - )» - else if s is RuntimeDepShape: - # no need to check - return e +def get_shape_var_mapping(S1: StructInfo, S2: StructInfo) -> {tir::Var, PrimExpr}: + if S1 and S2 are not the same type: + return {} + if S1 and S2 are both TupleStructInfo: + if S1.fields and S2.fields don't have the same length: + return {} + ret = {} + for 0 <= i < length of S1.fields: + ret = union of ret and get_shape_var_mapping(S1.fields[i], S2.fields[i]) + return ret + if S1 and S2 are both FuncStructInfo: + if S1 and S2 both have params defined and the params are the same length: + ret = {} + for 0 <= i < length of S1.params: + ret = union of ret and get_shape_var_mapping(S1.params[i], S2.params[i]) + # don't look at the return field; it's not a binding position + return ret + else: + return {} + if S1 and S2 are both ShapeStructInfo: + if S1 and S2 both have values defined and the values are the same length: + ret = {} + for 0 <= i < length of S1.values: + if S1.values[i] is an unbound shape variable: + ret[S1.values[i]] = S1.values[i] + return ret + else: + return {} + if S1 and S2 are both TensorStructInfo: + if ( + S1 and S2 both have shape defined + and the shapes are both ShapeExprs + and their values fields are the same length + ): + ret = {} + for 0 <= i < length of S1.shape.values: + if S1.shape.values[i] is an unbound shape variable: + ret[S1.shape.values[i]] = S2.shape.values[i] + return ret + else: + return {} ``` -### Evaluating Shape Expressions - -Every shape expression in the program (`shape_`) is associated with a program expression. Other than in the above procedure for checking function parameter shapes and the return shape, the specification does not guarantee that any `shape_` expression will ever be evaluated or how many times it may be evaluated; `shape_` is intended primarily for the benefit of memory planning. Hence, all `shape_` expressions must be pure and must be guaranteed to terminate. The `shape_` for a given expression `e` is intended to be evaluated *before* `e`. - -Shape expressions follow the same evaluation rules as general program expressions. In particular, shape functions are permitted to reference any variable that is in scope at the point of its associated expression; i.e., when evaluated, they form closures that capture any free variables (Relax variables and shape variables) referenced in their body. The `RuntimeDepShape` expression has no semantics at run time and indicates a shape that cannot be predicted in advance. If a `RuntimeDepShape` is encountered at any point while dynamically checking a shape match (see the `check_annotation` procedure above), it should "short-circuit" the match and cause the match to succeed immediately. - -### Building Up `shape_` for Each Expression - -For each expression type, we can recursively build up an associated `shape_` expression according to the following rules: - -1. For `Constant(value)`, the `shape_` expression is a `ShapeExpr` corresponding to the concrete shape of `value`. For example, for `Constant(1)`, `shape_` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape_` is `ShapeExpr([2])`. -2. «For `Tuple(fields)`, `shape_` can be defined as `Tuple([field.shape_ for field in fields])`.» -3. For `ShapeExpr`s, `shape_` is `RuntimeDepShape`. -4. `RuntimeDepShape` expressions should appear only in shape expressions; their `shape_` is not defined. -5. For `If(cond, true_branch, false_branch)`, we compare the `shape_` of `true_branch` and `false_branch`. If these can be proven equivalent (by a method that the compiler implementation is free to determine), then the `If` node's `shape_` is that shape. If they do not match, then we set it to `RuntimeDepShape`. -6. For `SeqExpr`, we set the `shape_` to be the `shape_` of the body expression. The `shape_` must respect the scoping rules for the `SeqExpr`: If the `shape_` of the body expression contains shape variables not defined in the outer scope (i.e., shape variables that are scoped to the `SeqExpr` only) or if the `shape_` contains any `Var`s or `DataflowVar`s scoped to the `SeqExpr`, use `RuntimeDepShape` as the shape. -7. For handling variable bindings: - 1. For the arguments to a function, set the `shape_` to the annotated shape. If the annotation is omitted, use `RuntimeDepShape`. - 2. In the general `VarBinding(v, e)`, if `v` does not have a shape annotation or the annotation is `RuntimeDepShape`, then we set the `shape_` of `v` to the `shape_` of `e`. If `v` has a shape annotation, then if the `shape_` of `e` can be proven equivalent to the shape annotation, use the shape annotation for the `shape_` of `v`. «Otherwise, give an error and require an explicit `MatchShape`.» - - It is up to the compiler implementation to decide what method to use for attempting to prove equivalence. - - 3. For bindings where the RHS is a function literal or assigning the `shape_` of a `GlobalVar`, see the rule for `Function` nodes. - 4. For `MatchShape(var, value, shape)`, we set the `shape_` of `var` to `shape`, as it will be dynamically checked. -8. «For `TupleGetItem(tuple_value, i)`, we examine the `shape_` of `tuple_value`; suppose it is `s`. If `s` is a `Tuple`, then we use its `i`th field. If it is `RuntimeDepShape`, we use `RuntimeDepShape`. If it is a `Call` to a function that returns a tuple with at least `i + 1` members, set the `shape_` to `TupleGetItem(s, i)`. Otherwise, raise an error at compile time (though this should not happen if type checking has passed).» -9. For `Call` nodes: - 1. For a call to an `ExternFunc`, we use `RuntimeDepShape` because we cannot analyze the shapes of arbitrary `PackedFunc`s and must check dynamically. - 2. For a call to an `Op`, we use the manually defined `FInferShape` macro if it has been defined and `RuntimeDepShape` if it has not. `FInferShape` is a function that takes in the call node and produces a `ShapeCompExpr`. - 3. «For all other cases with `Call(op, args)`, we consider the following cases: - 1. If `op` is a `GlobalVar` or a `Var` that refers to a function defined in the current scope, look up the `Function` node it references; let us call it `f`. Similarly, if `op` is itself a `Function` node, let `f` be `op`. - - Attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) on `f`'s return shape. A pseudocode procedure for this beta-reduction is given below, as a macro. - - 1. If the return shape of `f` is a `Call` node or contains any `Call` nodes, substitute any parameters of `f` for the corresponding member of `args`. (E.g., if `f` has parameters `p1`, `p2`, …, `pn` and any of these variables appears in the return shape, `p1` should be replaced with the first member of `args`; `p2`, with the second; etc.) If any member of `args` that is substituted this way is not a `Var` or `Constant`, consider beta-reduction to fail. - 2. For each shape annotation in the parameters of `f`, attempt to match it with the `shape_` of the corresponding member of `args`, substituting shape variables in the return shape accordingly. If the `shape_` of the member of `args` is `RuntimeDepShape`, consider beta-reduction to fail. If the `shape_` is not `RuntimeDepShape` but is incompatible with the parameter's shape annotation (e.g., a `Tuple` where a `ShapeExpr` was expected), report an error at compile time. - - If `f`'s return shape is `RuntimeDepShape`, then consider the call result to have `RuntimeDepShape`. If beta-reduction is considered to fail, then consider the call result to have `RuntimeDepShape`. If it succeeds, use the resulting shape as the `shape_` of the call result. - - 2. Otherwise, consider the result of the call to have `RuntimeDepShape`. - » -10. For a function node, set the `shape_` to `RuntimeDepShape`. - -### Procedure for Substituting a Function Return Shape to Determine the Shape of a Call - -The `substitute_shape` procedure defined below describes how the shape expression for a call result can be defined given the call arguments and the return shape annotation on the corresponding function node. Note that this procedure can obtain much more precise results in the cases of `Call` or `TupleGetItem` return shapes. - +**Checking Compatibility** + +In many cases during the derivation of structural information, it is important to judge when two distinct structural information encodings are compatible with each other or when they are too different from each other to be reconciled, which can indicate an error. In the case of shape information, this could mean having two symbolic shapes that can be proven not to be equal to each other. Because shape expressions can contain arithmetic and it can be very difficult to statically prove whether two arithmetic expressions are equal, we permit the compiler implementation to make a best-effort attempt to prove equality for arithmetic expressions. (The user can insert a `MatchCast` to check definitively.) Since the checks are best-effort, the compatibility check will only report incompatibility if two values are _definitely_ different from each other. + +We can check if some structural information `S1` is accepted where structural information `S2` is expected by the process given below, which we refer to as `check_compability(S1, S2)` for convenience. `check_compatibility` can find that `S1` and `S2` are compatible, possibly compatible, or incompatible. "Incompatible" indicates a definite mismatch that should result in a compiler error; "possibly compatible" indicates that the structures may or may not match and should likely result in a compiler warning (indicating that a user may want to insert a dynamic check). An invariant that should should is that if `check_compatibility(S1, S2)` returns "compatible" or "possible compatible", `erase_struct_info(S1) <: erase_struct_info(S2)` should hold; that is, compatibility of structural information should be consistent with typing rules. + +1. If `S2` is `ObjectStructInfo`, then they are compatible. +2. Otherwise, if `S1` and `S2` are not both `TensorStructInfo` or both `TupleStructInfo`, etc. (besides `ObjectStructInfo`), then report an incompatibility. +3. If `S1` and `S2` are both `TupleStructInfo`: + 1. If `S1.fields` is not the same length as `S2.fields`, they are incompatible + 2. Call `check_compability(S1.fields[i], S2.fields[i])` for all `i`. If any pair of fields is incompatible, then `S1` and `S2` are incompatible. If no pair of fields is incompatible but at least one is possibly compatible, then `S1` and `S2` are possibly compatible. If all pairs of fields are compatible, then `S1` and `S2` are compatible. +4. If `S1` and `S2` are both `ShapeStructInfo`: + 1. `S2.ndim` is -1, then they are compatible. + 2. Otherwise, give an error if `S1.ndim` does not match `S2.ndim`. + 3. If `values` is not defined for `S2`, then they are compatible. + 4. If `values` is defined for `S2` but not defined for `S1`, then they are possibly compatible. + 5. If `values` is defined for both `S1` and `S2`, then the two are incompatible if `S1.values[i]` can be proven to be _not_ equal to `S2.values[i]` for some `i`. If all members can be proven to be equal, then they are compatible. Otherwise, if at least one pair of values cannot be proven to be either equal or unequal, then they are possibly compatible. +5. If `S1` and `S2` are both `TensorStructInfo`: + 1. If `S2.dtype` is not `Void` and does not match `S1.dtype`, then they are incompatible. + 2. If `S2.ndim` is not -1 and does not match `S1.ndim`, then they are incompatible. + 3. If `S2.shape` is not defined, then they are compatible. + 4. If `S2.shape` is defined and `S1.shape` is not defined, then they are possibly compatible. + 5. Otherwise, if both `shape` fields are given and either is a `Var`, then consider `S1` and `S2` compatible if the compiler can statically prove that the `Var` holds the same value as the other `shape` field, consider them possibly compatible if the compiler cannot draw a conclusion one way or the other, and consider them incompatible if the `Var` definitely has a different value from the other `shape`. + 6. If both `shape` fields are given and they are both `ShapeExpr` nodes, then `S1` and `S2` are incompatible if the compiler can prove that some dimension of `S1.shape` is _not_ equal to the corresponding dimension of `S2.shape`. Otherwise, if the all dimensions can be proven to be equal, then consider them compatible. If at least one pair of dimensions cannot be proven to be equal or unequal, consider them possibly compatible. +6. If `S1` and `S2` are both `FuncStructInfo`: + 1. If `S1` and `S2` don't both have defined `params` or both have undefined `params`, consider them incompatible. + 2. If both `S1` and `S2` have undefined `params`, consider them compatible if they have an identical `derive_func` and consider them possibly compatible if they have different `derive_func`s (as they is no further way to introspect the `derive_func` and draw static conslusions about `PackedFunc`s). + 3. If `params` is defined for both `S1` and `S2`: + 1. Consider them incompatible if the `params` have different lengths. + 2. Next, map unbound shape variables as follows: Get a variable mapping `m` by applying `get_shape_var_mapping(S1.params[i], S2.params[i])` for all values of `i`, taking the union of all resulting mappings. Next, substitute all occurrences of the shape variables in `S1` with their values in `m`. + 3. If `check_compatible(S2.params[i], S1.params[i])` (note the direction of the check: see the subtyping rule for `FuncType`) is incompatible for any `i` or if `check_compatible(S1.ret, S2.ret)` is incompatible, then they are incompatible. Otherwise, if `check_compatible(S2.params[i], S1.params[i])` is possibly compatible for any `i` or if `check_compatible(S1.ret, S2.ret)` is possibly compatible, consider `S1` and `S2` possibly compatible. Consider `S1` and `S2` compatible only if all checks are compatible. + +**Unification** + +Analogously to subtyping, we can also consider a hierarchy of structural information, considering some structural information to more or less specific than other structural information. Accordingly, we can also define a least upper bound for structural information, as with types. + +We can define an analogue to subtyping for structural information, as below. We say that `S1` is more specific than `S2` and denote it as `S1 <<: S2` (to distinguish from the notation on subtyping) based on the conditions given here. As an invariant, if `S1 <<: S2` holds, then `erase_struct_info(S1) <: erase_struct_info(S2)`, though the converse may not be true. +1. Reflexivity: `S1 <<: S1` for all `S1`. +2. Transitivity: For all `S1`, `S2`, and `S3`, if `S1 <<: S2` and `S2 <<: S3`, then `S1 <<: S3`. +3. For all `S1`, `S1 <<: ObjectStructInfo()`. +4. For `TensorStructInfo`: + 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=-1, dtype=d)`. + 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=n, dtype=Void, shape=s)`. + 3. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (not undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=n, dtype=d, shape=undefined)`. + 4. Given any datatype `d`, an arbitrary `ndim` `n`, and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, shape=s1) <<: TensorStructInfo(ndim=n, dtype=d, shape=s2)` if `s1` and `s2` are _definitely_ or _possibly_ statically equal. +5. For `ShapeStructInfo`: + 1. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (possibly undefined), `ShapeStructInfo(ndim=n, values=v) <<: ShapeStructInfo(ndim=-1)`. + 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <<: ShapeStructInfo(ndim=n, values=undefined)`. + 3. Given an arbitrary `ndim` `n` and two arbitrary sets of values `v1` and `v2` (both defined), `ShapeStructInfo(ndim=n, values=v1) <<: ShapeStructInfo(ndim=n, values=v2)` if, for all valid `i`, `v1[i]` and `v2[i]` can be proven to be _definitely_ or _possibly_ statically equal. +6. Given two lists of structural information `fields1` and `fields2`, `TupleStructInfo(fields=fields1) <<: TupleStructInfo(fields=fields2)` if `fields1` and `fields2` are the same length and for all `i`, `fields1[i] <<: fields2[i]`. +7. For `FuncStructInfo`: + 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <<: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. + 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <<: F2` only if `F1.derive_func` and `F2.derive_func` are identical. + 3. Given two lists of structural information parameters `P1` and `P2` and two structural information annotations `R1` and `R2`, `FuncStructInfo(params=P1, ret=R1) <<: FuncStructInfo(params=P2, ret=R2)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <<: P1[i]` and `R1 <<: R2`. + +Given these rules, we can define how to unify (get the LUB) of two structural information annotations as follows (in pseudocode): ```python -def map_shape_vars(param_shape: ShapeCompExpr, arg_shape: ShapeCompExpr, shape_var_mapping: {tir::Var : PrimExpr}) -> bool: - if param_shape is RuntimeDepShape or arg_shape is RuntimeDepShape: - return False - if param_shape is ShapeExpr and arg_shape is ShapeExpr: - if len(param_shape.values) != len(arg_shape.values): - raise UnificationError("Shapes are of incompatible ranks") - for param_dim, arg_dim in zip(param_shape.values, arg_shape.values): - if param_dim in shape_var_mapping: - # syntactic equality - if arg_dim != shape_var_mapping[param_dim]: - # if they are statically not equal, e.g., 5 != 7 or 3 + 3 != 3*3 - if can_prove_not_equal(arg_dim, shape_var_mapping[param_dim]): - raise UnificationError("Incompatible dimensions") - else: - return False - else: - shape_var_mapping[param_dim] = arg_dim - return True - if param_shape is Tuple and arg_shape is Tuple: - if len(param_shape.fields) != len(arg_shape.fields): - raise UnificationError("Tuples are of incompatible lengths") - for param_field, arg_field in zip(param_shape.fields, arg_shape.fields): - ret = map_shape_vars(param_field, arg_field, shape_var_mapping) - if not ret: - return False - return True - if param_shape is TupleGetItem and arg_shape is TupleGetItem: - # Does not necessarily indicate a unification error, - # depending on what the tuple values are. - # Constant folding the TupleGetItem nodes could improve this unification case - if param_shape.index != arg_shape.index: - return False - return map_shape_vars(param_shape.tup_value, arg_shape.tup_value) - if param_shape is Call and arg_shape is Call: - # no dimension mapping to do in this case - return True - # if either is a Call or TupleGetItem, it is possible that the shapes - # can match dynamically even if they don't match statically - if (param_shape is Call - or param_shape is TupleGetItem - or arg_shape is Call - or arg_shape is TupleGetItem): - return False - raise UnificationError("Incompatible shape constructs") - -def substitute_vars(target: Expr, var_mapping: {Var: Expr}, shape_var_mapping: {tir::Var: PrimExpr}) -> Expr: - def substitute_shape_vars(target: PrimExpr): - if target is tir::Var: - if target in shape_var_mapping: - return shape_var_mapping[target] - else: - return target - # proceed recursively in all subexpressions, checking for vars - - if target is Var: - if target in var_mapping: - return var_mapping[target] - return target - if target is ShapeExpr: - return ShapeExpr([ - substitute_shape_vars(dim) - for dim in target.values +def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: + if S2 is ObjectStructInfo: + return S1 + if S1 is ObjectStructInfo: + return S2 + if S1 and S2 do not match types (e.g., not both TensorStructInfo, etc): + return ObjectStructInfo() + if S1 and S2 are both ShapeStructInfo: + if S1.ndim == -1: + return S1 + if S2.ndim == -1: + return S2 + if S1.ndim != S2.ndim: + return ShapeStructInfo(ndim=-1) + if S1.ndim == S2.ndim: + if S1.values is undefined: + return S1 + if S2.values is defined: + return S2 + if S1.values can be statically proven to match S2.values: + return S1 + # values either proven not to match or unknown + return ShapeStructInfo(ndim=S1.ndim) # leave values undefined + if S1 and S2 are both TensorStructInfo: + ndim = S1.ndim if S1.ndim == S2.ndim else -1 + dtype = S1.dtype if S1.dtype == S2.dtype else Void + if ( + S1.ndim == -1 or S2.ndim == -1 or S1.ndim != S2.ndim + or S1.shape is undefined or S2.shape is undefined + ): + return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined + # both shapes are defined + if S1.shape can be proven to equal S2.shape: + return S1 + # either proven to be unequal or cannot be concluded whether they are equal + return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined + if S1 and S2 are both TupleStructInfo: + if S1.fields and S2.fields are of different lengths: + return ObjectStructInfo() + return TupleStructInfo( + unify_struct_info(S1.fields[i], S2.fields[i]) + for 0 <= i < length of S1.fields ]) - # recurse through all other cases, checking for vars and shape exprs analogously - -def substitute_shape(func_params, arg_exprs, ret_shape): - var_mapping = {param: arg_expr for param, arg_expr in zip(func_params, arg_exprs)} - shape_var_mapping = {} - for param, arg_expr in zip(func_params, arg_exprs): - can_unify = map_shape_vars(param.shape_, arg_expr.shape_, shape_var_mapping) - if not can_unify: - return RuntimeDepShape() - - new_shape = substitute_vars(ret_shape, var_mapping, shape_var_mapping) - if new_shape contains any free (Relax or shape) variables: - return RuntimeDepShape() - return new_shape + if S1 and S2 are both FuncStructInfo: + if S1.params and S2.params are not both defined or both undefined: + return ObjectStructInfo() + if S1.params and S2.params are both undefined: + # they must be the same function, not bothering to check eta-equivalence + if S1.derive_func == S2.derive_func: + return S1 + return FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive) + if S1.params and S2.params are both defined: + if S1.params and S2.params do not have the same length: + return ObjectStructInfo() + unified_params = [] + for 0 <= i < length of S1.params: + unified_param = unify_struct_info(S1.params[i], S2.params[i]) + # That is, if the params judged to be equal, use them. + # If there is some pair that is not equal, + # we can't unify these types except with ObjectStructInfo + # See the use of GLB with FuncTypes + if unified_param <<: S1.params[i] and unified_param <<: S2.params[i]: + unified_params[i] = unified_param + else: + return ObjectStructInfo() + return FuncStructInfo(params=unified_params, ret=unify_struct_info(S1.ret, S2.ret)) ``` +### Derivation Rules + +Let `Δ` be the structural information context for Relax variables (to distinguish from `Γ` for types) and let `Σ` track which shape variables are in scope. + +1. «Prepopulate `Δ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Δ` corresponding to that `GlobalVar`.» +2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Δ[v]` for the structural information. +3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. +4. For `Tuple(fields)`, the resulting structural information is `TupleStructInfo([f.struct_info for f in fields])`, after deriving the structural information for the fields recursively. +5. For `ShapeExpr(values)`, the resulting structural information is `ShapeStructInfo(ndim, values)`, where `ndim` is the length of `values`. +6. For `If(cond, true_branch, false_branch)`, we compare the structural information of `true_branch` and `false_branch` (call these `S_t` and `S_f`, respectively). The resulting structural information is `unify_struct_info(S_t, S_f)`. +7. For `SeqExpr(blocks, body)`: + 1. For each binding block in `blocks` (call the current one `block`): + 1. Process each binding in the block, updating `Δ` and `Σ` accordingly (this is discussed in detail below). + 2. If `block` is a `DataflowBlock`, then remove all `DataflowVar`s introduced in `block` from `Δ` before proceeding to the next block. + 2. Next derive the structural information for `body`. Let us call this `S`. + 3. Remove all Relax variables introduced in `blocks` from `Δ` and all shape variables introduced in `blocks` from `Σ`. + 4. The structural information of the entire `SeqExpr` is `erase_to_well_defined(S, Δ, Σ)`. +8. For handling variable bindings: + 1. If `v` is the argument to a function, then if `v` has a structural annotation `S`, set `Δ[v]` to `S`. Add any unbound shape variables in `S` to `Σ`. If `v` does not have a structural annotation, set `Δ[v]` to `ObjectStructInfo()`. + 2. In the general `VarBinding(v, e)`: + 1. If `e` is a function literal, then recursion is permitted. In this case, `v` must have a structural annotation `Sv`. Derive the structural information for `e` as follows: Set `Δ[v]` to `Sv`, apply the normal rule for function literals (given below) to `e` to derive structural information `Se`, and finally remove `v` from `Δ`. Raise an error if `Se` and `Sv` are not compatible (via `check_compatibility`). + 2. Otherwise, derive the structural information of `e` and call it `Se`. + 3. If `v` has a structural annotation `Sv`, then apply `check_compatibility` to `Sv` and `Se`. If they are compatible, then set `Δ[v]` to `Sv` (respecting the user's intent in giving an annotation). Give a warning if `Sv` is more specific than `Se`. If are not compatible, then raise an error. + 4. If `v` does not have a structural annotation, then set `Δ[v]` to `Se`. + 3. For `MatchCast(v, value, S)`: + 1. Derive the structural information of `value` and call it `Sv`. + 2. Add any new shape variables in `S` to `Σ`. + 3. If `S <<: Sv` and `Sv <<: S` do not both hold, give a warning, as this indicates a cast that will always fail at run time. + 4. If `v` is given and it has a structural annotation `S'`, then give an error if `S` and `S'` are not compatible via `check_compatibility`. If they are compatible, then set `Δ[v]` to `S'` (respecting the user's intent in giving an annotation). (TODO: It doesn't seem very sensible to have a dynamic cast and give a different annotation, perhaps we should simply not permit doing that.) + 5. If `v` is given and it does not have a structural annotation, then set `Δ[v]` to `S`. +9. For `TupleGetItem(tuple_value, i)`, derive the structural information for `tuple_value` and call it `St`. Raise an error if `St` is not `TupleStructInfo`. If `St` is `TupleStructInfo(fields)`, then raise an error if `fields` value has less than `i + 1` members (this should not happen if type checking has passed) and use `fields[i]` (zero-based) as the structural information for the `TupleGetItem`. +10. For an `ExternFunc` node, the resulting structural information is `FuncStructInfo(params=None, ret=ObjectStructInfo(), derive_func=default_derive)`. +11. For `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`: + 1. For a call to an `Op`, we use the manually defined `FInferStructInfo` macro if it has been defined and `ObjectStructInfo()` if it has not. `FInferStructInfo` is a function that takes in the call node and returns the structural information of the result. + 2. Otherwise, derive the structural information for `op` and call it `Sf`. Next derive the structural information for the args and call it `S1`, `S2`, ..., and `Sn`. + 1. Give an error if `Sf` is not `FuncStructInfo`. + 2. If the `derive_func` field of `Sf` is defined, then apply the `derive_func` macro to the call node to derive the structural information for the call node, ignoring the `ret` field of `Sf`. + 3. Otherwise, `params` must be defined. Give an error if the length of `params` does not match the number of call arguments. + 4. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. Replace all variables in `m` with their mapping in `Sf`. + 5. After the substitutions, give an error if `check_compatibility` indicates that the `i`th member of `params` and `Si` are incompatible for some `i` (warn if they are only possibly compatible). + 6. Use `erase_to_well_defined(Sf.ret, Δ, Σ)` as the resulting structural information. +12. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: + 1. Let `S1`, `S2`, ..., `Sn` be the structural information of the parameters. If `vi` has a structural annotation, then use that annotation for `Si`; if not, use `ObjectStructInfo()`. Let `Sr` be `ret_struct_info` if it is defined and `ObjectStructInfo()` if not. + 2. If the function is bound to a `GlobalVar` `gv`, set `Δ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. + 3. For each of the `vi`, set `Δ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. + 4. Derive the structural information for `body`, calling it `Sb`. + 5. Give an error if `Sb` is incompatible with `Sr` via `check_compatibility` (warn if only possibly compatible). + 6. If `ret_struct_info` is defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], ret_struct_info)` as the structural information for the function. If `ret_struct_info` is not defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], erase_to_well_defined(Sb, Δ, Σ))`. + 7. Remove all variables added to `Δ` and `Σ` during the derivation. + ### Note on Proving Shapes Equivalent and Eliminating Dynamic Checks There can be some complexity involved in checking whether two shapes match during shape inference. A very simple, conservative method for determining equality is simply using alpha-equivalence: If the two shapes have the same structure, then they are equivalent. However, this method is conservative and can overlook numerical properties in `PrimExpr`s. We leave it up to compiler implementations as to whether to use more advanced methods for proving equivalence, such as attempting to use algebraic rewrite rules. (As a consequence, portability requires inserting dynamic checks wherever there needs to be a comparison of shapes.) Note that optimizations like function inlining or constant folding could allow for simplifying many shape annotations and expressions and make it possible to conclude at compile time that shapes in more cases are equivalent. In general, developing compiler infrastructure for partial evaluation and reasoning about common situations with shape annotations may eliminate many dynamic checks. -Applying some kind of normalization or algebraic simplifications to `PrimExpr`s used in shape annotations and in `shape_` fields can also make it easier to conclude that certain dynamic checks may not be necessary by increasing the likelihood that more `shape_` expressions could be made syntactically identical to the shape annotations. It would also be possible to generate compile-time warnings if analysis reveals that two shapes may not match (either using rewrite rules or by trying random values for shape variables and checking). +Applying some kind of normalization or algebraic simplifications to `PrimExpr`s used in structural information and `MatchCast` bindings can also make it easier to conclude that certain dynamic checks may not be necessary by increasing the likelihood that more derive structural information could be made syntactically identical to the structural annotations. It would also be possible to generate compile-time warnings if analysis reveals that two shapes may not match (either using rewrite rules or by trying random values for shape variables and checking). -Since most dynamic shape checks are done for safety, it may be feasible to introduce a compilation mode that eliminates almost all dynamic shape checks. Some shape checks may not be possible to eliminate, since the body of the program may construct `ShapeExpr`s and use them in calls to `PackedFunc`s, so some bindings to shape variables may need to be preserved, per a liveness analysis. +Since most dynamic structure checks are done for safety, it may be feasible to introduce a compilation mode that eliminates almost all dynamic structure checks. Some structure checks may not be possible to eliminate, since `ShapeExpr`s can use shape variables introduced in `MatchCast` brindings, so this would require some liveness analysis. -## Possible Extensions to the Shape Expression System +## Possible Extension: Indicating Unknown Dimensions -We may consider two possible extensions to the shape expression system in order to accommodate two further cases: +A further case that may be of interest might be using an explicit wildcard dimension (e.g., using `tir::Any`) to allow for dimensions to be specified as "unknown" in function return shapes. As described at present, the only way for a function to specify a partly unknown return shape is to make the entire return shape unknown (`RuntimeDepShape`), which loses partial shape information. -1. An explicit wildcard dimension (e.g., using `tir::Any`) to allow for dimensions to be specified as "unknown" in function return shapes. As described at present, the only way for a function to specify a partly unknown return shape is to make the entire return shape unknown (`RuntimeDepShape`), which loses partial shape information. -2. Adding `shape_` expressions consisting of functions, to allow arbitrary closures to have a known shape. This would allow the shapes of calls to closures of unknown origin (namely, in a higher-order function) to have their shapes correctly inferred rather than made `RuntimeDepShape`. - -In both cases, these additions would entail additional complexity (shape inference macros for operators would have to deal with potential `tir::Any` nodes and we would have to define rules for constructing, calling, and simplifying functions in `shape_` expressions). However, the advantage of implementing these features would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using `RuntimeDepShape` means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchShape` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. +This addition would entail some, as `FInferStructInfo` and `derive_func` macros would have to deal with potential `tir::Any` nodes. However, the advantage of implementing it would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using undefined `shape` fields means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchCast` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. # Detailed Semantics @@ -781,20 +930,19 @@ For each expression, we define how it affects the program's visible state and th 3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. 4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per type checking, must evaluate to a tuple) and then returning the `i`th field of the result. 5. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. -6. `RuntimeDepShape` expressions must not appear in the general body of a program; it is a well-formedness error if they do. They do not have any defined semantics. -7. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. -8. The node `If(cond, true_branch, false_branch)` is evaluated as follows: +6. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. +7. The node `If(cond, true_branch, false_branch)` is evaluated as follows: 1. First `cond` is evaluated. Let the result be `r` (per type checking, it must be a bool scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. -9. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: +8. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: 1. If `op` is an `ExternFunc` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Next, look up the `PackedFunc` registered under the global symbol name. If it exists (it is an error at run time if it does not), call the `PackedFunc` using the given arguments and return the result. Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. 2. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» - 3. In all other cases, first evaluate `op` (it must evaluate to a closure). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. -10. For the node `SeqExpr(blocks, body)`, we evaluate as follows: + 3. In all other cases, first evaluate `op` (it must evaluate to a closure). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) +9. For the node `SeqExpr(blocks, body)`, we evaluate as follows: 1. Push a new scope onto the stack. 2. Iterate through the `BindingBlock`s in `blocks` in order. We will call the current one `block`. For each binding in `Block`: - 1. If the binding is `MatchShape(var, value, shape)`, perform the shape matching and shape variable updates as described in the shape evaluation section. If `var` is provided, `var` will be bound to `value` in the current scope; this assignment is aliasing and no new value is allocated. If `var` is not provided, then the shape check is performed and shape variables are updated, but no new binding is introduced. + 1. If the binding is `MatchCast(var, value, struct_info)`, perform the structure matching and shape variable updates as described in the structural information section. If `var` is provided, `var` will be bound to `value` in the current scope; this assignment is aliasing and no new value is allocated. If `var` is not provided, then the structural check is performed and shape variables are updated, but no new binding is introduced. 2. If the binding is `VarBinding(var, value)`, then evaluate `value` and bind `var` to that value in the current scope; this assignment is aliasing and no new value is allocated. 3. If `block` is a `DataflowBlock`, remove all `DataflowVar`s bound in the block from the current scope before proceeding to the next block. 3. After iterating through the binding blocks, evaluate `body` in the current scope. That will be the return value of the `SeqExpr`. @@ -804,7 +952,7 @@ For each expression, we define how it affects the program's visible state and th Optimizations are allowed to reorder and modify the operations of a program in any way so long as they do not change the value returned by evaluating the program or any visible behavior of the program. For the purposes of compilation, visible behaviors consist of side effects like mutating values in the program or external effects like I/O (printing to the console, creating files, etc.) and the order and number of times in which they happen. -«Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchShape` or `cast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» +«Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchCast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» The specification makes no guarantees about certain memory-related properties and hence also does not consider them to be "visible behaviors": @@ -824,5 +972,3 @@ The above evaluation rules are general, but leave much room for implementations - «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type.» - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. -- «`cast(v, type_args=[aT])`: Given an argument `v`, it dynamically checks if `v`'s run-time representation is a subtype of `aT`. If it is not, it exits the program with an error message. Otherwise, it returns `v`.» - From 0d89b8fd73062a99d1a007124847c10132bfe3c7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 7 Jan 2023 21:33:16 -0500 Subject: [PATCH 15/30] Further StructInfo revisions --- relax_spec.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 0e2b4c3660..6f1d2ea80d 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -175,7 +175,7 @@ The type checking rules assign types to every variable in scope and every type o ## Structural Information System Survey -In Relax, tensor shapes are not handled in the type system; each expression instead a has an associated shape expression. In many cases, these shape computations can allow for statically concluding that two shapes are the same and thus eliminate the need for dynamic checks via `MatchShape`. However, when shapes cannot be statically concluded to be the same, it may be necessary for there to be dynamic checks. The compiler is also free to make use of shape expressions for memory planning purposes. «Relax is "strongly shaped," meaning that if the compiler cannot conclude that shapes match in certain cases, an error will be issued and an explicit `MatchShape` will be required.» +In Relax, tensor shapes are not handled in the type system, even though it would be greatly beneficial for the compiler to make use of shape information for static optimizations. Instead, shape information is tracked using Relax's structural information system, in which every expression has structural information associated with it (like tensor shapes) that is more expressive than its type. Structural information can convey richer properties about expressions, like tensor shapes, and can facilitate a greater degree of static reasoning. However, when it is not feasible for the compiler to draw conclusions about structural information, this information can be checked dynamically via `MatchCast`. The structural information is essentially an extended type system, so `MatchCast` also serves to handle type casting. --- @@ -526,7 +526,7 @@ A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds **Scope of Shape Variables** -New shape variables can be bound in two places in a Relax program: In `TensorStructInfo` or `ShapeStructInfo` annotations on function parameters or as the `struct_info` parameter in a `MatchCast` binding. Shape variables used in the function signature are scoped to the entire function in which they appear (including in the return structural annotation). Shape variables used in `MatchShape` bindings are scoped only to the `SeqExpr` in which they appear. +New shape variables can be bound in two places in a Relax program: In `TensorStructInfo` or `ShapeStructInfo` annotations on function parameters or as the `struct_info` parameter in a `MatchCast` binding. Shape variables used in the function signature are scoped to the entire function in which they appear (including in the return structural annotation). Shape variables used in `MatchCast` bindings are scoped only to the `SeqExpr` in which they appear. **Informal Semantics of `PrimExpr`s for Dimensions** @@ -555,7 +555,7 @@ This section describes the run-time checking performed by `MatchCast(var, value, ### Checking Structural Information at the Start and End of a Function -«Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchCast`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchCast` bindings. Suppose a function has the following signature, where the `Si` are structural annotations: +«Shape variables are bound at the start and end of a function or in `MatchCast` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchCast`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchCast` bindings. Suppose a function has the following signature, where the `Si` are structural annotations: ```python def f(arg1 : S1, arg2 : S2, ..., argn : Sn) -> Sr: From 76cb10c2967ab4e208608a42a625355b0c895bcb Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 7 Jan 2023 21:33:55 -0500 Subject: [PATCH 16/30] Make PackedFunc first-class --- relax_spec.md | 59 ++++++++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 6f1d2ea80d..b7f5a867ff 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -112,17 +112,17 @@ This specification provides a more detailed description of what each expression 1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). 2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. 3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. -4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchShape` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." -5. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. - 1. For `ExternFunc` nodes, the call will look up the registered `PackedFunc` by its global symbol and will call it with the given arguments (note that a TIR `PrimFunc` can be compiled into a `PackedFunc` and called using `ExternFunc` by defining a `global_symbol` attribute in the `PrimFunc`). «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» - 2. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s) «and `cast` (performs dynamic type conversions).» - 3. Any other expression must evaluate to a closure; the closure will then be called with the given arguments. +4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." +5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. +6. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. + 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s) «and `cast` (performs dynamic type conversions).» + 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. - Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. + Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» -6. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. -7. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. -8. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: +7. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. +8. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. +9. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. 1. The types must match: All `StructInfo` variants correspond to a type (`TensorStructInfo` to `DynTensorType`, `ShapeStructInfo` to `ShapeType`, etc.) and each type corresponds to a value (`DynTensorType` to a tensor value, `ShapeType` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: @@ -134,8 +134,8 @@ This specification provides a more detailed description of what each expression The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. -9. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. -10. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. +10. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +11. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. @@ -168,7 +168,7 @@ The types in Relax correspond to the broad categories of the values given above: 2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. 3. `ShapeType` corresponds to shape values, optionally giving the number of dimensions in the shape. 4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» -5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). Since packed functions are not first-class values (`ExternFunc` can appear only in the `op` position of a `Call` node), these do not actually correspond to any value in Relax, but can be used to assign a type to `ExternFunc` nodes. +5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). 6. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» @@ -191,8 +191,9 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Tensors* are n-dimensional arrays of scalar values (which can be integers of fixed bitwidths, floats of fixed bitwidths, or bools). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. - *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). -- *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time type information* (RTTI) indicating their argument types and result type, in order to facilitate dynamic type checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTTI is left up to the compiler implementation to determine so long as the `cast` operator can verify the type of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» +- *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. +- *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. - Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values of type `ObjectType`. ## Representation of Values at Run Time @@ -216,7 +217,7 @@ There are four relevant scopes in Relax, which determine where variables are vis 3. `SeqExpr`: `Var` nodes defined in a `BindingBlock` in a `SeqExpr` node can be referenced in any later binding within the same `BindingBlock`, in any binding within any later `BindingBlock` in that `SeqExpr` node, or in the `SeqExpr`'s body expression. The variables defined in the `BindingBlock`s leave scope once the `SeqExpr` returns. 4. `DataflowBlock`: `DataflowVar`s introduced in a `DataflowBlock` can be referenced in any later binding within that `DataflowBlock`, but leave scope *once that `DataflowBlock` finishes executing*. Definitions in a `DataflowBlock` that are intended to leave the `DataflowBlock` should be bound to an ordinary `Var`. -Note that Relax variables must be bound _exactly_ once. A global variable is bound if it is mapped to a function in the `IRModule` and a local variable is bound if it appears as a function parameter or if it appears on the left-hand side (LHS) of a binding (`VarBinding` or `MatchShape`). +Note that Relax variables must be bound _exactly_ once. A global variable is bound if it is mapped to a function in the `IRModule` and a local variable is bound if it appears as a function parameter or if it appears on the left-hand side (LHS) of a binding (`VarBinding` or `MatchCast`). «If there is another binding to a local variable with the same name as an already-bound variable, that is binding is considered to _shadow_ the previous binding, i.e., it is a binding to a new, distinct variable that happens to have the same name as the existing variable. The new, shadowing variable will exist only in the current scope; if the older variable was defined in an outer scope, then future uses of that name will refer to the older variable. [See the Wikipedia page for more information on variable shadowing.](https://en.wikipedia.org/wiki/Variable_shadowing)» @@ -276,15 +277,13 @@ The following criteria apply to all programs (including before normalization): 7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return structural annotation is *required*.» 8. `Op` nodes may appear only as the `op` argument to `Call` nodes. -9. `ExternFunc` expressions may appear only as the `op` argument to `Call` nodes. -10. The `type_args` field is used only for type checking calls of `ExternFunc`s and `Op`s. No other calls may have a non-empty `type_args`. -11. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. -12. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. -13. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» -14. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» -15. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. -16. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. -17. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. +9. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. +10. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. +11. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» +12. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» +13. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. +14. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. +15. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. @@ -470,7 +469,7 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 6. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. 7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT1, aT2, ..., aTn])`: 1. If `op` is a Relax `Op` node, then we look up its registered `FInferStructInfo` property. `FInferStructInfo` is a macro that takes in the `Call` node and produces structural information. Invoke `op.FInferStructInfo(Call(op, [a1, ..., an], type_args=[aT1, aT2, ..., aTn]))` and convert the result to a type using the `erase_struct_info` procedure defined above. The implementation of `FInferStructInfo` is free to throw errors. - 2. If `op` is `ExternFunc`, note that packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function's implementation to do any validation at run time. However, the type system uses the `type_args` field to determine the result type as follows: + 2. If `op` has `PackedFuncType`, note that packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function's implementation to do any validation at run time. (TODO: `derive_func` should be used here, propagated from the structural information.) However, the type system uses the `type_args` field to determine the result type as follows: 1. If there are no `type_args`, the resulting type is `ObjectType()`. 2. If there is exactly one member of `type_args`, use that as the return type. 3. If there are multiple members of `type_args`, then the type is `TupleType(fields=[aT1, aT2, ..., aTn])`. @@ -935,10 +934,12 @@ For each expression, we define how it affects the program's visible state and th 1. First `cond` is evaluated. Let the result be `r` (per type checking, it must be a bool scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. -8. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: - 1. If `op` is an `ExternFunc` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Next, look up the `PackedFunc` registered under the global symbol name. If it exists (it is an error at run time if it does not), call the `PackedFunc` using the given arguments and return the result. Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. - 2. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» - 3. In all other cases, first evaluate `op` (it must evaluate to a closure). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) +8. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. +9. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: + 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» + 2. Otherwise, first evaluate `op` (it must evaluate to a closure or `PackedFunc`). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. + 1. If `op` evaluated to a closure, push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) + 2. If `op` evaluated to a `PackedFunc`, simply invoke it. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. 9. For the node `SeqExpr(blocks, body)`, we evaluate as follows: 1. Push a new scope onto the stack. 2. Iterate through the `BindingBlock`s in `blocks` in order. We will call the current one `block`. For each binding in `Block`: @@ -946,7 +947,7 @@ For each expression, we define how it affects the program's visible state and th 2. If the binding is `VarBinding(var, value)`, then evaluate `value` and bind `var` to that value in the current scope; this assignment is aliasing and no new value is allocated. 3. If `block` is a `DataflowBlock`, remove all `DataflowVar`s bound in the block from the current scope before proceeding to the next block. 3. After iterating through the binding blocks, evaluate `body` in the current scope. That will be the return value of the `SeqExpr`. - 4. Pop the scope, removing any `Var` bindings introduced in the `SeqExpr`. This should also remove any shape variables introduced and bound in the `SeqExpr` as well. + 4. Pop the scope, removing any `Var` or shape variable bindings introduced in the `SeqExpr`. ### Optimizations From 3f368a09d2d9ecbc72e3ae128feb9d87e5601c80 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 9 Jan 2023 15:11:24 -0500 Subject: [PATCH 17/30] erase_to_well_defined should handle unbound shape vars in FuncStructInfo --- relax_spec.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index b7f5a867ff..164b0e31b5 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -657,13 +657,17 @@ def erase_to_well_defined( ) if s is FuncStructInfo: if params is defined: - return FuncStructInfo( - params=[ - erase_to_well_defined(param, var_scope, shape_var_scope) - for param in s.params - ], + new_params = [] + for param in s.params: + if param contains unbound shape variables: + insert unbound shape variables into shape_var_scope + new_params.append(erase_to_well_defined(param, var_scope, shape_var_scope)) + ret = FuncStructInfo( + params=new_params, ret=erase_to_well_defined(s.ret, var_scope, shape_var_scope) ) + remove any unbound shape variables added into shape_var_scope above + return ret else: return FuncStructInfo( ret=erase_to_well_defined(s.ret, var_scope, shape_var_scope), From 8387ef0f1d1060db6c0ecbd5d35a7f2ccd69cf43 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 31 Jan 2023 19:56:16 -0500 Subject: [PATCH 18/30] Include `sinfo_args`, greatly condense discussion of types --- relax_spec.md | 661 ++++++++++++++++---------------------------------- 1 file changed, 208 insertions(+), 453 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 164b0e31b5..cc911b314a 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -14,13 +14,12 @@ Though this document will use the TVMScript front end for some examples, specify 4. [Variable Scoping](#variable-scoping) 5. [Normal Form](#normal-form) 6. [Well-Formedness Criteria](#well-formedness-criteria) -7. [Types in Relax](#types-in-relax) -8. [Structural Information in Relax](#structural-information-in-relax) -9. [Semantics](#detailed-semantics) +7. [Structural Information in Relax](#structural-information-in-relax) +8. [Semantics](#detailed-semantics) # Overview -This section will outline the grammar of Relax and give very brief descriptions of the different components, including the semantics, type system, and shape system. The rest of this document will provide more detailed descriptions of these facets of the language, including the validity conditions that the type system and shape system uphold. +This section will outline the grammar of Relax and give very brief descriptions of the different components, including the semantics and structural information (`StructInfo`) system. The rest of this document will provide more detailed descriptions of these facets of the language, including the validity conditions that the `StructInfo` system upholds. ## Differences from Relay @@ -52,13 +51,6 @@ PrimExpr ::= | Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr) # (others may be added later, as deemed necessary) -Type ::= DynTensorType(ndim: int, dtype: DataType) - | ShapeType(ndim: int) - | ObjectType() - | TupleType(fields: [Type]) - | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») - | PackedFuncType() - DataType ::= Int(bitwidth: int) | Float(bitwidth: int) | Bool() @@ -82,7 +74,7 @@ Expr ::= Constant(data: NDArray) | Function(params: [Var], body: Expr, ret_struct_info: StructInfo?, attrs: Attrs?) | If(cond: Expr, true_branch: Expr, false_branch: Expr) | ExternFunc(global_symbol: string) - | Call(op: Expr, args: [Expr], type_args: [Type], attrs: Attrs?) + | Call(op: Expr, args: [Expr], sinfo_args: [StructInfo], attrs: Attrs?) | ShapeExpr(values: [PrimExpr]) | TupleGetItem(tuple_value: Expr, index: int) | Op(op_name: string) @@ -107,7 +99,7 @@ Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) ## Expression Survey -This specification provides a more detailed description of what each expression and type represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. +This specification provides a more detailed description of what each expression and `StructInfo` represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. 1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). 2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. @@ -115,17 +107,17 @@ This specification provides a more detailed description of what each expression 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. 6. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. - 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s) «and `cast` (performs dynamic type conversions).» + 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s). 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. - Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» + Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. 7. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. 8. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. 9. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. - 1. The types must match: All `StructInfo` variants correspond to a type (`TensorStructInfo` to `DynTensorType`, `ShapeStructInfo` to `ShapeType`, etc.) and each type corresponds to a value (`DynTensorType` to a tensor value, `ShapeType` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: + 1. The types must match: All `StructInfo` variants correspond to a category of value value (`TensorStructInfo` to a tensor value, `ShapeStructInfo` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: 1. For comparing tensor values to `TensorStructInfo`, `ndim` must match the number of dimensions in the tensor value (unless `ndim` is -1) and `dtype` must match the datatype used (unless `dtype` is `Void`). If `shape` has been specified, the shape of the value must match that encoded by `shape`; if specified, `shape` must be either a `Var` already bound in the current scope or a `ShapeExpr`. 2. For comparing shape values to `ShapeStructInfo`, `ndim` must match the number of dimensions in the shape value (unless `ndim` is -1). If `values` has been specified, the shape value must match that encoded by `values`. 3. «For comparing closures (function values) to `FuncStructInfo`, it is necessary for the compiled program to track run-time structural information for closures, since it is not possible to introspect the closure; this subject will be discussed in further detail later in the document.» @@ -143,7 +135,7 @@ This specification provides a more detailed description of what each expression ## Purity and Dataflow Blocks -A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. «In Relax, we conservatively assume that any function that calls an impure function is itself impure, though the attribute `force_pure` on a function can be used as an override (e.g., if a function creates a new tensor, mutates it, and returns it, that is still pure but does not satisfy the conservative rule).» +A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. Above, it is mentioned that `DataflowBlock`s are not allowed to contain constructs featuring control flow (`If` nodes or recursive calls to the current function) or calls to impure functions. This ensures that `DataflowBlock`s represent a directed acyclic graph of pure operations, which is similar to the graph-like abstractions of traditional deep learning frameworks. This allows many common optimizations from past frameworks to be directly adapted to `DataflowBlock`s without having to accommodate additional reasoning about more expressive features like control flow and side effects. @@ -154,28 +146,22 @@ There is one visible side effect that Relax permits inside otherwise "pure" func Even though an abnormal program exit is a visible side effect and removing or reordering it changes the observable semantics, it would be too great a restriction to prohibit error checking inside `DataflowBlock`s. Relax does not have any notion of exception handling, so the only consequence of a failed safety check can be exiting the program. It is permissible for the compiler to reorder, duplicate, or eliminate `MatchCast`, or otherwise pure operations that have the potential of failing, provided that doing so does not change the value returned by the program or any other visible behavior. -To indicate that an operator or `PackedFunc` that can abort with an error should *never* be reordered or removed by the compiler, it should *not* be marked as pure. However, this means that it cannot be used inside a `DataflowBlock`. - Note that in some programming languages like Koka, non-termination is also considered a side effect, since it can in some sense be "observed" by a user and affects the visible behavior of a program (e.g., if there is an infinite loop before a print statement, the print will never happen). However, since non-termination cannot be automatically detected in general and is unlikely to arise in deep learning models, we do not attempt to systematically track non-termination in Relax. In general, the Relax compiler is allowed to reorder or remove otherwise pure function calls even if they may not terminate. For example, if a pure function `f` that returns an integer scalar does not terminate, it is permissible in principle to rewrite `f() - f()` to 0. Exiting with an error and infinitely looping are traditionally considered "[divergence](https://en.wikipedia.org/wiki/Divergence_(computer_science))" in the programming languages literature. As a general principle, Relax's compiler is permitted to turn a program that diverges into a program that does not diverge (provided that no other visible effects change) so long as it never transforms a program that does not diverge into one that diverges. -## Type System Survey - -The types in Relax correspond to the broad categories of the values given above: - -1. `DynTensorType` corresponds to tensor values, giving the scalar data type and the number of dimensions (rank), both of which are optional. -2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. -3. `ShapeType` corresponds to shape values, optionally giving the number of dimensions in the shape. -4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» -5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). -6. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. +## Structural Information (`StructInfo`) System Survey -The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» +Analogously to a type system in most languages, Relax tracks structural information (referred to as `StructInfo` in the implementation) related to the categories of values in Relax: +1. `TensorStructInfo` corresponds to tensor values, giving the scalar data type, the number of dimensions (rank), and an expression that computes the tensor's shape (either a `ShapeExpr` or a `Var`), all of which are optional. +2. `TupleStructInfo` corresponds to tuple values, giving the `StructInfo` for each member of the tuple. +3. `ShapeStructInfo` corresponds to shape values, optionally giving the number of dimensions in the shape and an expression that computes the shape's dimensions (either a `ShapeExpr` or a `Var`). +4. `FunctionStructInfo` corresponds to function values (closures) and `PackedFunc`s (external functions), giving the types of the parameters, the return type, «and whether the function is pure.» +5. `ObjectStructInfo` is a parent to all Relax `StructInfo` and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. -## Structural Information System Survey +`StructInfo` is assigned to every variable in scope and every type of expression based on the values it returns via a set of inference rules defined later in the specification, making use of subtyping to assign more general `StructInfo` when a more specific one cannot be determined. «Relax is strongly typed, meaning that if the `StructInfo` inferred is less specific than the one expected, an error will be issued and an explicit check via `MatchCast` will be required.» -In Relax, tensor shapes are not handled in the type system, even though it would be greatly beneficial for the compiler to make use of shape information for static optimizations. Instead, shape information is tracked using Relax's structural information system, in which every expression has structural information associated with it (like tensor shapes) that is more expressive than its type. Structural information can convey richer properties about expressions, like tensor shapes, and can facilitate a greater degree of static reasoning. However, when it is not feasible for the compiler to draw conclusions about structural information, this information can be checked dynamically via `MatchCast`. The structural information is essentially an extended type system, so `MatchCast` also serves to handle type casting. +In Relax, tensor shapes are not statically handled in the type system, even though it would be greatly beneficial for the compiler to make use of shape information for static optimizations. Instead, shape information is tracked using Relax's structural information system, in which every expression has structural information associated with it (like tensor shapes) that is more expressive than its type. `StructInfo` can convey richer properties about expressions, like tensor shapes, and can facilitate a greater degree of static reasoning. However, when it is not feasible for the compiler to draw conclusions about structural information, this information can be checked dynamically via `MatchCast`. The structural information is essentially an extended type system, so `MatchCast` also serves to handle type casting. --- @@ -194,11 +180,11 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. -- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values of type `ObjectType`. +- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. ## Representation of Values at Run Time -Because Relax supports calls to arbitrary `PackedFunc`s that can operate on a low level, it is necessary to define a convention for how values will be represented at run time. At this time, the specification does not require any specific representation and permits compiler implementations to choose their own representations, provided that each value type listed above can be recognized at run time (for dynamic type checks). This means that Relax programs that call `PackedFunc`s directly are not portable across compiler implementations: The `PackedFunc`s used must be able to operate on the run-time representations of values. +Because Relax supports calls to arbitrary `PackedFunc`s that can operate on a low level, it is necessary to define a convention for how values will be represented at run time. At this time, the specification does not require any specific representation and permits compiler implementations to choose their own representations, provided that each value type listed above can be recognized at run time (for dynamic `StructInfo` checks). This means that Relax programs that call `PackedFunc`s directly are not portable across compiler implementations: The `PackedFunc`s used must be able to operate on the run-time representations of values. Possible specification in terms of the TVM object system: @@ -240,7 +226,7 @@ def func(x: Tensor) -> Tensor: # Normal Form -To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the type- and structure-checking rules for operators rely on macros (`FInferShapeInfo`), _this means that the structure of the program can affect type and structure inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these type- and structure-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying type or structure checking. +To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the `StructInfo`-checking rules for operators rely on macros (`FInferShapeInfo`), _this means that the structure of the program can affect `StructInfo` inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these `StructInfo`-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying `StructInfo` checking. The normal form for Relax is very similar to ANF; differences will be noted. Here are the criteria required for a program to be in normal form: 1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. @@ -250,7 +236,7 @@ The normal form for Relax is very similar to ANF; differences will be noted. Her 3. In fact, the `body` field of a `Function` node and the `true_branch` and `false_branch` fields of `If` nodes _must_ be `SeqExpr`s. If these fields are not `SeqExpr`s, they must be "wrapped" in a `SeqExpr`. 4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. Empty `BindingBlock`s should be dropped. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. If all the `BindingBlock`s are empty, then the `blocks` field of the `SeqExpr` should be set to an empty list. -Programs that are parsed should be "normalized" before performing type checking or structure checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: +Programs that are parsed should be "normalized" before performing `StructInfo` checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: 1. For each function in the `IRModule`, ensure that the body is a `SeqExpr`. If the body is not a `SeqExpr`, wrap the function body in a `SeqExpr`, creating a new `BindingBlock` to hold `VarBinding`s for any non-leaf expressions that need to be bound to variables. 2. If the function body is already a `SeqExpr`, consolidate all `BindingBlock`s, then check if the `body` field of the `SeqExpr` is a leaf expression. If not, bind it to a new var in the final `BindingBlock` and replace the `SeqExpr` body with the new var. 3. If the function body is not a `SeqExpr`, then recurse down the body's AST, binding any nested non-leaf expressions to a var in the current scope (doing this process in breadth-first order from left to right will respect the evaluation order in the semantics). If the body itself is a non-leaf expression, finally bind it to a var and have the final `SeqExpr` return the new var. @@ -260,7 +246,7 @@ Programs that are parsed should be "normalized" before performing type checking # Well-Formedness Criteria -Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid, which includes conforming to the expectations of the above-described normal form. +Prior to `StructInfo` checking, Relax programs must conform to certain syntactic criteria to be valid, which includes conforming to the expectations of the above-described normal form. The following criteria apply to all programs (including before normalization): 1. `DataflowVar`s can be bound only inside `DataflowBlock`s. Additionally, a `DataflowVar` may not be used outside of the `DataflowBlock` in which it is defined. @@ -273,11 +259,11 @@ The following criteria apply to all programs (including before normalization): 2. Calls to a global function that is mutually recursive with the current function 3. `If` nodes - «Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during type checking.» + «Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during `StructInfo` checking.» 7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return structural annotation is *required*.» 8. `Op` nodes may appear only as the `op` argument to `Call` nodes. -9. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. +9. If a variable has a `StructInfo` annotation, the `ndim` of any `TensorStructInfo` and `ShapeStructInfo`s must match the number of dimensions in their `shape` and `values` fields, respectively. 10. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. 11. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» 12. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» @@ -287,238 +273,23 @@ The following criteria apply to all programs (including before normalization): Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. -# Types in Relax - -Relax's type system is intended to enforce strong guarantees that values are passed correctly between expressions. The design emphasis is on simplicity, aiming to leave more complex analysis to the structural information. - -Relax presently has six types, corresponding to the values in the language: - -1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. -2. `ShapeType`, referring to shape values. The number of dimensions in the shape as given as `ndim` and is optional (using -1 for `ndim` indicates an unknown number of dimensions). -3. `FuncType`, referring to functions (closures). `FuncType`s specify the types of their parameters, a return type, and whether the function is pure. -4. `TupleType`, referring to tuple values, giving the types of their fields. -5. `PackedFuncType`, referring to the type of PackedFunctions. -6. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. - -## Erasing Structural Information into Types - -Several type-checking rules rely on structural annotations or rules for defining the structural information for a call to an `Op` or `PackedFunc`. In general, types are simpler than structural information (to facilitate more precise reasoning). Structural information can be convereted into a type as follows (in pseudocode): - -```python -def erase_struct_info(si: StructInfo) -> Type: - if si is TensorStructInfo: - return DynTensorType(ndim=si.ndim, dtype=si.dtype) - if si is ShapeStructInfo: - return ShapeType(ndim=si.ndim) - if si is TupleStructInfo: - return TupleType(fields=[erase_struct_info(field) for field in si.fields]) - if si is FuncStructInfo: - # this should be the case only for packed funcs - if si.params is not specified: - return PackedFuncType() - return FuncType( - arg_types=[erase_struct_info(arg_type) for arg_type in si.params], - ret_type=erase_struct_info(si.ret) - pure=False) # TODO: This suggests we should either handle purity - # in StructInfo entirely (and not make it part of the type) - # or include it in both StructInfo and the type system - # only remaining case is ObjectStructInfo - return ObjectType() -``` - -## Subtyping - -Relax implements subtyping, which means that members of types can be accepted where members of their supertypes are accepted. We will denote the subtyping relationship as `T1 <: T2`, indicating that `T1` is a subtype of `T2`. For example. if `T1 <: T2` and some function expects an argument of type `T2`, then passing a member of type `T1` to that function is permitted; passing a member of type `T2` as an argument to a function that expects type `T1` for that argument is *not* permitted—the value would have to be dynamically cast to `T1` using the `cast` operator. - -### Rules for Subtyping - -1. Reflexivity: For all types `T`, `T <: T`. -2. Transitivity: For all types `T1`, `T2`, and `T3`, if `T1 <: T2` and `T2 <: T3`, then `T1 <: T3`. -3. For all types `T`, `T <: ObjectType`. Hence, `ObjectType` is a supertype to all Relax types (all values in Relax are members of `ObjectType`). -4. Rules for `DynTensorType`: - 1. For all fixed `ndim` values `m`, where `m` ≥ 0, and `dtype`s `d`, `DynTensorType(ndim=m, dtype=d) <: DynTensorType(ndim=m, dtype=Void)`. - 2. For all fixed `ndim` values `m` and `dtype`s `d` that are not `Void`, `DynTensorType(ndim=m, dtype=d) <: DynTensorType(ndim=-1, dtype=d)`. - 3. Corollary: `DynTensorType(ndim=-1, dtype=Void)` is a supertype to all tensor types, since it refers to any possible tensor value. -5. Suppose we have types `T1 <: T1'`, `T2 <: T2'`, …, `Tn <: Tn'`. Then `TupleType(fields=[T1, T2, ..., Tn]) <: TupleType(fields=[T1', T2', ..., Tn'])`. -6. Rules for `FuncType`: - 1. Impure functions are supertypes to pure functions. Namely, if we have types `T1`, `T2`, …, `Tn` and `Tr`, then `FuncType(arg_types=[T1, T2, ..., Tn], ret_type=Tr, pure=True) <: FuncType(arg_types=[T1, T2, ..., Tn], ret_type=Tr, pure=False)`. - 2. Suppose we have types `T1' <: T1`, `T2' <: T2`, …, `Tn' <: Tn` and `Tr <: Tr'`. Then `FuncType(arg_types=[T1, T2, ... Tn], ret_type=Tr, pure=p) <: FuncType(arg_types=[T1', T2', ..., Tn'], ret_type=Tr', pure=p)`. Note the direction of the subtyping relationships for the argument and return types: We must be able to *call* this function with the *same* arguments and *use the returned value* wherever it is accepted—hence a function that takes more general arguments and returns a more specific return value can be used in place of the original. - -These rules allow us to define the least upper bound (LUB) for any two types `T1` and `T2`, meaning that it is the most specific type `T` for which `T1 <: T` and `T2 <: T` ("most specific" meaning that if there exists some other `T'` for which `T1 <: T'` and `T2 <: T'`, then `T <: T'`). The LUB is guaranteed to exist for any two types because `Object` is a supertype to all types. - -Note that the rule for obtaining the LUB of function types relies on the counterpart to the LUB, the greatest lower bound (GLB). The GLB is not guaranteed to exist for any two types in Relax, as there is no single type that is a subtype of all others. - -We can give an algorithm for determining the LUB and GLB for two types, in pseudocode: +# Structural Information (`StructInfo`) in Relax -```python -def find_glb(T1 : Type, T2 : Type) -> Type?: - if T1 == T2: # syntactic equality - return T2 - if T1 is ObjectType: - return T2 - if T2 is ObjectType: - return T1 - if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType, or not both PackedFuncType: - return None - if T1 and T2 are both ShapeType: - ret_ndim = T1.ndim - if ret_ndim == -1: - ret_ndim == T2.ndim - if ret_ndim != -1 and T2.ndim != ret_ndim: - return None - return ShapeType(ret_ndim) - if T1 and T2 are both DynTensorType: - ret_ndim = T1.ndim - ret_dtype = T1.dtype - if ret_ndim == -1: - ret_ndim == T2.ndim - if ret_dtype == Void: - ret_dtype = T2.dtype - if ret_ndim != -1 and T2.ndim != ret_ndim: - # mismatch, so there's no common lower bound - return None - if ret_dtype != Void and T2.dtype != ret_dtype: - return None - return DynTensorType(ret_ndim, ret_dtype) - if T1 and T2 are both TupleType: - if they do not have the same length: - return None - fields = [] - for field1, field2 in zip(T1.fields, T2.fields): - glb = find_glb(field1, field2) - if glb is None: - return None - fields.append(glb) - return TupleType(fields) - if T1 and T2 are both FuncType: - «if they are not both pure or both impure:» - «return None» - purity = T1.purity - if they do not have the same arity: - return None - # mutual recursion with finding the LUB - arg_types = [ - find_lub(arg_type1, arg_type2) - for arg_type1, arg_type2 in zip(T1.arg_types, T2.arg_types) - ] - ret_type = find_glb(T1.ret_type, T2.ret_type) - if ret_type is None: - return None - return FuncType(arg_types, ret_type, purity) - -def find_lub(T1 : Type, T2 : Type) -> Type: - if T1 == T2: # syntactic equality - return T1 - if T1 or T2 is ObjectType: - return Object - if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType, or both PackedFuncType: - return ObjectType - if T1 and T2 are both ShapeType: - res_ndim = T1.ndim - if T1.ndim != T2.ndim: - res_ndim = -1 - return ShapeType(res_ndim) - if T1 and T2 are both DynTensorType: - res_ndim = T1.ndim - res_dtype = T1.dtype - if T1.ndim != T2.ndim: - res_ndim = -1 - if T1.dtype != T2.dtype: - res_dtype = Void - return DynTensorType(res_ndim, res_dtype) - if T1 and T2 are both TupleType: - if they do not have the same length: - return ObjectType - return TupleType([ - find_lub(field1, field2) - for field1, field2 in zip(T1.fields, T2.fields) - ]) - if T1 and T2 are both FuncType: - «purity = (True iff they're both pure)» - if they do not have the same arity: - return ObjectType - arg_types = [] - for arg_type1, arg_type2 in zip(T1.arg_types, T2.arg_types): - # potential mutual recursion - glb = find_glb(arg_type1, arg_type2) - if glb is None: - return ObjectType - arg_types.append(glb) - return FuncType(arg_types, find_lub(T1.ret_type, T2.ret_type), «purity») -``` +Structural information in Relax is intended to enforce basic guarantees that values are passed correctly between expressions, while also analyzing more complex properties like tensor shapes in a _"best-effort"_ fashion. Namely, anything that cannot be proved statically can instead be checked at run time. Each Relax expression has structural information associated with it. The best-effort nature of the structural system in Relax means that the analysis may detect _some_ errors at compile time and report them, but it may give warnings when it _cannot_ draw conclusions, perhaps suggesting that dynamic checks via `MatchCast` should be inserted. Note that the precision of the static analysis can potentially be improved by some compile-time optimizations like constant propagation, function inlining, and other partial evaluation–like transformations. -### When Type Conversions are Necessary - -For two types `T1` and `T2`, if `T1 <: T2`, then a value of type `T1` can be passed anywhere a value of type `T2` is expected without any need for type conversions or dynamic checks. - -*However*, if `T1 <: T2`, then passing a value of type `T2` where `T1` is expected can only be done if there has been some kind of dynamic check or conversion of that value. «Relax is *strongly typed*, meaning that the compiler will give an error in this situation and require an explicit conversion via a `MatchCast` node, which inspects the value's run-time representation.» - -If `T1` is not a subtype of `T2` and `T2` is not a subtype of `T1`, then it is always a type error to pass a value of either type where a value of the other is expected (no member of either type can be a member of the other). - -## Type Checking Rules - -The type checking rules for Relax are relatively simple and allow in some cases for types to be inferred without user annotations. Below, we describe how the types for each expression can be derived and when type checking should return an error. - -Let us consider a typing context `Γ`, which is a map of variables to types. - -1. «We type check the entire `IRModule` one function definition at a time. To handle mutual recursion, we prepopulate `Γ` with the annotated types of all global functions that are called mutually recursively. We then proceed to check the types of the global functions one at a time.» -2. Given a variable `v`, if `v` is in `Γ`, then we return `Γ[v]`. Otherwise, it is an error. -3. Given a constant expression `Constant(data)`, the type is `DynTensorType(ndim=n, dtype=d)` where `n` is the number of dimensions in `data` and `d` is its data type (recall that `data` is an `NDArray` literal). -4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType(n)`, where `n` is the length of `dims`. -5. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. -6. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. -7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT1, aT2, ..., aTn])`: - 1. If `op` is a Relax `Op` node, then we look up its registered `FInferStructInfo` property. `FInferStructInfo` is a macro that takes in the `Call` node and produces structural information. Invoke `op.FInferStructInfo(Call(op, [a1, ..., an], type_args=[aT1, aT2, ..., aTn]))` and convert the result to a type using the `erase_struct_info` procedure defined above. The implementation of `FInferStructInfo` is free to throw errors. - 2. If `op` has `PackedFuncType`, note that packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function's implementation to do any validation at run time. (TODO: `derive_func` should be used here, propagated from the structural information.) However, the type system uses the `type_args` field to determine the result type as follows: - 1. If there are no `type_args`, the resulting type is `ObjectType()`. - 2. If there is exactly one member of `type_args`, use that as the return type. - 3. If there are multiple members of `type_args`, then the type is `TupleType(fields=[aT1, aT2, ..., aTn])`. - 3. Otherwise, check the types of the subexpressions, left to right. Suppose `op` has the type `Tf`, `a1` has type `T1`, …, and `an` has type `Tn`. If `Tf` is not a function type with exactly `n` arguments, we consider it a type error and require an explicit cast. Suppose `Tf` has argument types `T1'`, `T2'`, …, `Tn'`. Consider it a type error if any of the following does not hold: `T1 <: T1'`, `T2 <: T2'`, …, or `Tn <: Tn'`. Then the return type is `Tf.ret_type`. -8. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. -9. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» -10. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. - 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. - 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» - 2. For each binding `VarBinding(v, e)` in the current block, check the type of `e` and suppose it is `T'`. If `v` has a structural annotation, then let `T` be the corresponding type (via the `erase_struct_info` procedure above). If there is no annotation, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and otherwise add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, we require `v` to have a structural annotation add `v` to `Γ` with its annotated type before type-checking the function body; see the rule for `Function` nodes.) - 3. For each `MatchCast(v, e, struct_info)`: - 1. Check the type of `e` and let it be `T'`. - 2. Let `T''` be the type corresponding to `struct_info` (via the `erase_struct_info` procedure). - 3. Emit a warning if `T'` is not a supertype of `T''` and `T''` is also not a supertype of `T'`; this indicates that the cast is _guaranteed_ to fail at run time. - 4. If `v` has been defined and it has a structural annotation, then let `T` be its corresponding type (via `erase_struct_info`). - 5. If `T` has been defined, then emit an error if `T` is not a supertype of `T''`. - 6. If `v` has been defined and does not have a structural annotation, then add `v` to `Γ` with type `T''`. If `T` has also been defined, then add `v` to `Γ` with type `T`. - 2. If the current block is a `DataflowBlock`, remove any `DataflowVar`s from `Γ` after we have finished processing the bindings in that block. - 3. Finally, the type of the `SeqExpr` is the type of `body` checked under the current `Γ`. Afterwards, remove all bindings added from the `b0`, `b1`, …, `bn` from `Γ`. -11. Let us consider a function `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`. All of the vars are required to have structural annotations; let `T1` be the type corresponding to `v1`'s annotation (via `erase_struct_info`), `T2` be the type corresponding to `v2`'s annotation, etc.. - 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, «where `p` is `True` if a `pure` attribute is included and `False` otherwise». Remove `fv` from `Γ` before returning. - 2. Add `v1` to `Γ` with type `T1`, `v2` with type `T2`, …, and `tn` with type `Tn`. Recursively type check `body` in this new context: - 1. «Determining purity: If `body` contains any call to a function whose return type does not specify that it is pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not specify the `pure` attribute, then we consider the function to be (potentially) impure. If all calls are to functions whose return type specifies purity or that include the `pure` attribute on the call or `Op`, then the function is treated as pure.» - 2. «Suppose the purity defined in the previous step is `p'`. Suppose the annotated function purity (in the attributes) is `p`. If `p'` is false while `p` is true, then it is a type error; if `p` was omitted, use `p'` for `p`.» - 3. «If the function has the attribute "`force_pure`," then consider `p` to be true, even if the check above judged the function not to be pure. The compiler may emit a warning in this situation.» - 4. Suppose the result of type-checking `body` is `Tr'`. If the current function is not recursive or a mutually recursive global function and `ret_struct_info` is undefined, consider the function to have type `FuncType([T1, T2, ..., Tn], Tr', pure=p)`. If `ret_struct_info` is defined, then let `Tr` be `erase_struct_info(ret_struct_info)`. If `Tr' <: Tr`, then we consider the function to have type `FuncType([T1, T2, ..., Tn], Tr, pure=p)`. «If `Tr <: Tr'`, return an error and require an explicit cast in the function body.» If `Tr'` is not a subtype of `Tr` and `Tr` is also not a subtype of `Tr'` (meaning a dynamic cast cannot succeed), this is an error. - 5. Remove `v1`, `v2`, …, and `vn` from `Γ` before returning. - -# Structural Information in Relax - -In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. While this allows Relay's type system to make strong guarantees about tensor shapes, it results in greater complexity in type checking and makes it difficult to implement new operators or handle cases like tensors with symbolic shapes. - -Relax instead aims to facilitate analysis of more complex properties like shapes by tracking _structural information_ pertaining, encoding as much analysis as is feasible at compile-time in a _"best-effort"_ fashion. Anything that cannot be proved statically can instead be checked at run time. Each Relax expression has structural information associated with it just as it has a type. Indeed, the structural information for each expression can be simplified into a type (recall [the procedure for doing so](#erasing-structural-information-into-types)), so the structural information for an expression can be thought of as an extended type that is checked in a less precise manner. The best-effort nature of the structural system in Relax means that the analysis may detect _some_ errors at compile time and report them, but it may give warnings when it _cannot_ draw conclusions, perhaps suggesting that dynamic checks via `MatchCast` should be inserted. Note that the precision of the static analysis can potentially be improved by some compile-time optimizations like constant propagation, function inlining, and other partial evaluation–like transformations. - -Tensor shapes are the primary motivation for including structural information in Relax, as shape information is particularly important for memory planning. Relax's structural information system uses expressions to encode tensor shapes, which allows for using shape variables and arithmetic expressions to encode a rich variety of shape constraints. Note, however, that the structural system could potentially be extended to encode and analyze further information, like tensor sparsity or density. +Tensor shapes are the primary motivation for including structural information in Relax, as shape information is particularly important for memory planning. In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. While this allows Relay's type system to make strong guarantees about tensor shapes, it results in greater complexity in type checking and makes it difficult to implement new operators or handle cases like tensors with symbolic shapes. By contrast, Relax's `StructInfo` system uses expressions to encode tensor shapes, which allows for using shape variables and arithmetic expressions to encode a rich variety of shape constraints. Note, however, that the structural system could potentially be extended to encode and analyze further information, like tensor sparsity or density. ## Defining Structural Information -As with types, the structural information in Relax corresponds to the values in the language: -* `TensorStructInfo` describes tensor values. Like in `DynTensorType`, the `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` whose type is `ShapeType`. If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation (that returns a shape). which can be useful for memory planning. -* `ShapeStructInfo` describes shape values. Like `ShapeType`, it has an `ndim` field that gives the number of dimensions in the shape (with -1 indicating that it is statically unknown). It additionally has an optional `values` field. If defined, `values` gives a list of `PrimExpr`s that indicate how to compute the dimensions of the shape, potentially providing further information for static analyses. -* `TupleStructInfo` describes tuple values, namely by giving the structural information for each of the tuple's members via `fields`. +The structural information in Relax corresponds to the values in the language: +* `TensorStructInfo` describes tensor values. The `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` (with `ShapeStructInfo`). If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation that returns a shape value, which can be useful for memory planning. +* `ShapeStructInfo` describes shape values. It has an `ndim` field that gives the number of dimensions in the shape (with -1 indicating that it is statically unknown). It additionally has an optional `values` field. If defined, `values` gives a list of `PrimExpr`s that indicate how to compute the dimensions of the shape, potentially providing further information for static analyses. +* `TupleStructInfo` describes tuple values, namely by giving the `StructInfo` for each of the tuple's members via `fields`. * `FuncStructInfo` describes closure values or `PackedFunc`s. There are two ways in which to specify `FuncStructInfo`: - 1. By specifying `params` and `ret` (for closures). `params` gives the structural information corresponding to each of the function's parameters and `ret` gives the structural information corresponding to the result of calling the function. In this case, the `derive_func` field is left undefined. + 1. By specifying `params` and `ret` (for closures). `params` gives the `StructInfo` corresponding to each of the function's parameters and `ret` gives the `StructInfo` corresponding to the result of calling the function. In this case, the `derive_func` field is left undefined. 2. By giving a `derive_func` macro (for `PackedFunc`s). The `derive_func` macro is takes a call to the corresponding `PackedFunc` and the variable mapping context and returns the `StructInfo` of the result. In this case, the `params` field is left undefined and the `ret` field is ignored. * `ObjectStructInfo` describes arbitrary object values. -While these categories correspond closely to types, they serve as a mechanism for propagating further information (especially as given in shape annotations in variable bindings) throughout the program and facilitating more static analysis. - ### Expressing Shape Dimensions A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds to a dimension. The use of TIR `PrimExpr`s for shape dimensions allows shape computations to express complex constraints that include variables and integer arithmetic expressions in addition to just constant dimensions. @@ -544,7 +315,7 @@ This section describes the run-time checking performed by `MatchCast(var, value, 2. If `struct_info` is `TensorStructInfo(ndim, dtype, shape)`, then check that `value` is a tensor value, that it has a rank of `ndim` (if `ndim` is not -1), a datatype of `dtype` (if `dtype` is not `Void`). If `shape` is defined, consider the following cases: 1. If `shape` is a `Var`, then check that the concrete shape of `value` matches the value bound to the `Var`. 2. If `shape` is a `ShapeExpr`, then compare the fields of the `ShapeExpr` to the concrete shape of `value`, dimension by dimension (comparing the `i`th field of the `ShapeExpr` to the `i`th dimension of the shape of `value`). Give an error if the number of the dimensions does not match the number of fields in the `ShapeExpr`. - 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. + 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. 2. Otherwise, evaluate the field of the `ShapeExpr` and ensure that it matches the concrete value of the dimension. 3. If `struct_info` is `ShapeStructInfo(ndim, values)`, then check that `value` is a shape value, that it has `ndim` dimensions (if `ndim` is not -1). If `values` is defined, then compare it to the concrete shape value (comparing the `i`th member of `values` to the `i`th field of the shape value): 1. If the `i`th member of `values` consists of only an unbound shape variable, then bind that variable to the `i`th field of the the concrete shape value. @@ -575,9 +346,110 @@ def f(arg1, arg2, ..., argn): ``` » +## Subtyping for `StructInfo` + +Relax implements subtyping for `StructInfo`, which means that values with some `StructInfo` can be accepted where values with more general `StructInfo` are accepted We will denote the subtyping relationship as `S1 <: S2`, indicating that `S1` is a subtype of `S2`. For example. if `S1 <: S2` and some function expects an argument with `StructInfo` `S2`, then passing a value with `StructInfo` `S1` to that function is permitted; passing a value with `StructInfo` `S2` as an argument to a function that expects `S1` for that argument is *not* permitted—the value would have to be dynamically cast to `S1` using `MatchCast`. + +Note that judging subtyping requires potentially reasoning about arbitrary `ShapeExpr`s. We assume that the compiler is able to draw the following three conclusions about two shape expressions, acting conservatively (it will consider values to be _definitely_ equal or _definitely not_ equal only if it is certain): +* They are _definitely_ statically equal in all cases. +* They are _possibly_ statically equal. +* They are _definitely not_ statically equal in at least one case. + +1. Reflexivity: `S1 <: S1` for all `S1`. +2. Transitivity: For all `S1`, `S2`, and `S3`, if `S1 <: S2` and `S2 <: S3`, then `S1 <<: S3`. +3. For all `S1`, `S1 <: ObjectStructInfo()`. +4. For `TensorStructInfo`: + 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=-1, dtype=d)`. + 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=Void, shape=s)`. + 3. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s`, `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=d, shape=undefined)`. + 4. Given any datatype `d`, an arbitrary `ndim` `n`, and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` if `s1` and `s2` are _definitely_ statically equal. We say that `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` _possibly_ holds if `s1` and `s2` are _possibly_ statically equal. +5. For `ShapeStructInfo`: + 1. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (possibly undefined), `ShapeStructInfo(ndim=n, values=v) <: ShapeStructInfo(ndim=-1)`. + 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <: ShapeStructInfo(ndim=n, values=undefined)`. + 3. Given an arbitrary `ndim` `n` and two arbitrary sets of values `v1` and `v2` (both defined), `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` if, for all valid `i`, `v1[i]` and `v2[i]` can be proven to be _definitely_ statically equal. We say that `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` _possibly_ holds if `v1` and `v2` are _possibly_ statically equal. +6. Given two lists of `StructInfo` `fields1` and `fields2`, `TupleStructInfo(fields=fields1) <: TupleStructInfo(fields=fields2)` if `fields1` and `fields2` are the same length and for all `i`, `fields1[i] <: fields2[i]`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships for the fields only possibly holds. +7. For `FuncStructInfo`: + 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. + 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <: F2` only if `F1.derive_func` and `F2.derive_func` are identical. + 3. Given two lists of `StructInfo` parameters `P1` and `P2` and two `StructInfo` annotations `R1` and `R2`, `FuncStructInfo(params=P1, ret=R1) <: FuncStructInfo(params=P2, ret=R2)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <: P1[i]` and `R1 <: R2`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships given only possibly holds. + +These rules allow us to define the least upper bound (LUB) for any two `StructInfo` `S1` and `S2`, meaning that it is the most specific `StructInfo` `S` for which `S1 <: S` and `S2 <: S` ("most specific" meaning that if there exists some other `S'` for which `S1 <: S'` and `S2 <: S'`, then `S <: S'`), modulo reasoning about arithmetic (for example, the compiler may judge that two shape expressions are _possibly_ equivalent rather than _definitely_ equivalent). The LUB is guaranteed to exist for any two `StructInfo` because all `StructInfo` are subtypes of `ObjectStructInfo`. + +We can define how to find the LUB of two structural information annotations (modulo arithmetic reasoning) as follows, in pseudocode: + +```python +def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: + if S2 is ObjectStructInfo: + return S1 + if S1 is ObjectStructInfo: + return S2 + if S1 and S2 do not match types (e.g., not both TensorStructInfo, etc): + return ObjectStructInfo() + if S1 and S2 are both ShapeStructInfo: + if S1.ndim == -1: + return S1 + if S2.ndim == -1: + return S2 + if S1.ndim != S2.ndim: + return ShapeStructInfo(ndim=-1) + if S1.ndim == S2.ndim: + if S1.values is undefined: + return S1 + if S2.values is defined: + return S2 + if S1.values can be statically proven to match S2.values: + return S1 + # values either proven not to match or unknown + return ShapeStructInfo(ndim=S1.ndim) # leave values undefined + if S1 and S2 are both TensorStructInfo: + ndim = S1.ndim if S1.ndim == S2.ndim else -1 + dtype = S1.dtype if S1.dtype == S2.dtype else Void + if ( + S1.ndim == -1 or S2.ndim == -1 or S1.ndim != S2.ndim + or S1.shape is undefined or S2.shape is undefined + ): + return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined + # both shapes are defined + if S1.shape can be proven to equal S2.shape: + return S1 + # either proven to be unequal or cannot be concluded whether they are equal + return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined + if S1 and S2 are both TupleStructInfo: + if S1.fields and S2.fields are of different lengths: + return ObjectStructInfo() + return TupleStructInfo( + unify_struct_info(S1.fields[i], S2.fields[i]) + for 0 <= i < length of S1.fields + ]) + if S1 and S2 are both FuncStructInfo: + if S1.params and S2.params are not both defined or both undefined: + return ObjectStructInfo() + if S1.params and S2.params are both undefined: + # they must be the same function, not bothering to check eta-equivalence + if S1.derive_func == S2.derive_func: + return S1 + return FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive) + if S1.params and S2.params are both defined: + if S1.params and S2.params do not have the same length: + return ObjectStructInfo() + unified_params = [] + for 0 <= i < length of S1.params: + unified_param = unify_struct_info(S1.params[i], S2.params[i]) + # That is, if the params judged to be equal, use them. + # If there is some pair that is not equal, + # we can't unify these types except with ObjectStructInfo. + # This rule should suffice in practice; otherwise we would + # need to give a full definition of the GLB + if unified_param <: S1.params[i] and unified_param <: S2.params[i]: + unified_params[i] = unified_param + else: + return ObjectStructInfo() + return FuncStructInfo(params=unified_params, ret=unify_struct_info(S1.ret, S2.ret)) +``` + ## Deriving the Structural Information for Each Expression -For each expression type, we can recursively build up the structural information associated with the expression. +For each kind of expression, we can recursively build up the structural information associated with the expression. ### Auxiliary Procedures @@ -585,41 +457,14 @@ For each expression type, we can recursively build up the structural information There are two special `derive_func` values built into the compiler that are used for checking the structural information of `PackedFunc`s. -The first is `default_derive`, giving a simple way to determine the resulting structural information of a `PackedFunc` from its type arguments. `default_derive` takes one argument that is a `Call` node and is defined as follows: -1. Suppose its call node argument is `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`. -2. If `type_args` is of length 0, then return `ObjectStructInfo()`. -3. If `type_args` is of length 1, then return `wrap_type(aT1)`. -4. If `type_args` is of a greater length than 1, then return `TupleStructInfo(fields=[wrap_type(aT1), wrap_type(aT2), ..., wrap_type(aTn)])`. +The first is `default_derive`, giving a simple way to determine the resulting structural information of a `PackedFunc` from its `StructInfo` arguments. `default_derive` takes one argument that is a `Call` node and is defined as follows: +1. Suppose its call node argument is `Call(op, [arg1, arg2, ..., argn], sinfo_args=[aS1, aS2, ..., aSn])`. +2. If `sinfo_args` is of length 0, then return `ObjectStructInfo()`. +3. If `sinfo_args` is of length 1, then return `aS1`. +4. If `sinfo_args` is of a greater length than 1, then return `TupleStructInfo(fields=[aS1, aS2, ..., aSn])`. The second is `empty_derive`, which is the weakest possible derivation. It simply returns `ObjectStructInfo` regardless of its argument. This is used for worst-case deducation of `StructInfo` for a `PackedFunc`. -**Wrapping Types** - -For deriving the structural information for a `PackedFunc` call, the type arguments are converted into structural information. This is a straightforward procedure, given here in pseudocode: - -```python -def wrap_type(t: Type) -> StructInfo: - if t is ObjectType: - return ObjectStructInfo() - if t is PackedFuncType: - # leave params undefined; see default_derive below - return FuncStructInfo(ret=ObjectStructInfo(), derive_func=default_derive) - if t is FuncType: - # leave derive_func undefined - return FuncStructInfo( - params=[wrap_type(arg_type) for arg_type in t.arg_types], - ret=wrap_type(t.ret_type) - ) - if t is TupleType: - return TupleStructInfo(fields=[wrap_type(field) for field in t.fields]) - if t is ShapeType: - # leave values undefined - return ShapeStructInfo(ndim=t.ndim) - if t is DynTensorType: - # leave shape undefined - return TensorStructInfo(ndim=t.ndim, dtype=t.dtype) -``` - **Erasing Out-of-Scope Information** When returning a value from an inner scope to an outer scope (namely, the `body` field of a `SeqExpr`, which may use variables defined in the binding blocks, and the `body` field of a `Function`, which may use variables defined in the function body), it may be possible for the derived `TensorStructInfo` or `ShapeStructInfo` to contain Relax variables or shape vars that have gone out of scope. We defined a procedure to check for any of these out-of-scope variables and weaken the structural information not to include it. The procedure is defined below, in pseudocode: @@ -725,180 +570,59 @@ def get_shape_var_mapping(S1: StructInfo, S2: StructInfo) -> {tir::Var, PrimExpr return {} ``` -**Checking Compatibility** - -In many cases during the derivation of structural information, it is important to judge when two distinct structural information encodings are compatible with each other or when they are too different from each other to be reconciled, which can indicate an error. In the case of shape information, this could mean having two symbolic shapes that can be proven not to be equal to each other. Because shape expressions can contain arithmetic and it can be very difficult to statically prove whether two arithmetic expressions are equal, we permit the compiler implementation to make a best-effort attempt to prove equality for arithmetic expressions. (The user can insert a `MatchCast` to check definitively.) Since the checks are best-effort, the compatibility check will only report incompatibility if two values are _definitely_ different from each other. - -We can check if some structural information `S1` is accepted where structural information `S2` is expected by the process given below, which we refer to as `check_compability(S1, S2)` for convenience. `check_compatibility` can find that `S1` and `S2` are compatible, possibly compatible, or incompatible. "Incompatible" indicates a definite mismatch that should result in a compiler error; "possibly compatible" indicates that the structures may or may not match and should likely result in a compiler warning (indicating that a user may want to insert a dynamic check). An invariant that should should is that if `check_compatibility(S1, S2)` returns "compatible" or "possible compatible", `erase_struct_info(S1) <: erase_struct_info(S2)` should hold; that is, compatibility of structural information should be consistent with typing rules. - -1. If `S2` is `ObjectStructInfo`, then they are compatible. -2. Otherwise, if `S1` and `S2` are not both `TensorStructInfo` or both `TupleStructInfo`, etc. (besides `ObjectStructInfo`), then report an incompatibility. -3. If `S1` and `S2` are both `TupleStructInfo`: - 1. If `S1.fields` is not the same length as `S2.fields`, they are incompatible - 2. Call `check_compability(S1.fields[i], S2.fields[i])` for all `i`. If any pair of fields is incompatible, then `S1` and `S2` are incompatible. If no pair of fields is incompatible but at least one is possibly compatible, then `S1` and `S2` are possibly compatible. If all pairs of fields are compatible, then `S1` and `S2` are compatible. -4. If `S1` and `S2` are both `ShapeStructInfo`: - 1. `S2.ndim` is -1, then they are compatible. - 2. Otherwise, give an error if `S1.ndim` does not match `S2.ndim`. - 3. If `values` is not defined for `S2`, then they are compatible. - 4. If `values` is defined for `S2` but not defined for `S1`, then they are possibly compatible. - 5. If `values` is defined for both `S1` and `S2`, then the two are incompatible if `S1.values[i]` can be proven to be _not_ equal to `S2.values[i]` for some `i`. If all members can be proven to be equal, then they are compatible. Otherwise, if at least one pair of values cannot be proven to be either equal or unequal, then they are possibly compatible. -5. If `S1` and `S2` are both `TensorStructInfo`: - 1. If `S2.dtype` is not `Void` and does not match `S1.dtype`, then they are incompatible. - 2. If `S2.ndim` is not -1 and does not match `S1.ndim`, then they are incompatible. - 3. If `S2.shape` is not defined, then they are compatible. - 4. If `S2.shape` is defined and `S1.shape` is not defined, then they are possibly compatible. - 5. Otherwise, if both `shape` fields are given and either is a `Var`, then consider `S1` and `S2` compatible if the compiler can statically prove that the `Var` holds the same value as the other `shape` field, consider them possibly compatible if the compiler cannot draw a conclusion one way or the other, and consider them incompatible if the `Var` definitely has a different value from the other `shape`. - 6. If both `shape` fields are given and they are both `ShapeExpr` nodes, then `S1` and `S2` are incompatible if the compiler can prove that some dimension of `S1.shape` is _not_ equal to the corresponding dimension of `S2.shape`. Otherwise, if the all dimensions can be proven to be equal, then consider them compatible. If at least one pair of dimensions cannot be proven to be equal or unequal, consider them possibly compatible. -6. If `S1` and `S2` are both `FuncStructInfo`: - 1. If `S1` and `S2` don't both have defined `params` or both have undefined `params`, consider them incompatible. - 2. If both `S1` and `S2` have undefined `params`, consider them compatible if they have an identical `derive_func` and consider them possibly compatible if they have different `derive_func`s (as they is no further way to introspect the `derive_func` and draw static conslusions about `PackedFunc`s). - 3. If `params` is defined for both `S1` and `S2`: - 1. Consider them incompatible if the `params` have different lengths. - 2. Next, map unbound shape variables as follows: Get a variable mapping `m` by applying `get_shape_var_mapping(S1.params[i], S2.params[i])` for all values of `i`, taking the union of all resulting mappings. Next, substitute all occurrences of the shape variables in `S1` with their values in `m`. - 3. If `check_compatible(S2.params[i], S1.params[i])` (note the direction of the check: see the subtyping rule for `FuncType`) is incompatible for any `i` or if `check_compatible(S1.ret, S2.ret)` is incompatible, then they are incompatible. Otherwise, if `check_compatible(S2.params[i], S1.params[i])` is possibly compatible for any `i` or if `check_compatible(S1.ret, S2.ret)` is possibly compatible, consider `S1` and `S2` possibly compatible. Consider `S1` and `S2` compatible only if all checks are compatible. - -**Unification** - -Analogously to subtyping, we can also consider a hierarchy of structural information, considering some structural information to more or less specific than other structural information. Accordingly, we can also define a least upper bound for structural information, as with types. - -We can define an analogue to subtyping for structural information, as below. We say that `S1` is more specific than `S2` and denote it as `S1 <<: S2` (to distinguish from the notation on subtyping) based on the conditions given here. As an invariant, if `S1 <<: S2` holds, then `erase_struct_info(S1) <: erase_struct_info(S2)`, though the converse may not be true. -1. Reflexivity: `S1 <<: S1` for all `S1`. -2. Transitivity: For all `S1`, `S2`, and `S3`, if `S1 <<: S2` and `S2 <<: S3`, then `S1 <<: S3`. -3. For all `S1`, `S1 <<: ObjectStructInfo()`. -4. For `TensorStructInfo`: - 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=-1, dtype=d)`. - 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=n, dtype=Void, shape=s)`. - 3. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (not undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=n, dtype=d, shape=undefined)`. - 4. Given any datatype `d`, an arbitrary `ndim` `n`, and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, shape=s1) <<: TensorStructInfo(ndim=n, dtype=d, shape=s2)` if `s1` and `s2` are _definitely_ or _possibly_ statically equal. -5. For `ShapeStructInfo`: - 1. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (possibly undefined), `ShapeStructInfo(ndim=n, values=v) <<: ShapeStructInfo(ndim=-1)`. - 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <<: ShapeStructInfo(ndim=n, values=undefined)`. - 3. Given an arbitrary `ndim` `n` and two arbitrary sets of values `v1` and `v2` (both defined), `ShapeStructInfo(ndim=n, values=v1) <<: ShapeStructInfo(ndim=n, values=v2)` if, for all valid `i`, `v1[i]` and `v2[i]` can be proven to be _definitely_ or _possibly_ statically equal. -6. Given two lists of structural information `fields1` and `fields2`, `TupleStructInfo(fields=fields1) <<: TupleStructInfo(fields=fields2)` if `fields1` and `fields2` are the same length and for all `i`, `fields1[i] <<: fields2[i]`. -7. For `FuncStructInfo`: - 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <<: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. - 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <<: F2` only if `F1.derive_func` and `F2.derive_func` are identical. - 3. Given two lists of structural information parameters `P1` and `P2` and two structural information annotations `R1` and `R2`, `FuncStructInfo(params=P1, ret=R1) <<: FuncStructInfo(params=P2, ret=R2)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <<: P1[i]` and `R1 <<: R2`. - -Given these rules, we can define how to unify (get the LUB) of two structural information annotations as follows (in pseudocode): -```python -def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: - if S2 is ObjectStructInfo: - return S1 - if S1 is ObjectStructInfo: - return S2 - if S1 and S2 do not match types (e.g., not both TensorStructInfo, etc): - return ObjectStructInfo() - if S1 and S2 are both ShapeStructInfo: - if S1.ndim == -1: - return S1 - if S2.ndim == -1: - return S2 - if S1.ndim != S2.ndim: - return ShapeStructInfo(ndim=-1) - if S1.ndim == S2.ndim: - if S1.values is undefined: - return S1 - if S2.values is defined: - return S2 - if S1.values can be statically proven to match S2.values: - return S1 - # values either proven not to match or unknown - return ShapeStructInfo(ndim=S1.ndim) # leave values undefined - if S1 and S2 are both TensorStructInfo: - ndim = S1.ndim if S1.ndim == S2.ndim else -1 - dtype = S1.dtype if S1.dtype == S2.dtype else Void - if ( - S1.ndim == -1 or S2.ndim == -1 or S1.ndim != S2.ndim - or S1.shape is undefined or S2.shape is undefined - ): - return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined - # both shapes are defined - if S1.shape can be proven to equal S2.shape: - return S1 - # either proven to be unequal or cannot be concluded whether they are equal - return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined - if S1 and S2 are both TupleStructInfo: - if S1.fields and S2.fields are of different lengths: - return ObjectStructInfo() - return TupleStructInfo( - unify_struct_info(S1.fields[i], S2.fields[i]) - for 0 <= i < length of S1.fields - ]) - if S1 and S2 are both FuncStructInfo: - if S1.params and S2.params are not both defined or both undefined: - return ObjectStructInfo() - if S1.params and S2.params are both undefined: - # they must be the same function, not bothering to check eta-equivalence - if S1.derive_func == S2.derive_func: - return S1 - return FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive) - if S1.params and S2.params are both defined: - if S1.params and S2.params do not have the same length: - return ObjectStructInfo() - unified_params = [] - for 0 <= i < length of S1.params: - unified_param = unify_struct_info(S1.params[i], S2.params[i]) - # That is, if the params judged to be equal, use them. - # If there is some pair that is not equal, - # we can't unify these types except with ObjectStructInfo - # See the use of GLB with FuncTypes - if unified_param <<: S1.params[i] and unified_param <<: S2.params[i]: - unified_params[i] = unified_param - else: - return ObjectStructInfo() - return FuncStructInfo(params=unified_params, ret=unify_struct_info(S1.ret, S2.ret)) -``` - ### Derivation Rules -Let `Δ` be the structural information context for Relax variables (to distinguish from `Γ` for types) and let `Σ` track which shape variables are in scope. +Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track which shape variables are in scope. -1. «Prepopulate `Δ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Δ` corresponding to that `GlobalVar`.» -2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Δ[v]` for the structural information. +1. «Prepopulate `Γ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Γ` corresponding to that `GlobalVar`.» +2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Γ[v]` for the structural information. 3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. -4. For `Tuple(fields)`, the resulting structural information is `TupleStructInfo([f.struct_info for f in fields])`, after deriving the structural information for the fields recursively. +4. For `Tuple(fields)`, suppose that `fields` is comprised of expressions `E1`, `E2`, ..., `En`. Let the `StructInfo` for these expressions be `S1`, `S2`, ..., `Sn`, respectively. Then the resulting `StructInfo` is `TupleStructInfo(fields=[S1, S2, ..., Sn])`. 5. For `ShapeExpr(values)`, the resulting structural information is `ShapeStructInfo(ndim, values)`, where `ndim` is the length of `values`. 6. For `If(cond, true_branch, false_branch)`, we compare the structural information of `true_branch` and `false_branch` (call these `S_t` and `S_f`, respectively). The resulting structural information is `unify_struct_info(S_t, S_f)`. 7. For `SeqExpr(blocks, body)`: 1. For each binding block in `blocks` (call the current one `block`): - 1. Process each binding in the block, updating `Δ` and `Σ` accordingly (this is discussed in detail below). - 2. If `block` is a `DataflowBlock`, then remove all `DataflowVar`s introduced in `block` from `Δ` before proceeding to the next block. + 1. Process each binding in the block, updating `Γ` and `Σ` accordingly (this is discussed in detail below). + 2. If `block` is a `DataflowBlock`, then remove all `DataflowVar`s introduced in `block` from `Γ` before proceeding to the next block. 2. Next derive the structural information for `body`. Let us call this `S`. - 3. Remove all Relax variables introduced in `blocks` from `Δ` and all shape variables introduced in `blocks` from `Σ`. - 4. The structural information of the entire `SeqExpr` is `erase_to_well_defined(S, Δ, Σ)`. + 3. Remove all Relax variables introduced in `blocks` from `Γ` and all shape variables introduced in `blocks` from `Σ`. + 4. The structural information of the entire `SeqExpr` is `erase_to_well_defined(S, Γ, Σ)`. 8. For handling variable bindings: - 1. If `v` is the argument to a function, then if `v` has a structural annotation `S`, set `Δ[v]` to `S`. Add any unbound shape variables in `S` to `Σ`. If `v` does not have a structural annotation, set `Δ[v]` to `ObjectStructInfo()`. + 1. If `v` is the argument to a function, then if `v` has a structural annotation `S`, set `Γ[v]` to `S`. Add any unbound shape variables in `S` to `Σ`. If `v` does not have a structural annotation, set `Γ[v]` to `ObjectStructInfo()`. 2. In the general `VarBinding(v, e)`: - 1. If `e` is a function literal, then recursion is permitted. In this case, `v` must have a structural annotation `Sv`. Derive the structural information for `e` as follows: Set `Δ[v]` to `Sv`, apply the normal rule for function literals (given below) to `e` to derive structural information `Se`, and finally remove `v` from `Δ`. Raise an error if `Se` and `Sv` are not compatible (via `check_compatibility`). + 1. If `e` is a function literal, then recursion is permitted. In this case, `v` must have a structural annotation `Sv`. Derive the structural information for `e` as follows: Set `Γ[v]` to `Sv`, apply the normal rule for function literals (given below) to `e` to derive structural information `Se`, and finally remove `v` from `Γ`. Raise an error if `Se` and `Sv` are not compatible (via `check_compatibility`). 2. Otherwise, derive the structural information of `e` and call it `Se`. - 3. If `v` has a structural annotation `Sv`, then apply `check_compatibility` to `Sv` and `Se`. If they are compatible, then set `Δ[v]` to `Sv` (respecting the user's intent in giving an annotation). Give a warning if `Sv` is more specific than `Se`. If are not compatible, then raise an error. - 4. If `v` does not have a structural annotation, then set `Δ[v]` to `Se`. + 3. If `v` has a structural annotation `Sv`, then apply `check_compatibility` to `Sv` and `Se`. If they are compatible, then set `Γ[v]` to `Sv` (respecting the user's intent in giving an annotation). Give a warning if `Sv` is more specific than `Se`. If are not compatible, then raise an error. + 4. If `v` does not have a structural annotation, then set `Γ[v]` to `Se`. 3. For `MatchCast(v, value, S)`: 1. Derive the structural information of `value` and call it `Sv`. 2. Add any new shape variables in `S` to `Σ`. - 3. If `S <<: Sv` and `Sv <<: S` do not both hold, give a warning, as this indicates a cast that will always fail at run time. - 4. If `v` is given and it has a structural annotation `S'`, then give an error if `S` and `S'` are not compatible via `check_compatibility`. If they are compatible, then set `Δ[v]` to `S'` (respecting the user's intent in giving an annotation). (TODO: It doesn't seem very sensible to have a dynamic cast and give a different annotation, perhaps we should simply not permit doing that.) - 5. If `v` is given and it does not have a structural annotation, then set `Δ[v]` to `S`. -9. For `TupleGetItem(tuple_value, i)`, derive the structural information for `tuple_value` and call it `St`. Raise an error if `St` is not `TupleStructInfo`. If `St` is `TupleStructInfo(fields)`, then raise an error if `fields` value has less than `i + 1` members (this should not happen if type checking has passed) and use `fields[i]` (zero-based) as the structural information for the `TupleGetItem`. + 3. If `S <: Sv` and `Sv <: S` both do not hold, give a warning, as this indicates a cast that will _always_ fail at run time. (Conversely, if `Sv <: S`, then the cast will always succeed.) + 4. If `v` is given and it has a structural annotation `S'`, then give an error if `S <: S'` does not hold. If they are compatible, then set `Γ[v]` to `S'` (respecting the user's intent in giving an annotation). (TODO: It doesn't seem very sensible to have a dynamic cast and give a different annotation, perhaps we should simply not permit doing that.) + 5. If `v` is given and it does not have a structural annotation, then set `Γ[v]` to `S`. +9. For `TupleGetItem(tuple_value, i)`: + 1. Derive the structural information for `tuple_value` and call it `St`. + 2. Raise an error if `St` is not `TupleStructInfo`. + 3. If `St` is `TupleStructInfo(fields)`, then raise an error if `fields` value has less than `i + 1` members. + 4. Use `fields[i]` (zero-based) as the structural information for the `TupleGetItem`. 10. For an `ExternFunc` node, the resulting structural information is `FuncStructInfo(params=None, ret=ObjectStructInfo(), derive_func=default_derive)`. 11. For `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`: 1. For a call to an `Op`, we use the manually defined `FInferStructInfo` macro if it has been defined and `ObjectStructInfo()` if it has not. `FInferStructInfo` is a function that takes in the call node and returns the structural information of the result. 2. Otherwise, derive the structural information for `op` and call it `Sf`. Next derive the structural information for the args and call it `S1`, `S2`, ..., and `Sn`. 1. Give an error if `Sf` is not `FuncStructInfo`. 2. If the `derive_func` field of `Sf` is defined, then apply the `derive_func` macro to the call node to derive the structural information for the call node, ignoring the `ret` field of `Sf`. - 3. Otherwise, `params` must be defined. Give an error if the length of `params` does not match the number of call arguments. + 3. Otherwise, `params` must be defined. Give an error if the length of `params` does not match the number of call arguments. Let the members of params be `P1`, `P2`, ..., `Pn`. 4. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. Replace all variables in `m` with their mapping in `Sf`. - 5. After the substitutions, give an error if `check_compatibility` indicates that the `i`th member of `params` and `Si` are incompatible for some `i` (warn if they are only possibly compatible). - 6. Use `erase_to_well_defined(Sf.ret, Δ, Σ)` as the resulting structural information. + 5. After the substitutions, give an error if `Pi <: Si` does not hold for some `i` (give a warning if it _possibly_ holds). + 6. Use `erase_to_well_defined(Sf.ret, Γ, Σ)` as the resulting structural information. 12. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: 1. Let `S1`, `S2`, ..., `Sn` be the structural information of the parameters. If `vi` has a structural annotation, then use that annotation for `Si`; if not, use `ObjectStructInfo()`. Let `Sr` be `ret_struct_info` if it is defined and `ObjectStructInfo()` if not. - 2. If the function is bound to a `GlobalVar` `gv`, set `Δ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. - 3. For each of the `vi`, set `Δ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. + 2. If the function is bound to a `GlobalVar` `gv`, set `Γ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. + 3. For each of the `vi`, set `Γ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. 4. Derive the structural information for `body`, calling it `Sb`. 5. Give an error if `Sb` is incompatible with `Sr` via `check_compatibility` (warn if only possibly compatible). - 6. If `ret_struct_info` is defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], ret_struct_info)` as the structural information for the function. If `ret_struct_info` is not defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], erase_to_well_defined(Sb, Δ, Σ))`. - 7. Remove all variables added to `Δ` and `Σ` during the derivation. + 6. If `ret_struct_info` is defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], ret_struct_info)` as the structural information for the function. If `ret_struct_info` is not defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], erase_to_well_defined(Sb, Γ, Σ))`. + 7. Remove all variables added to `Γ` and `Σ` during the above steps of the derivation. ### Note on Proving Shapes Equivalent and Eliminating Dynamic Checks @@ -914,7 +638,41 @@ Since most dynamic structure checks are done for safety, it may be feasible to i A further case that may be of interest might be using an explicit wildcard dimension (e.g., using `tir::Any`) to allow for dimensions to be specified as "unknown" in function return shapes. As described at present, the only way for a function to specify a partly unknown return shape is to make the entire return shape unknown (`RuntimeDepShape`), which loses partial shape information. -This addition would entail some, as `FInferStructInfo` and `derive_func` macros would have to deal with potential `tir::Any` nodes. However, the advantage of implementing it would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using undefined `shape` fields means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchCast` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. +This addition would entail some, as `FInferStructInfo` and `derive_func` macros would have to deal with potential `tir::Any` nodes. However, the advantage of implementing it would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using more general `StructInfo` means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchCast` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. + +## Traditional Types + +For comparison with Relay, it may be useful to simplify `StructInfo` into more traditional types that do not contain any expressions (such as in `TensorStructInfo` and `ShapeStructInfo`). We can define Relax types as follows: + +``` +Type ::= + DynTensorType(ndim: int, dtype: DataType) + | ShapeType(ndim: int) + | TupleType(fields: [Type]) + | PackedFuncType() + | FuncType(arg_types: [Type], ret_type: Type) + | ObjectType() +``` + +We can "erase" `StructInfo` into types by the following procedure (in psuedocode): +```python +def erase_struct_info(si: StructInfo) -> Type: + if si is TensorStructInfo: + return DynTensorType(ndim=si.ndim, dtype=si.dtype) + if si is ShapeStructInfo: + return ShapeType(ndim=si.ndim) + if si is TupleStructInfo: + return TupleType(fields=[erase_struct_info(field) for field in si.fields]) + if si is FuncStructInfo: + # this should be the case only for packed funcs + if si.params is not specified: + return PackedFuncType() + return FuncType( + arg_types=[erase_struct_info(arg_type) for arg_type in si.params], + ret_type=erase_struct_info(si.ret)) + # only remaining case is ObjectStructInfo + return ObjectType() +``` # Detailed Semantics @@ -922,7 +680,7 @@ This addition would entail some, as `FInferStructInfo` and `derive_func` macros In the `IRModule`, every mapping of a `GlobalVar` to a `Function` node or a TIR `PrimFunc` should be processed first and added to the global scope. «Global functions that have a `global_symbol` attribute should be externally linked, meaning that they can be invoked as program entry points; those that do not have a `global_symbol` attribute can be called only from within the global functions in the `IRModule`.» -The rules for evaluating `Function` nodes into closures are given below. TIR `PrimFunc`s evaluate into objects that are opaque to Relax; these objects have type `Object` and can be used only by the `call_tir` operator. None of the values in global scope is mutable. Execution of a Relax function in an IR module thus begins by evaluating all globally visible functions into a form in which they can be accessed. +The rules for evaluating `Function` nodes into closures are given below. TIR `PrimFunc`s evaluate into objects that are opaque to Relax; these objects are of `ObjectStructInfo` and can be used only by the `call_tir` operator. None of the values in global scope is mutable. Execution of a Relax function in an IR module thus begins by evaluating all globally visible functions into a form in which they can be accessed. ## Evaluating Expressions @@ -931,11 +689,11 @@ For each expression, we define how it affects the program's visible state and th 1. The node `Constant(value)` creates a new tensor whose contents are `value`. 2. A variable (whether `Var`, `DataflowVar` , or `GlobalVar`) evaluates to the stored value for that variable in the current scope. 3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. -4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per type checking, must evaluate to a tuple) and then returning the `i`th field of the result. +4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per `StructInfo` checking, must evaluate to a tuple) and then returning the `i`th field of the result. 5. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. 6. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. 7. The node `If(cond, true_branch, false_branch)` is evaluated as follows: - 1. First `cond` is evaluated. Let the result be `r` (per type checking, it must be a bool scalar). + 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a bool scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. 8. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. @@ -971,9 +729,6 @@ These semantic rules assume a single thread of evaluation on a single host machi The above evaluation rules are general, but leave much room for implementations of operators to specify custom semantics. Certain operators are used to perform common operations and will be discussed here as well. -- `call_tir(prim_func, arg1, arg2, ..., argn, shape, type_args=[aT])`: `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). The `shape` argument gives the shapes of the result of calling the TIR `PrimFunc`: It must be either of `ShapeType` (corresponding to returning a single tensor) or `TupleType` whose members are `ShapeType` (corresponding to returning a tuples of tensors). The type arg `aT` gives the type of the result of calling the `PrimFunc` and it must correspond to `shape` (namely, if `shape` is of `ShapeType`, `aT` must be a `DynTensorType`; if `shape` is of `TupleType`, `aT` must be a `TupleType` whose fields are `ShapeType`). `aT` is used especially to provide the `dtype` of returned tensors. - - Based on `shape`, the resulting tensor or tuple `r` will be allocated according to the sizes given in `shape`. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. - -- «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type.» +- `call_tir(prim_func, arg1, arg2, ..., argn, sinfo_args=[aS])`: `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). The `StructInfo` arg `aS` gives the `StructInfo` of the result of calling the `PrimFunc`; it must be a `TensorStructInfo` with a `shape` field corresponding to a constant shape expression and a non-`Void` `dtype`, denoting the shape of the resulting tensor, or a a `TupleStringInfo` where all the `fields` are `TensorStructInfo`. Based on `aS`, the resulting tensor or tuple `r` will be allocated according to the sizes given in their `shape` fields. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. «If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. +- «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, sinfo_args=[aS])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the results, so purity is not assumed. `aS` denotes the `StructInfo` for the result.» - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. From 6272880e0f7a09b8bae3a07b70219ccad16abaaa Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 31 Jan 2023 20:47:11 -0500 Subject: [PATCH 19/30] First draft of PrimValues --- relax_spec.md | 92 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 30 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index cc911b314a..fc7a2bf33f 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -58,6 +58,7 @@ DataType ::= Int(bitwidth: int) StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, ndim: int) | ShapeStructInfo(values: [PrimExpr]?, ndim: int) + | PrimStructInfo(dtype: DataType) | ObjectStructInfo() | TupleStructInfo(fields: [StructInfo]) | FuncStructInfo(params: [StructInfo]?, ret: StructInfo, derive_func: EnvFunc?*) @@ -71,6 +72,9 @@ Expr ::= Constant(data: NDArray) | GlobalVar(name_hint: string) | Tuple(fields: [Expr]) | SeqExpr(blocks: [BindingBlock], body: Expr) + | PrimValue(value: PrimExpr) + | StringImm(value: string) + | DataTypeImm(value: DataType) | Function(params: [Var], body: Expr, ret_struct_info: StructInfo?, attrs: Attrs?) | If(cond: Expr, true_branch: Expr, false_branch: Expr) | ExternFunc(global_symbol: string) @@ -106,15 +110,18 @@ This specification provides a more detailed description of what each expression 3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. -6. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. +5. `PrimValue` nodes construct immutable scalar values from `PrimExpr`s, primarily for interacting with `ExternFunc`s or operators. These scalars are boxed within TVM objects, allowing them to be nested inside TVM's containers. (By contrast, zero-dimensional tensors defined via `Constant` are mutable.) +6. `StringImm` nodes construct strings, intended primarily for interacting with `ExternFunc`s or operators. +7. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators. +8. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s). 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. -7. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. -8. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. -9. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: +9. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. +10. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. +11. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. 1. The types must match: All `StructInfo` variants correspond to a category of value value (`TensorStructInfo` to a tensor value, `ShapeStructInfo` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: @@ -126,8 +133,8 @@ This specification provides a more detailed description of what each expression The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. -10. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. -11. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. +12. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +13. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. @@ -155,9 +162,10 @@ Exiting with an error and infinitely looping are traditionally considered "[dive Analogously to a type system in most languages, Relax tracks structural information (referred to as `StructInfo` in the implementation) related to the categories of values in Relax: 1. `TensorStructInfo` corresponds to tensor values, giving the scalar data type, the number of dimensions (rank), and an expression that computes the tensor's shape (either a `ShapeExpr` or a `Var`), all of which are optional. 2. `TupleStructInfo` corresponds to tuple values, giving the `StructInfo` for each member of the tuple. -3. `ShapeStructInfo` corresponds to shape values, optionally giving the number of dimensions in the shape and an expression that computes the shape's dimensions (either a `ShapeExpr` or a `Var`). -4. `FunctionStructInfo` corresponds to function values (closures) and `PackedFunc`s (external functions), giving the types of the parameters, the return type, «and whether the function is pure.» -5. `ObjectStructInfo` is a parent to all Relax `StructInfo` and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. +3. `PrimStructInfo` corresponds to `PrimValue`s (immutable scalar values), giving their TIR datatype. +4. `ShapeStructInfo` corresponds to shape values, optionally giving the number of dimensions in the shape and an expression that computes the shape's dimensions (either a `ShapeExpr` or a `Var`). +5. `FunctionStructInfo` corresponds to function values (closures) and `PackedFunc`s (external functions), giving the types of the parameters, the return type, «and whether the function is pure.» +6. `ObjectStructInfo` is a parent to all Relax `StructInfo` and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. `StructInfo` is assigned to every variable in scope and every type of expression based on the values it returns via a set of inference rules defined later in the specification, making use of subtyping to assign more general `StructInfo` when a more specific one cannot be determined. «Relax is strongly typed, meaning that if the `StructInfo` inferred is less specific than the one expected, an error will be issued and an explicit check via `MatchCast` will be required.» @@ -180,7 +188,8 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. -- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. +- *Primitive values* (`PrimValue`s) represent immutable scalar values that are primarily intended for being passed to external procedures, like calls to `PackedFunc`s. As a rule of thumb, scalar values intended for arithmetical computations should be 0-rank tensors while scalar values meant to serve as metadata should be `PrimValue`s. +- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. ## Representation of Values at Run Time @@ -192,6 +201,7 @@ Possible specification in terms of the TVM object system: - Tuples are represented using TVM ADTs (algebraic data types), which are arrays of TVM objects with a tag (see `include/tvm/runtime/container/adt.h`). Tuples use a tag of 0. - At run time, closures are represented as a `ClosureObj` (see `include/tvm/runtime/container/closure.h`); in the Relax VM these more specifically use the `VMClosureObj` (see [`https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h`](https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h)). - Shape values are represented at run time as a `ShapeTuple` (see `include/tvm/runtime/container/shape_tuple.h`). +- Strings are represented using TVM's `String` container (see `include/tvm/runtime/container/string.h`). - We require objects other than the above values used by and returned by `PackedFunc` to inherit from TVM's `Object` class (defined in `include/tvm/runtime/Object.h`). Note that `PackedFunc`s are capable of using and returning all TVM POD (plain-old data) values (see `include/tvm/runtimes/packed_func.h`), which includes some representations that do not inherit from `Object`. In the future, we may define semantics for other values, but at present, these are *unsupported* in Relax and we make no guarantees about the semantics of calling `PackedFunc`s that use or return anything that does not inherit from `Object`. # Variable Scoping @@ -284,6 +294,7 @@ Tensor shapes are the primary motivation for including structural information in The structural information in Relax corresponds to the values in the language: * `TensorStructInfo` describes tensor values. The `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` (with `ShapeStructInfo`). If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation that returns a shape value, which can be useful for memory planning. * `ShapeStructInfo` describes shape values. It has an `ndim` field that gives the number of dimensions in the shape (with -1 indicating that it is statically unknown). It additionally has an optional `values` field. If defined, `values` gives a list of `PrimExpr`s that indicate how to compute the dimensions of the shape, potentially providing further information for static analyses. +* `PrimStructInfo` describes `PrimValue`s, giving their TIR datatype. * `TupleStructInfo` describes tuple values, namely by giving the `StructInfo` for each of the tuple's members via `fields`. * `FuncStructInfo` describes closure values or `PackedFunc`s. There are two ways in which to specify `FuncStructInfo`: 1. By specifying `params` and `ret` (for closures). `params` gives the `StructInfo` corresponding to each of the function's parameters and `ret` gives the `StructInfo` corresponding to the result of calling the function. In this case, the `derive_func` field is left undefined. @@ -317,11 +328,12 @@ This section describes the run-time checking performed by `MatchCast(var, value, 2. If `shape` is a `ShapeExpr`, then compare the fields of the `ShapeExpr` to the concrete shape of `value`, dimension by dimension (comparing the `i`th field of the `ShapeExpr` to the `i`th dimension of the shape of `value`). Give an error if the number of the dimensions does not match the number of fields in the `ShapeExpr`. 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. 2. Otherwise, evaluate the field of the `ShapeExpr` and ensure that it matches the concrete value of the dimension. -3. If `struct_info` is `ShapeStructInfo(ndim, values)`, then check that `value` is a shape value, that it has `ndim` dimensions (if `ndim` is not -1). If `values` is defined, then compare it to the concrete shape value (comparing the `i`th member of `values` to the `i`th field of the shape value): +3. If `struct_info` is `PrimStructInfo(dtype)`, then check that `value` is a `PrimValue` and that the underlying scalar has datatype `dtype` in TIR (according to TIR's type-checking rules). +4. If `struct_info` is `ShapeStructInfo(ndim, values)`, then check that `value` is a shape value, that it has `ndim` dimensions (if `ndim` is not -1). If `values` is defined, then compare it to the concrete shape value (comparing the `i`th member of `values` to the `i`th field of the shape value): 1. If the `i`th member of `values` consists of only an unbound shape variable, then bind that variable to the `i`th field of the the concrete shape value. 2. Otherwise, evaluate the `i`th member of `values` and check that it is equal to teh `i`th field of the concrete shape value. -4. If `struct_info` is `TupleStructInfo(fields)`, then check that `value` is a tuple value with `n` fields, where `n` is the length of `fields`. Also recursively check the `i`th field of the tuple value against the `i`th member of `fields`. -5. If `struct_info` is `FuncStructInfo(params, ret, derive_func)`, then if `params` is defined, check that `value` is a closure value; if `derive_func` is defined, check that `value` is a `PackedFunc`. No further validation may be done on a `PackedFunc`. «If `value` is a closure value, then it can contain run-time structural information indicating the structural information of its intended arguments and return value that can be compared against `params` and `ret`.» +5. If `struct_info` is `TupleStructInfo(fields)`, then check that `value` is a tuple value with `n` fields, where `n` is the length of `fields`. Also recursively check the `i`th field of the tuple value against the `i`th member of `fields`. +6. If `struct_info` is `FuncStructInfo(params, ret, derive_func)`, then if `params` is defined, check that `value` is a closure value; if `derive_func` is defined, check that `value` is a `PackedFunc`. No further validation may be done on a `PackedFunc`. «If `value` is a closure value, then it can contain run-time structural information indicating the structural information of its intended arguments and return value that can be compared against `params` and `ret`.» ### Checking Structural Information at the Start and End of a Function @@ -368,7 +380,8 @@ Note that judging subtyping requires potentially reasoning about arbitrary `Shap 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <: ShapeStructInfo(ndim=n, values=undefined)`. 3. Given an arbitrary `ndim` `n` and two arbitrary sets of values `v1` and `v2` (both defined), `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` if, for all valid `i`, `v1[i]` and `v2[i]` can be proven to be _definitely_ statically equal. We say that `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` _possibly_ holds if `v1` and `v2` are _possibly_ statically equal. 6. Given two lists of `StructInfo` `fields1` and `fields2`, `TupleStructInfo(fields=fields1) <: TupleStructInfo(fields=fields2)` if `fields1` and `fields2` are the same length and for all `i`, `fields1[i] <: fields2[i]`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships for the fields only possibly holds. -7. For `FuncStructInfo`: +7. For `PrimStructInfo`, `PrimStructInfo(dt1) <: PrimStructInfo(dt2)` holds if `dt1` and `dt2` are the same. That is, we do not have subtyping for TIR datatypes or `PrimStructInfo`. +8. For `FuncStructInfo`: 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <: F2` only if `F1.derive_func` and `F2.derive_func` are identical. 3. Given two lists of `StructInfo` parameters `P1` and `P2` and two `StructInfo` annotations `R1` and `R2`, `FuncStructInfo(params=P1, ret=R1) <: FuncStructInfo(params=P2, ret=R2)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <: P1[i]` and `R1 <: R2`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships given only possibly holds. @@ -385,6 +398,10 @@ def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: return S2 if S1 and S2 do not match types (e.g., not both TensorStructInfo, etc): return ObjectStructInfo() + if S1 and S2 are both PrimStructInfo: + if S1.dtype == S2.dtype: + return S1 + return ObjectStructInfo() if S1 and S2 are both ShapeStructInfo: if S1.ndim == -1: return S1 @@ -478,6 +495,8 @@ def erase_to_well_defined( if s is ObjectStructInfo: return s + if s is PrimStructInfo: + return s if s is TensorStructInfo: if s.shape is defined: if (s.shape is a Relax var that is not in var_scope @@ -530,6 +549,8 @@ For clarity, additional detail on how the mapping should be constructed is given def get_shape_var_mapping(S1: StructInfo, S2: StructInfo) -> {tir::Var, PrimExpr}: if S1 and S2 are not the same type: return {} + if S1 and S2 are both PrimStructInfo: + return {} if S1 and S2 are both TupleStructInfo: if S1.fields and S2.fields don't have the same length: return {} @@ -577,17 +598,20 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 1. «Prepopulate `Γ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Γ` corresponding to that `GlobalVar`.» 2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Γ[v]` for the structural information. 3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. -4. For `Tuple(fields)`, suppose that `fields` is comprised of expressions `E1`, `E2`, ..., `En`. Let the `StructInfo` for these expressions be `S1`, `S2`, ..., `Sn`, respectively. Then the resulting `StructInfo` is `TupleStructInfo(fields=[S1, S2, ..., Sn])`. -5. For `ShapeExpr(values)`, the resulting structural information is `ShapeStructInfo(ndim, values)`, where `ndim` is the length of `values`. -6. For `If(cond, true_branch, false_branch)`, we compare the structural information of `true_branch` and `false_branch` (call these `S_t` and `S_f`, respectively). The resulting structural information is `unify_struct_info(S_t, S_f)`. -7. For `SeqExpr(blocks, body)`: +4. For `PrimValue(prim_expr)`, the resulting `StructInfo` is `PrimStructInfo(dt)`, where `dt` is the datatype of `prim_expr`, derived according to the type-checking rules for TIR. +5. For `StringImm(s)`, the resulting `StructInfo` is `ObjectStructInfo()`. +6. For `DataTypeImm(dt)`, the resulting `StructInfo` is `ObjectStructInfo()`. +7. For `Tuple(fields)`, suppose that `fields` is comprised of expressions `E1`, `E2`, ..., `En`. Let the `StructInfo` for these expressions be `S1`, `S2`, ..., `Sn`, respectively. Then the resulting `StructInfo` is `TupleStructInfo(fields=[S1, S2, ..., Sn])`. +8. For `ShapeExpr(values)`, the resulting structural information is `ShapeStructInfo(ndim, values)`, where `ndim` is the length of `values`. +9. For `If(cond, true_branch, false_branch)`, we compare the structural information of `true_branch` and `false_branch` (call these `S_t` and `S_f`, respectively). The resulting structural information is `unify_struct_info(S_t, S_f)`. +10. For `SeqExpr(blocks, body)`: 1. For each binding block in `blocks` (call the current one `block`): 1. Process each binding in the block, updating `Γ` and `Σ` accordingly (this is discussed in detail below). 2. If `block` is a `DataflowBlock`, then remove all `DataflowVar`s introduced in `block` from `Γ` before proceeding to the next block. 2. Next derive the structural information for `body`. Let us call this `S`. 3. Remove all Relax variables introduced in `blocks` from `Γ` and all shape variables introduced in `blocks` from `Σ`. 4. The structural information of the entire `SeqExpr` is `erase_to_well_defined(S, Γ, Σ)`. -8. For handling variable bindings: +11. For handling variable bindings: 1. If `v` is the argument to a function, then if `v` has a structural annotation `S`, set `Γ[v]` to `S`. Add any unbound shape variables in `S` to `Σ`. If `v` does not have a structural annotation, set `Γ[v]` to `ObjectStructInfo()`. 2. In the general `VarBinding(v, e)`: 1. If `e` is a function literal, then recursion is permitted. In this case, `v` must have a structural annotation `Sv`. Derive the structural information for `e` as follows: Set `Γ[v]` to `Sv`, apply the normal rule for function literals (given below) to `e` to derive structural information `Se`, and finally remove `v` from `Γ`. Raise an error if `Se` and `Sv` are not compatible (via `check_compatibility`). @@ -600,13 +624,13 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 3. If `S <: Sv` and `Sv <: S` both do not hold, give a warning, as this indicates a cast that will _always_ fail at run time. (Conversely, if `Sv <: S`, then the cast will always succeed.) 4. If `v` is given and it has a structural annotation `S'`, then give an error if `S <: S'` does not hold. If they are compatible, then set `Γ[v]` to `S'` (respecting the user's intent in giving an annotation). (TODO: It doesn't seem very sensible to have a dynamic cast and give a different annotation, perhaps we should simply not permit doing that.) 5. If `v` is given and it does not have a structural annotation, then set `Γ[v]` to `S`. -9. For `TupleGetItem(tuple_value, i)`: +12. For `TupleGetItem(tuple_value, i)`: 1. Derive the structural information for `tuple_value` and call it `St`. 2. Raise an error if `St` is not `TupleStructInfo`. 3. If `St` is `TupleStructInfo(fields)`, then raise an error if `fields` value has less than `i + 1` members. 4. Use `fields[i]` (zero-based) as the structural information for the `TupleGetItem`. -10. For an `ExternFunc` node, the resulting structural information is `FuncStructInfo(params=None, ret=ObjectStructInfo(), derive_func=default_derive)`. -11. For `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`: +13. For an `ExternFunc` node, the resulting structural information is `FuncStructInfo(params=None, ret=ObjectStructInfo(), derive_func=default_derive)`. +14. For `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`: 1. For a call to an `Op`, we use the manually defined `FInferStructInfo` macro if it has been defined and `ObjectStructInfo()` if it has not. `FInferStructInfo` is a function that takes in the call node and returns the structural information of the result. 2. Otherwise, derive the structural information for `op` and call it `Sf`. Next derive the structural information for the args and call it `S1`, `S2`, ..., and `Sn`. 1. Give an error if `Sf` is not `FuncStructInfo`. @@ -615,7 +639,7 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 4. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. Replace all variables in `m` with their mapping in `Sf`. 5. After the substitutions, give an error if `Pi <: Si` does not hold for some `i` (give a warning if it _possibly_ holds). 6. Use `erase_to_well_defined(Sf.ret, Γ, Σ)` as the resulting structural information. -12. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: +15. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: 1. Let `S1`, `S2`, ..., `Sn` be the structural information of the parameters. If `vi` has a structural annotation, then use that annotation for `Si`; if not, use `ObjectStructInfo()`. Let `Sr` be `ret_struct_info` if it is defined and `ObjectStructInfo()` if not. 2. If the function is bound to a `GlobalVar` `gv`, set `Γ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. 3. For each of the `vi`, set `Γ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. @@ -648,6 +672,7 @@ For comparison with Relay, it may be useful to simplify `StructInfo` into more t Type ::= DynTensorType(ndim: int, dtype: DataType) | ShapeType(ndim: int) + | PrimType(dtype: DataType) | TupleType(fields: [Type]) | PackedFuncType() | FuncType(arg_types: [Type], ret_type: Type) @@ -661,6 +686,8 @@ def erase_struct_info(si: StructInfo) -> Type: return DynTensorType(ndim=si.ndim, dtype=si.dtype) if si is ShapeStructInfo: return ShapeType(ndim=si.ndim) + if si is PrimStructInfo: + return PrimType(dtype=si.dtype) if si is TupleStructInfo: return TupleType(fields=[erase_struct_info(field) for field in si.fields]) if si is FuncStructInfo: @@ -689,20 +716,23 @@ For each expression, we define how it affects the program's visible state and th 1. The node `Constant(value)` creates a new tensor whose contents are `value`. 2. A variable (whether `Var`, `DataflowVar` , or `GlobalVar`) evaluates to the stored value for that variable in the current scope. 3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. -4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per `StructInfo` checking, must evaluate to a tuple) and then returning the `i`th field of the result. -5. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. -6. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. -7. The node `If(cond, true_branch, false_branch)` is evaluated as follows: +4. The node `PrimType(prim_expr)` evaluates the `PrimExpr` `prim_expr` first, obtaining a resulting `pv`. It then creates an immutable `PrimValue` containing `pv`. +5. The node `StringImm(s)` creates an immutable string container whose contents is `s`. It does not necessarily have to be a _new_ string container if, for example, string interning is implemented. +6. The node `DataTypeImm(dt)` creates a new immutable datatype representation. +7. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per `StructInfo` checking, must evaluate to a tuple) and then returning the `i`th field of the result. +8. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. +9. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. +10. The node `If(cond, true_branch, false_branch)` is evaluated as follows: 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a bool scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. -8. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. -9. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: +11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. +12. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» 2. Otherwise, first evaluate `op` (it must evaluate to a closure or `PackedFunc`). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. 1. If `op` evaluated to a closure, push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) 2. If `op` evaluated to a `PackedFunc`, simply invoke it. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. -9. For the node `SeqExpr(blocks, body)`, we evaluate as follows: +13. For the node `SeqExpr(blocks, body)`, we evaluate as follows: 1. Push a new scope onto the stack. 2. Iterate through the `BindingBlock`s in `blocks` in order. We will call the current one `block`. For each binding in `Block`: 1. If the binding is `MatchCast(var, value, struct_info)`, perform the structure matching and shape variable updates as described in the structural information section. If `var` is provided, `var` will be bound to `value` in the current scope; this assignment is aliasing and no new value is allocated. If `var` is not provided, then the structural check is performed and shape variables are updated, but no new binding is introduced. @@ -717,6 +747,8 @@ Optimizations are allowed to reorder and modify the operations of a program in a «Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchCast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» +For immutable containers like those for the results of `PrimValue`, `StringImm`, and `DataTypeImm`, it is not required for the results of evaluating these expressions to be _new_ containers—it is permitted for the compiler to reuse existing objects provided that the values contained within are identical. This optimization is called [interning](https://en.wikipedia.org/wiki/String_interning). However, for operations that return new mutable values (in particular, operations that return tensor values), those _must_ be newly allocated, since reusing values can affect the behavior under aliasing. + The specification makes no guarantees about certain memory-related properties and hence also does not consider them to be "visible behaviors": - Whether an allocation happens at a given point. Compiler implementations are permitted to reuse already-allocated memory if it would not interfere with visible state in any other way, per the aliasing rules (`PackedFunc`s or operators may mutate values that are passed to them and those mutations should be visible as per aliasing in this specification). Copying values or sharing representations (e.g., interning constants) between values may be done only if they will not affect any other visible behaviors, dependent on the aliasing behavior. From b88ce9edc0a0dd774af410e31289545a2d2f3fdb Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 17:19:02 -0500 Subject: [PATCH 20/30] Example of what DataTypeImm is used for --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index fc7a2bf33f..d502302be2 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -112,7 +112,7 @@ This specification provides a more detailed description of what each expression 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. 5. `PrimValue` nodes construct immutable scalar values from `PrimExpr`s, primarily for interacting with `ExternFunc`s or operators. These scalars are boxed within TVM objects, allowing them to be nested inside TVM's containers. (By contrast, zero-dimensional tensors defined via `Constant` are mutable.) 6. `StringImm` nodes construct strings, intended primarily for interacting with `ExternFunc`s or operators. -7. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators. +7. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators (e.g., for TIR intrinsics that take a datatype as an input). 8. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s). 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. From 1d219e50fe5697703274fc9f4261a735c249148a Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 17:22:08 -0500 Subject: [PATCH 21/30] Restrictions on PrimValue and PrimStructInfo --- relax_spec.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index d502302be2..8b02492e61 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -280,6 +280,8 @@ The following criteria apply to all programs (including before normalization): 13. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. 14. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. 15. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. +16. `PrimValue` nodes are intended only to be used with `value`s consisting of TIR `IntImm`s and `FloatImm`s. +17. `PrimStructInfo` annotations should use only the `Int` and `Float` datatypes for their `dtype` fields. Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. From eae9f7da9873cbd485899df9f988ab4cf775969f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 17:39:33 -0500 Subject: [PATCH 22/30] Unify the notation used for TIR dtypes and Relax dtypes --- relax_spec.md | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 8b02492e61..a3c1878556 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -51,10 +51,11 @@ PrimExpr ::= | Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr) # (others may be added later, as deemed necessary) -DataType ::= Int(bitwidth: int) - | Float(bitwidth: int) - | Bool() - | Void() +# Also from TIR +DataType ::= Int(bits: int, lanes: int) + | UInt(bits: int, lanes: int) + | Float(bits: int, lanes: int) + | Handle(bits: int, lanes: int) StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, ndim: int) | ShapeStructInfo(values: [PrimExpr]?, ndim: int) @@ -99,7 +100,19 @@ Binding ::= Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) ``` -*The `derive_func` field of `FuncStructInfo` is a macro in the meta-language: Given a function call and the variable mapping context, return the `StructInfo` of the result. This field is used only at compile time for reasoning about the `StructInfo` of calls to `ExternFunc`s. +### Notes on `derive_func` + +The `derive_func` field of `FuncStructInfo` is a macro in the meta-language: Given a function call and the variable mapping context, return the `StructInfo` of the result. This field is used only at compile time for reasoning about the `StructInfo` of calls to `ExternFunc`s. + +### Notes on `DataType` and Related Terminology + +The representation of datatypes, `DataType`, in the above AST is taken directly from TIR. However, the usage of datatypes in Relax is more restricted than in TIR. +1. The `lanes` field for the `Int`, `UInt`, and `Float` datatypes must always be 1; we do not directly consider vectorized values in Relax. +2. The `lanes` field for the `Handle` datatype must always be 0, indicating that it is `Void` (see below). The `bits` field for `Handle` should always be set to 64 (it will not be used by Relax). + +We also define the following special notation for datatypes, to be used in the rest of the specification: +1. `Bool()`: This is shorthand for `UInt(bits=1, lanes=1)`, since TIR does not have a separate Boolean type. "True" refers to a value of 1 in this datatype and "false" refers to a value of 0. For convenience, we will refer to Boolean values as a separate datatype in the specification, due to their significance in `If` nodes. +2. `Void()`: This is shorthand for `Handle(bits=64, lanes=0)`. TIR uses this datatype to refer to opaque objects; in Relax, it is used to denote an unknown datatype. ## Expression Survey @@ -183,7 +196,7 @@ Oftentimes, compiler passes operate only on particular functions or add new func Here are the classes of values that Relax operates over, meaning that they can be assigned to variables or be the result of evaluating expressions. -- *Tensors* are n-dimensional arrays of scalar values (which can be integers of fixed bitwidths, floats of fixed bitwidths, or bools). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. +- *Tensors* are n-dimensional arrays of scalar values (which can be signed or unsigned integers of fixed bitwidths, floats of fixed bitwidths, or Boolean values). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. - *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. @@ -280,10 +293,11 @@ The following criteria apply to all programs (including before normalization): 13. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. 14. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. 15. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. -16. `PrimValue` nodes are intended only to be used with `value`s consisting of TIR `IntImm`s and `FloatImm`s. -17. `PrimStructInfo` annotations should use only the `Int` and `Float` datatypes for their `dtype` fields. +16. `PrimValue` nodes are intended only to be used with `value`s consisting of TIR `IntImm`s and `FloatImm`s (with `lanes` set to 1). +17. `PrimStructInfo` annotations should use only the `Int`, `UInt`, or `Float` datatypes for their `dtype` fields. +18. Per [the notes on `DataType`](#notes-on-datatype-and-related-terminology), any `DataType` annotation must have a `lanes` value of 1 for the `Int`, `UInt`, or `Float` datatypes and a `lanes` value of 0 for the `Handle` (`Void`) datatype. Additionally, `bits` must be 64 for `Void`. The supported bitwidths for `Int` and `UInt` are 1, 8, 16, 32, and 64; the supported bitwidths for `Float` are 16, 32, and 64. -Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. +Additionally, the criteria for normal form listed in [the previous section](#normal-form) must apply to any program that has been normalized. # Structural Information (`StructInfo`) in Relax @@ -374,7 +388,7 @@ Note that judging subtyping requires potentially reasoning about arbitrary `Shap 3. For all `S1`, `S1 <: ObjectStructInfo()`. 4. For `TensorStructInfo`: 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=-1, dtype=d)`. - 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=Void, shape=s)`. + 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=Void(), shape=s)`. 3. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s`, `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=d, shape=undefined)`. 4. Given any datatype `d`, an arbitrary `ndim` `n`, and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` if `s1` and `s2` are _definitely_ statically equal. We say that `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` _possibly_ holds if `s1` and `s2` are _possibly_ statically equal. 5. For `ShapeStructInfo`: @@ -725,7 +739,7 @@ For each expression, we define how it affects the program's visible state and th 8. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. 9. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. 10. The node `If(cond, true_branch, false_branch)` is evaluated as follows: - 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a bool scalar). + 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a `Bool` scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. 11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. From 17cdd298a8b7bf09f7db12485b3d34095e4139c5 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 18:00:49 -0500 Subject: [PATCH 23/30] Note immutability of tuples and shapes and possibility of interning --- relax_spec.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index a3c1878556..89e70c3b79 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -119,7 +119,7 @@ We also define the following special notation for datatypes, to be used in the r This specification provides a more detailed description of what each expression and `StructInfo` represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. 1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). -2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. +2. `Tuple` nodes construct a tuple (immutable fixed-size ordered grouping) of Relax values. 3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. @@ -146,7 +146,7 @@ This specification provides a more detailed description of what each expression The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. -12. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +12. `ShapeExpr` nodes construct shape literals, which are immutable collections of shape dimensions. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. 13. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. @@ -197,9 +197,9 @@ Oftentimes, compiler passes operate only on particular functions or add new func Here are the classes of values that Relax operates over, meaning that they can be assigned to variables or be the result of evaluating expressions. - *Tensors* are n-dimensional arrays of scalar values (which can be signed or unsigned integers of fixed bitwidths, floats of fixed bitwidths, or Boolean values). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. -- *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). +- *Tuples* represent a fixed-size immutable grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» -- *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. +- *Tensor shapes* (shape values) are immutable tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. - *Primitive values* (`PrimValue`s) represent immutable scalar values that are primarily intended for being passed to external procedures, like calls to `PackedFunc`s. As a rule of thumb, scalar values intended for arithmetical computations should be 0-rank tensors while scalar values meant to serve as metadata should be `PrimValue`s. - Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. @@ -731,12 +731,12 @@ For each expression, we define how it affects the program's visible state and th 1. The node `Constant(value)` creates a new tensor whose contents are `value`. 2. A variable (whether `Var`, `DataflowVar` , or `GlobalVar`) evaluates to the stored value for that variable in the current scope. -3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. +3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and returns a tuple value containing `v1`, `v2`, …, and `vn` in that order. 4. The node `PrimType(prim_expr)` evaluates the `PrimExpr` `prim_expr` first, obtaining a resulting `pv`. It then creates an immutable `PrimValue` containing `pv`. 5. The node `StringImm(s)` creates an immutable string container whose contents is `s`. It does not necessarily have to be a _new_ string container if, for example, string interning is implemented. 6. The node `DataTypeImm(dt)` creates a new immutable datatype representation. 7. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per `StructInfo` checking, must evaluate to a tuple) and then returning the `i`th field of the result. -8. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. +8. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and returns a shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. 9. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. 10. The node `If(cond, true_branch, false_branch)` is evaluated as follows: 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a `Bool` scalar). @@ -763,7 +763,7 @@ Optimizations are allowed to reorder and modify the operations of a program in a «Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchCast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» -For immutable containers like those for the results of `PrimValue`, `StringImm`, and `DataTypeImm`, it is not required for the results of evaluating these expressions to be _new_ containers—it is permitted for the compiler to reuse existing objects provided that the values contained within are identical. This optimization is called [interning](https://en.wikipedia.org/wiki/String_interning). However, for operations that return new mutable values (in particular, operations that return tensor values), those _must_ be newly allocated, since reusing values can affect the behavior under aliasing. +For immutable containers like those for the results of `Tuple`, `ShapeExpr`, `PrimValue`, `StringImm`, and `DataTypeImm`, it is not required for the results of evaluating these expressions to be _new_ containers—it is permitted for the compiler to reuse existing objects provided that the values contained within are identical. This optimization is called [interning](https://en.wikipedia.org/wiki/String_interning). However, for operations that return new mutable values (in particular, operations that return tensor values), those _must_ be newly allocated, since reusing values can affect the behavior under aliasing. The specification makes no guarantees about certain memory-related properties and hence also does not consider them to be "visible behaviors": From 7c87f32a5ff130aacc23313baf687438a317e9aa Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 18:01:36 -0500 Subject: [PATCH 24/30] Fix numbering in expression summary --- relax_spec.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 89e70c3b79..aea0840411 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -123,18 +123,18 @@ This specification provides a more detailed description of what each expression 3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. -5. `PrimValue` nodes construct immutable scalar values from `PrimExpr`s, primarily for interacting with `ExternFunc`s or operators. These scalars are boxed within TVM objects, allowing them to be nested inside TVM's containers. (By contrast, zero-dimensional tensors defined via `Constant` are mutable.) -6. `StringImm` nodes construct strings, intended primarily for interacting with `ExternFunc`s or operators. -7. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators (e.g., for TIR intrinsics that take a datatype as an input). -8. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. +6. `PrimValue` nodes construct immutable scalar values from `PrimExpr`s, primarily for interacting with `ExternFunc`s or operators. These scalars are boxed within TVM objects, allowing them to be nested inside TVM's containers. (By contrast, zero-dimensional tensors defined via `Constant` are mutable.) +7. `StringImm` nodes construct strings, intended primarily for interacting with `ExternFunc`s or operators. +8. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators (e.g., for TIR intrinsics that take a datatype as an input). +9. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s). 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. -9. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. -10. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. -11. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: +10. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. +11. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. +12. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. 1. The types must match: All `StructInfo` variants correspond to a category of value value (`TensorStructInfo` to a tensor value, `ShapeStructInfo` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: @@ -146,8 +146,8 @@ This specification provides a more detailed description of what each expression The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. -12. `ShapeExpr` nodes construct shape literals, which are immutable collections of shape dimensions. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. -13. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. +13. `ShapeExpr` nodes construct shape literals, which are immutable collections of shape dimensions. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +14. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. From 491ba4ad691f9863a249e1f022292b3e17dc6238 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 2 Feb 2023 17:58:54 -0500 Subject: [PATCH 25/30] Update the description of call_tir --- relax_spec.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index aea0840411..468788db7b 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -777,6 +777,16 @@ These semantic rules assume a single thread of evaluation on a single host machi The above evaluation rules are general, but leave much room for implementations of operators to specify custom semantics. Certain operators are used to perform common operations and will be discussed here as well. -- `call_tir(prim_func, arg1, arg2, ..., argn, sinfo_args=[aS])`: `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). The `StructInfo` arg `aS` gives the `StructInfo` of the result of calling the `PrimFunc`; it must be a `TensorStructInfo` with a `shape` field corresponding to a constant shape expression and a non-`Void` `dtype`, denoting the shape of the resulting tensor, or a a `TupleStringInfo` where all the `fields` are `TensorStructInfo`. Based on `aS`, the resulting tensor or tuple `r` will be allocated according to the sizes given in their `shape` fields. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. «If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. -- «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, sinfo_args=[aS])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the results, so purity is not assumed. `aS` denotes the `StructInfo` for the result.» -- `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. +- `call_tir(prim_func, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: + - `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). + - `args` must should be an expression that evaluates to a tuple of tensor values (where each member of a tuple will be a tensor argument to the `PrimFunc`). Let us call the members of the tuple `arg1`, `arg2`, ..., `argn`. + - `packed_ints` is an optional argument. If present, it must be a shape value (with `ShapeStructInfo`). Each dimension of the value (which we will call `shape1`, `shape2`, ..., `shapem`) + - The `StructInfo` arguments `aS1` through `aSk` give the `StructInfo` of the results of calling the `PrimFunc`. + - All the `aSi` must be `TensorStructInfo` with a `shape` field consisting of a `ShapeExpr` (possibly containing shape variables) and a non-`Void` `dtype`, denoting the shape of the resulting tensors. + - If there is exactly one member of `sinfo_args`, then the operation returns a single tensor with that shape; if there are multiple or zero members of `sinfo_args`, then the result will have the `StructInfo` `TupleStructInfo(fields=[aS1, as2, ..., aSk])`. + - Based on the `aSi`, the resulting tensors `r1`, `r2`, ..., `rk` will be allocated according to the sizes given in their `shape` fields. + - `f` will be called in destination-passing style, like so: `f(arg1, arg2, ..., argn, shape1, shape2, ..., shapem, r1, r2, ..., rk)`, omitting the `shapei` if `packed_ints` is not given. `f` is expected to mutate *only* the `ri` to give the output of the function, hence `call_tir` is considered pure. + - «If the shape or data type of the actual result do not correspond to the `aSi`, an error is issued.» + - After the call, the `ri` will be returned (returning `r1` directly if there is only a single result, otherwise returning `Tuple(fields=[r1, r2, ..., rk])`). +- «`call_dps_packed(global_symbol, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol` instead of a `PrimFunc` object. The `PackedFunc` may modify any member of `args` (`packed_ints`, if present, is immutable) in addition to the results, so purity is not assumed. The `StructInfo` for the result will be determined int he same manner as in `call_tir`, where it will be `aS1` if `sinfo_args` has a length of 1 and `TupleStructInfo(fields=[aS1, aS2, ..., aSk])` otherwise.» +- `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a shape object. From 47ab40b36ef60388e3c3166ea2ce1637353f3c23 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 6 Feb 2023 18:18:53 -0500 Subject: [PATCH 26/30] Specify invariants for TensorStructInfo --- relax_spec.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index 468788db7b..8e84d2c68d 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -374,6 +374,16 @@ def f(arg1, arg2, ..., argn): ``` » +### Invariants for `TensorStructInfo` + +Because the `shape` field of `TensorStructInfo` is an expression (either a `Var` or `ShapeExpr`), that expression may have its own `StructInfo`. In any `TensorStructInfo` derived by the below inference rules for `StructInfo` or in any `StructInfo` annotation, the following properties must hold of the `shape` field in `TensorStructInfo`: +1. If the `shape` field is a `Var`, the `Var` must have `ShapeStructInfo`. The `ndim` for the `Var`'s `ShapeStructInfo` must match that of the `TensorStructInfo`. +2. If the `shape` field is a literal `ShapeExpr`, then the `ndim` for the `TensorStructInfo` must match the number of fields in the `shape`'s `values` field (this is noted in the [well-formedness rules](#well-formedness-criteria)). + +Any shape variables that appear in the `ShapeStructInfo` must be in scope where the annotation appears. + +In particular, it is not permitted for the `TensorStructInfo` to have an unknown rank (`ndim` of -1) when the `shape` field has a non-negative `ndim`. + ## Subtyping for `StructInfo` Relax implements subtyping for `StructInfo`, which means that values with some `StructInfo` can be accepted where values with more general `StructInfo` are accepted We will denote the subtyping relationship as `S1 <: S2`, indicating that `S1` is a subtype of `S2`. For example. if `S1 <: S2` and some function expects an argument with `StructInfo` `S2`, then passing a value with `StructInfo` `S1` to that function is permitted; passing a value with `StructInfo` `S2` as an argument to a function that expects `S1` for that argument is *not* permitted—the value would have to be dynamically cast to `S1` using `MatchCast`. From 2336452cf6e23a435e3de3f4bb445c50806b19b7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 9 Feb 2023 21:09:15 -0500 Subject: [PATCH 27/30] PrimValue, StringImm, DataTypeImm are leaf nodes --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 8e84d2c68d..978eccf6b6 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -252,7 +252,7 @@ def func(x: Tensor) -> Tensor: To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the `StructInfo`-checking rules for operators rely on macros (`FInferShapeInfo`), _this means that the structure of the program can affect `StructInfo` inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these `StructInfo`-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying `StructInfo` checking. The normal form for Relax is very similar to ANF; differences will be noted. Here are the criteria required for a program to be in normal form: -1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. +1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, `PrimValue`, `StringImm`, `DataTypeImm`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. 2. `SeqExpr`s may appear only in the following locations: 1. In the `body` field of a `Function` node. 2. In the `true_branch` and `false_branch` fields of `If` nodes. From 613b84087ab16bb8d760d2be4517ce305ed6f543 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 10 Feb 2023 16:34:15 -0500 Subject: [PATCH 28/30] Tuples are represented using Arrays, not ADTs now --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 978eccf6b6..27953d2023 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -211,7 +211,7 @@ Because Relax supports calls to arbitrary `PackedFunc`s that can operate on a lo Possible specification in terms of the TVM object system: - Tensors are represented at run time as `NDArray`s (see `include/tvm/NDArray.h`). -- Tuples are represented using TVM ADTs (algebraic data types), which are arrays of TVM objects with a tag (see `include/tvm/runtime/container/adt.h`). Tuples use a tag of 0. +- Tuples are represented using TVM `Array`s (in contrast to `NDArray`s), which are immutable (see `include/tvm/runtime/container/array.h`). - At run time, closures are represented as a `ClosureObj` (see `include/tvm/runtime/container/closure.h`); in the Relax VM these more specifically use the `VMClosureObj` (see [`https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h`](https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h)). - Shape values are represented at run time as a `ShapeTuple` (see `include/tvm/runtime/container/shape_tuple.h`). - Strings are represented using TVM's `String` container (see `include/tvm/runtime/container/string.h`). From 7ebc082cb2ae6e69659c8a1c6043c309c8f529ee Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 13 Feb 2023 15:41:40 -0500 Subject: [PATCH 29/30] Fix typo --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 27953d2023..839c96d9f7 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -752,7 +752,7 @@ For each expression, we define how it affects the program's visible state and th 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a `Bool` scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. -11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. +11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not). Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. 12. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» 2. Otherwise, first evaluate `op` (it must evaluate to a closure or `PackedFunc`). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. From 4071b9b5bcfa19a6507d923b9b5d197216173645 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 14 Feb 2023 21:35:16 -0500 Subject: [PATCH 30/30] Add mention of null value --- relax_spec.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 839c96d9f7..60de82fa28 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -202,7 +202,7 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Tensor shapes* (shape values) are immutable tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. - *Primitive values* (`PrimValue`s) represent immutable scalar values that are primarily intended for being passed to external procedures, like calls to `PackedFunc`s. As a rule of thumb, scalar values intended for arithmetical computations should be 0-rank tensors while scalar values meant to serve as metadata should be `PrimValue`s. -- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. +- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. Another noteworthy value in this category is the _null object_ (the result of returning a null pointer in C++ or passing in `None` through the Python FFI), which is returned by the `null_value()` operator. ## Representation of Values at Run Time @@ -800,3 +800,4 @@ The above evaluation rules are general, but leave much room for implementations - After the call, the `ri` will be returned (returning `r1` directly if there is only a single result, otherwise returning `Tuple(fields=[r1, r2, ..., rk])`). - «`call_dps_packed(global_symbol, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol` instead of a `PrimFunc` object. The `PackedFunc` may modify any member of `args` (`packed_ints`, if present, is immutable) in addition to the results, so purity is not assumed. The `StructInfo` for the result will be determined int he same manner as in `call_tir`, where it will be `aS1` if `sinfo_args` has a length of 1 and `TupleStructInfo(fields=[aS1, aS2, ..., aSk])` otherwise.» - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a shape object. +- `null_value()`: Returns a null object (treated as `ObjectStructInfo`). This is used for indicating to operators that an optional argument has been omitted. \ No newline at end of file