From c0070146895d102ac2ee910afa56229077974828 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 17 Jan 2023 17:01:40 +0900 Subject: [PATCH] (Rebase) Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 5bf9c8acf12dfba9865ac9f8480341298131dec4 Author: Masahiro Masuda Date: Tue Jan 17 16:10:16 2023 +0900 clean up commit 5506d92ed9a4c48c63f192ddcb576c9665d4ad5b Author: Masahiro Masuda Date: Tue Jan 17 15:39:39 2023 +0900 link and run compiled cutlass code, result correct commit 81d39f84ebb1a7bcfe5c2fa9f97ce2130f932dbb Author: Masahiro Masuda Date: Tue Jan 17 15:13:41 2023 +0900 compile generated cutlass code commit c2a68e14575c2711497347d5fc93d15b88c6c79b Author: Masahiro Masuda Date: Tue Jan 17 07:47:31 2023 +0900 codegen working commit ba26344f85ebe43f88852c8c18b754bf03df1ce1 Author: Masahiro Masuda Date: Mon Jan 16 19:41:47 2023 +0900 wip commit ed3ac6d632a4798e411573f30d1a090bc05a96fc Author: Masahiro Masuda Date: Mon Jan 16 17:53:10 2023 +0900 wip commit 47e09e54a0d405a14a602d7a6d31c49399c5662f Author: Masahiro Masuda Date: Mon Jan 16 17:32:58 2023 +0900 wip commit b9e5df768b188de3dda1ef0d0f3db3fd592535d9 Author: Masahiro Masuda Date: Mon Jan 16 17:25:37 2023 +0900 copy codegen_c base function commit fe20e653ecf548f07432f06cd17395b554e6faa5 Author: Masahiro Masuda Date: Sat Jan 14 08:43:57 2023 +0900 add cutlass stub commit 990eec78b58ca259bc067bb32e4020f28d88b7c8 Author: Masahiro Masuda Date: Sat Jan 14 08:18:57 2023 +0900 updated cutlass revision commit 591a8f1ba62d9f8e923f2dcc1702e7e7590e92e2 Author: Masahiro Masuda Date: Sat Jan 14 08:02:01 2023 +0900 conv2d + relu DNNL offload works commit 1365402079626eab5bf99bad96dbfa4abd750175 Author: Masahiro Masuda Date: Fri Jan 13 16:35:49 2023 +0900 starting DNNL codegen commit 4a72e7810b0df31a4fb13856b5b6320ced4e978e Author: Masahiro Masuda Date: Thu Jan 12 14:02:19 2023 +0900 clean up commit 61cc55e94123f3064e0d1200c70f33b4a537c4ad Author: Masahiro Masuda Date: Tue Jan 10 16:26:31 2023 +0900 pattern based partitioning working commit 2433733c5458302cbe05e534d6c99bec13fb6d36 Author: Masahiro Masuda Date: Tue Jan 10 08:30:20 2023 +0900 add conv2d match & run test commit 360429440acb7068fdfd982d597523ebe032eb20 Author: Ruihang Lai Date: Mon Jan 9 17:20:05 2023 -0500 [Op][O2e] Indexing and datatype operators (#338) commit e45bdb73824d120bb3b848d4fdaa54f88211b509 Author: Tianqi Chen Date: Mon Jan 9 14:59:26 2023 -0500 [VM] Supporting "compiled" exec mode. (#331) * [VM] Supporting "compiled" exec mode. This PR adds support of "compiled" mode to the VM. The compiled mode translate the relax function into TIR function and drive it through the TIR function. It is different from the micro AOT codegen, which generate TIR code that targets the micro C runtime environment and useful for resource limited settings with smaller set of features. Both leverages the low-level TIR build that is also shared with TensorIR. The current implementation targets full TVM (VM) runtime, that comes with PackedFunc, object, tuple, closure and all kinds of rich structure support. This also mean that we can leverage the full runtime support to handle things like allocation, dynamic shape, easy plugins and python interaction, which are not available in more limited runtime. The user directly use the same API to load the generated code regardless of compiled mode or bytecode. And just need to change one line ```python ex = relax.vm.build(mod, target, exec_mode="compiled") ``` Most of the codegen features are lifted before the codegen phase, so the overall implementation would be around 500 loc for each exec mode and can be further cut down with future introduction of PrimValue. The simplicity is thanks to the TVM runtime archiecture that allows us to compose things together in objects. The only difference is how the PackedFunc of high-level driving is being provided. In the case of bytecode it is normal interpretation and in the case of compiled mode it is TIR. It is a complete implementation Unit-testcases are added. All codegen build tests are updated to include two exec_modes and have passed locally. The only exception that we skipped some special packedfunc handling(printing) because can be further simplified after we introduce PrimValue. Co-authored-by: Junru Shao * Address review comments Co-authored-by: Junru Shao commit 32c2bf74eda5ff9cb958e6d54a29c324d53f2869 Author: Ruihang Lai Date: Mon Jan 9 13:45:14 2023 -0500 [Op][O2d] Manipulation operators (#337) As tracked by #332, this PR is the O2d milestone of the high-level operator introduction plan. This PR introduces a few manipulation operators: * broadcast_to * concat * expand_dims * flatten * permute_dims * reshape * split * squeeze These operators are all well-tested. commit b39d11a37c899a1625ecee0ffdacc5ef5444365f Author: Ruihang Lai Date: Mon Jan 9 10:57:19 2023 -0500 [O2h] Neural network and linear algebra operators (#343) commit 1d6d897ec223cc07768e0382c3e21a196ffdfac8 Author: Ruihang Lai Date: Sun Jan 8 20:21:50 2023 -0500 [O2g] Convolution, pooling and image operators (#341) commit 95f784ece1d61676b88b5455be3dab5e3ddbc75a Author: Ruihang Lai Date: Sun Jan 8 16:53:10 2023 -0500 [Op][O2f] Set and searching operators (#339) commit be1c32d817bbbbd56329378d6d929dce79ecb0f8 Author: Siyuan Feng Date: Mon Jan 9 03:38:20 2023 +0800 simple fix jupyter error reporting (#345) commit da11e4bf373349ce4142949099e29d11655aa88b Author: Siyuan Feng Date: Sun Jan 8 23:09:22 2023 +0800 [TVMScript] Symbolic shape computing (#342) commit 80808fbf9a02480abf337b8a5edffe34c963feec Author: Ruihang Lai Date: Sat Jan 7 18:31:00 2023 -0500 [Op][O2c] Creation operators (#336) commit 5efc8f7224f83766875e74669e139ec82119a504 Author: Ruihang Lai Date: Sat Jan 7 11:14:23 2023 -0500 [TIR] Create Layout with specified axis dtype (apache/tvm#13663) (#340) commit ae71be06c8252c211642abb9d5b3e4583bdb6f6a Author: Ruihang Lai Date: Fri Jan 6 16:41:18 2023 -0500 [Op][O2b] Statistical operators (#334) commit 8220df74e339cdb6dab38a803b80edc3cd6b92e2 Author: Ruihang Lai Date: Thu Jan 5 18:31:48 2023 -0500 [Op][O1][O2a] Utility, arithmetic and comparison operators (#333) As tracked by #332, this PR is the kickoff part of high-level operator introduction in Relax. This PR is about the milestone O1 and O2a. Specifically, this PR * introduces some of common utility functions that the registration and StructInfo inference of each operator will often use. * introduces unary arithmetic operators: cos, log, negative, sigmoid, sin, sqrt, tanh. * refactors and introduces binary arithmetic operators: add, divide, floor_divide, multiply, subtract. * introduces binary comparative operators: equal, greater, greater_equal, less, less_equal, not_equal. These operators are well tested from three perspective: P1. the op getter can get correct op by name P2. their StructInfo inference result are as expected under all kinds of cases P3. Relax TVMScript parser can parse the scripts with the op inside For operators in O2a, most operators share almost the same StructInfo inference logic. Therefore, for tests in P2, in each category, not every op is tested in every case. For each case, it is good to have only part of op in this category tested. This is intended not to make overlarge testing file. commit f1cab0a05f05829c4c35e2a7e613bd69f2a17fae Author: Siyuan Feng Date: Thu Jan 5 20:43:28 2023 +0800 [TVMScript] Ensure consistent struct info between assign lhs and rhs with sinfo annotation (#328) * [TVMScript] Ensure consistent struct info between assign lhs and rhs with sinfo annotation * fix * fix commit dc7072efe290d7e8c69d8e216311510981fc82e1 Author: Tianqi Chen Date: Wed Jan 4 10:13:08 2023 -0500 [REFACTOR] Hide VM Impl, Improve execution logic. (#326) * [REFACTOR] Hide VM Impl, Improve execution logic. This PR refactors VM by hiding most of the VM implementations and improve the overall execution logic. - Unifies PackedFunc and Closure Table. - Update Closure mechanism to no longer depend on string. - Update VMMemoryLower to VMBuiltinLower to incorporate more VM intrinsic lowering, move some of the codegen intrinsic to this phase. - Allow directly pass in function index as VM instruction. * Address comment commit 2449d8c205f0b6e2c346132695b56039b07e9a10 Author: Steven S. Lyubomirsky Date: Tue Jan 3 22:04:16 2023 -0500 [IR][ASTPrinter] Tweaks to AST printer's handling of struct info (#330) commit 2d352807090ba1b7e898fbdcb83d6d9427c762cf Author: Siyuan Feng Date: Tue Jan 3 23:20:47 2023 +0800 [TVMScript] Enforce `I.DeclareFunc` to have function signature (#329) commit dcae50e836a0c2999f52d96a372fc7de584951f4 Author: Tianqi Chen Date: Mon Jan 2 15:21:49 2023 -0500 [BACKEND] Refactor and introduce full match-cast support. (#324) * [BACKEND] Refactor and introduce full match-cast support. This PR refactors VMShapeLower to introduce full match-cast support that enables nested tuples, type checks at argument boundary and symbolic shape computation. Along the way we also refactors cleans up some of vm codegen logic and adding unit-tests for different stages. * address comments commit a36920bf672d22e1d31e1e6f81d0447fd7a55806 Author: Siyuan Feng Date: Mon Jan 2 23:31:04 2023 +0800 [TVMScript] Fix empty TupleStructInfo (#327) commit 80710a826bda66532eeda978668ed157b471b186 Author: Tianqi Chen Date: Fri Dec 30 15:57:50 2022 -0500 [CONTAINER] Hash/Equal/JSON support for ShapeTuple (#325) This PR add hash/equal/json support for shape tuple. commit 343a1e7e2174612031c70ba8547577c7d21839e4 Author: Tianqi Chen Date: Thu Dec 29 18:33:17 2022 -0500 [REFACTOR] StructInfo M3: MatchShape=>MatchCast (#323) * Introduce match cast, and code changes along * add match_cast parser support (#9) * Match cast support for VMShapeLower CanonicalizeBinding * Remove `match_shape` (#12) * Refactor ExprVisitor/Mutator to consider Expr in StructInfo. Co-authored-by: Siyuan Feng commit e332285559d61db1c5033b8d50cd9d4af6c6b6f4 Author: Tianqi Chen Date: Thu Dec 29 01:28:09 2022 -0500 [REFACTOR] StructInfo M2: Cleanups on legacy shape related items (#320) * [REFACTOR] Remove shape function * [WIP] Remove shape_, runtime_dep shape * Remove shape_ pass Compile * Remove RuntimeDepShape (#11) * BlockBuilder: remove CanProveShapeEqual, consolidate binding emit to EmitNormalize * Remove DimType, make get_shape_of API different from op.shape_of Changes the init importing to direct import so the VSCode nagivator can directly jump to the defintion point. * Apply suggestions from code review Co-authored-by: Ruihang Lai * Clarify cases where struct info can be determinstically derived * Fix remaining testcases * Remove InferShape/Type per comment. Co-authored-by: Siyuan Feng Co-authored-by: Ruihang Lai commit edadf247551f526188c0a08b3812ffc0a1f9d8bd Author: Ruihang Lai Date: Fri Dec 23 14:46:07 2022 -0500 [Analysis] Optionally check structure info in well-formedness check (#321) With the introduction of structure info in #314, the well-formedness check will report malformed whenever an Expr doesn’t have defined structure info. However, when writing tests for well-formedness check and normalizer, usually we will manually construct the Exprs, which means their structure info are not defined most of the time. As a consequence, the well-formedness check will always complain “the Expr xxx doesn’t have structure info populated.” Therefore, when the checker fails to complain about the original reason of malformed, which means the checker is not working, the tests will still pass and we won’t be able to realize there is something wrong with the checker. Thus, in this PR we add an optional flag to the well-formedness check. In well-formedness tests, we will turn off the structure info check so that the original reason of being malformed will be revealed correctly. --- This PR also cleans up the DiagnosticContext parameter in the WellFormed API - the diag_ctx has been unused since the merge of #99. commit d548459a1736378398ab773dce413d90d49376cf Author: Ruihang Lai Date: Fri Dec 23 07:33:25 2022 -0500 [Op] Enforce int64 output shape in CallTIR (#322) commit 10a87a455bbb84b0a0d20b22bd31784b9f4b9774 Author: Chaosfan Date: Fri Dec 23 08:03:48 2022 +0800 [Bugfix] Handle function name properly in Relax TVMScript printer (#317) * remove relax_func_name_ and change logic * well_formed check for globalvar and gsymbol consistency * revise the logic in well_formed and update test * Remove `global_symbol` in test_function_attr.py * Update docs Co-authored-by: Ruihang Lai commit 29aebb9d24cbf52ab21fd98996633534301ef34d Author: Tianqi Chen Date: Wed Dec 21 20:21:57 2022 -0500 [REFACTOR] M1: Change parser/printer to only depend on struct info (#319) * [REFACTOR] StructInfo M1: Parser/printer/Var/Function to only depend on struct info field * Update src/relax/backend/vm/vm_shape_lower.cc Co-authored-by: Ruihang Lai * Address comments * Allow function to have default value Co-authored-by: Siyuan Feng Co-authored-by: Ruihang Lai commit e6173430f491c1d88d2ab77ce0ab43a8c602df30 Author: Tianqi Chen Date: Wed Dec 21 00:42:29 2022 -0500 [REFACTOR][ARCH] Introduce StructInfo M0 (#314) * [IR] Introduce StructInfo * StructInfoFunctor and Analysis Support * [TVMScript] Parse type/shape annotation with StructInfo * remove runtime type assign * Remove type/shape during parsing (#2) * Normalizer prep: simple checks and legacy function renaming. * Struct info deduction in BlockBuilder. * Two TODOs * StructInfo Normalizer Fixes (#3) * StructInfo AST Fix * Fix Extern Func Deduction and shape mutator. * Update VoidStructInfo & globalvar (#4) * Fix passes and proper sinfo propagation. * Refactor EraseToWellDefined to Enable Remapping * [WIP] First stab at symbolic param tracking * Update EraseToWellDefined to support symbolic shape return (#5) * fix R.shape with ndim (#6) * Remove update shape/type * Address review comment, AnnotateTypeShape=>AnnotateStructInfo * Update include/tvm/script/ir_builder/relax/frame.h Co-authored-by: Ruihang Lai * Address comments * Update printer to use structinfo (#7) * Update Error mechanism to prep for obj loc based reporting * Symbolic shape aware function call return value derivation. The main flow works as follows: - Match and populate shape_var_map and var_map by visit each pair of param and call arguments. - Call EraseToWellDefined to map the ret parameter to new result. * [ANALYSIS] Refactor well-form to only look at struct info. * Update comments according to reviews. * Update include/tvm/relax/struct_info.h Co-authored-by: Ruihang Lai Co-authored-by: Siyuan Feng Co-authored-by: Tianqi Chen Co-authored-by: Ruihang Lai commit 151701740fac3a53b35799a82c85d86f91b720ee Author: Tianqi Chen Date: Fri Dec 16 17:48:26 2022 -0500 Update relay_translator.py commit ad0f3179a84b3bc167f91c3eb082cb996b1d04e2 Author: Ruihang Lai Date: Fri Dec 16 17:37:00 2022 -0500 [Translator] Remove global symbol and follow-up fix for #262 (#316) This PR removes the `global_symbol` linkage added by Relay Translator. It also fixes unaddressed comments of #262. All tests can pass locally and I believe it is safe to merge this PR directly. commit 850deded1201001d833ac65991fb1a4c6509cb1b Author: Ruihang Lai Date: Fri Dec 16 16:19:48 2022 -0500 [Translator] Support translating op calls with Tuple input (#262) Previously, when a Relay function contains a Call which directly uses Tuples as arguments (the example below), ``` %25 = (%23, %24) /* ty=(Tensor[(1, 160), float32], Tensor[(1, 160), float32]) */; %26 = concatenate(%25, axis=-1) /* ty=Tensor[(1, 320), float32] */; ``` our Relay-translator is unable to generate corresponding CallTIR, because the translator always assumes a argument of a Call is mapped to a single tensor (see the code snippet below: the translator directly passes the Relax variable `new_args[-1]` to function `te_tensors`, which translate a Var to a single tensor). https://github.com/tlc-pack/relax/blob/60e9a01cdfdd013945790fc03d5abad29b8a7c0b/python/tvm/relax/testing/relay_translator.py#L124 https://github.com/tlc-pack/relax/blob/60e9a01cdfdd013945790fc03d5abad29b8a7c0b/src/relax/ir/emit_te.h#L56-L61 But in fact, the Relax variable may correspond to a Tuple of tensors, which wasn’t taken into consideration before. And such case can lead to error in `TETensor`, when creating tensors. Therefore, this PR fixes the issue by examine the Relax variable before the tensor creation of Relay Call arguments. If an argument has shape Tuple and type TupleType, we break down the tuple Variable and emit a TupleGetItem for each field, and meanwhile create a tensor for each field. commit 54a0ff551adb90937073675b4fb3d5439b814398 Author: Siyuan Feng Date: Fri Dec 16 21:02:13 2022 +0800 Remove relax parser_v1 (#313) commit b363dd48aced8fb939880db8cf595ed65b7ecc77 Author: Steven S. Lyubomirsky Date: Wed Dec 14 22:51:38 2022 -0500 [Debugging][Arch] Expose `shape_` fields for `TupleGetItem` and `If` nodes, fix AST printer accordingly (#311) * Make the shape of If and TupleGetItem nodes accessible in Python * Remove order-dependency from AST printer tests * Trailing whitespace commit 4bb01fe4eccdd59614cc264838a389b21dd40388 Author: Yuchen Jin Date: Wed Dec 14 08:11:47 2022 -0800 [IR] Dedicated Relax Call, Constant, Tuple, TupleGetItem, If (#306) * relax.Constant. * Add callnode; * Tuple, tuplegetitem, If * mypy. * lint * rebase & fix printer. * rebase & remove virtual_device_ * address comments & leave todos. * address comments. * address comments. * tuple index. * type anno. commit 4cda8a5881fd4cd2473258b35244fc4129b6110c Author: Steven S. Lyubomirsky Date: Wed Dec 14 09:09:03 2022 -0500 [BlockBuilder][Refactor] Normalize nested `SeqExpr`s (#310) Co-authored-by: Ruihang Lai commit 5aab150f322526c1a7bfe6cea0f4d7a7543a7f46 Author: Ruihang Lai Date: Tue Dec 13 17:06:06 2022 -0500 [ExprMutator] No prologue in VisitWithNewScope when input is SeqExpr (#305) commit 0bf1f1b784f19298117e36016a2e522f58c143fc Author: Tianqi Chen Date: Tue Dec 13 15:27:05 2022 -0500 [REFACTOR] Refactor BlockBuilder (#308) commit 28d598b6a7c55f95f8f9c2ccd5c860ba5451232d Author: Siyuan Feng Date: Sun Dec 11 01:28:56 2022 +0800 [Normalizer] Combine Nearby Blocks in SeqExprs (#298) commit e152c50e368454afab75425fcb0863b1c328bf4c Author: Tianqi Chen Date: Thu Dec 8 19:33:18 2022 -0500 [ARCH] Add VisitBinding second-level dispatcher in Expr type. (#301) commit fed6b8fc88b824ec68260417793447dbe524c4c3 Author: Yuchen Jin Date: Wed Dec 7 16:55:40 2022 -0800 [Linkage] Cleanup global_symbol attachment and linkage. (#300) * Cleanup global_symbol attachment and linkage. * lint * Add global_symbol to the main function in translation. commit e0907d4fd03af1731310647d3d0547bdff2cfaf6 Author: Tianqi Chen Date: Tue Dec 6 21:35:20 2022 -0500 [ARCH] Introduce NestedMsg to robustly handle nested-tuple analysis (#295) commit 2eb99975dc1b40b83db7dcbb96b748503dcb3319 Author: Siyuan Feng Date: Mon Dec 5 21:57:21 2022 +0800 [TVMScript] Update sccript printer to enable roundtrip tests (#291) commit f8ab9890e14c2533c401969ebf11dd591beff592 Author: Hongyi Jin <3231950289@qq.com> Date: Sun Nov 27 09:59:26 2022 -0500 [RUNTIME] Correctly handling export_module when exporting modules of different type (#13489) commit 9009840e654a9900009f7776a19e26f29b1e3f85 Author: Steven S. Lyubomirsky Date: Fri Dec 2 18:33:50 2022 -0500 [Debugging] Support PackedFuncType in the AST Printer (#289) commit bda0e42f05eaba657c40a850486e55c39924f3bf Author: Steven S. Lyubomirsky Date: Fri Dec 2 18:31:39 2022 -0500 [IR][Bugfix] Improvements to the normalizer and well-formed checker (#288) commit d5fe87b21546995c7a88905bd04b4e944d28a0f4 Author: Yong Wu Date: Thu Dec 1 20:00:38 2022 -0800 Enforce i64 index in ShapeExpr (#281) commit 9c9eb5585501a5da0f25ca38d7d3ac8269b6714c Author: Yuchen Jin Date: Thu Dec 1 11:00:47 2022 -0800 [Parser] Register memory operators to new parser. (#279) commit 28c3f68cc51d2c22936c5496debcb8c2de54040b Author: Yong Wu Date: Thu Dec 1 08:55:31 2022 -0800 [TVMScript] enable the closure test (#280) * [TVMScript] enable the closure tests. commit eb9d531b2565cdd000f46e5ecae2c45b9f589abe Author: Yuchen Jin Date: Thu Dec 1 05:47:05 2022 -0800 [Normalizer] Enforce all Expr have checked_type_ invariance after normalization. (#287) commit 43f81ddf4afc2f4fdb214c9f994e844f53126cdb Author: Steven S. Lyubomirsky Date: Mon Nov 21 19:25:43 2022 -0500 [Debugging][Bugfix] Debug printer improvements: Print `shape_` and `checked_type_` for all nodes and handle non-binding `MatchShape`s (#261) The initial AST printer only included the `shape_` and `checked_type_` fields for variables because of the potential for infinite recursion (`shape_` nodes can contain other expressions, which in turn have `shape_` nodes). This PR cuts off the potential recursion to allow for printing these fields for all Relax expressions, which should be more useful for debugging. This PR also fixes a bug: The AST printer previously did not handle `MatchShape` bindings that did not bind a new variable. commit 304048c33956dddb5027fec26541d57f903d8ca2 Author: YuchenJin Date: Thu Nov 17 17:02:11 2022 -0800 Fix after rebase, and reorganize the TVMScript folder structure. Co-authored-by: Junru Shao Co-authored-by: Siyuan Feng commit e7277460f0a2c7c980be9323cdf7919dc38153e2 Author: Siyuan Feng Date: Thu Nov 17 00:31:32 2022 +0800 [TVMScript] Switch to the new parser (#276) * [TVMScript] Support cross-function call for relax function This PR adds support for cross-function call for relax function, by declaring a function signature (i.e. an empty function that contains params and return type/shape but w/o body.) However, the PR meets the issue of block_builder shape deduction, which does not use function `ret_shape` to infer the shape of GlobalVar Calls. commit 7152175762613130e3ba647c77cc9818312a5b06 Author: Yuchen Jin Date: Sat Nov 5 16:45:33 2022 -0500 [CI] Enable Mypy type checking for Relax; Fix typing errors to pass Mypy checking. (#270) commit 6f8f6da505b835345d7709d06bdfd8dddce7e85b Author: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Thu Nov 3 08:16:35 2022 -0700 Introduce memory primitives (#255) Introduce the memory primitives, including `relax.memory.{alloc_storage, alloc_tensor, kill_storage, kill_tensor}`. commit 48b7c158cc01532f9019a2e615f2d94766a9464c Author: Siyuan Feng Date: Thu Oct 20 08:30:47 2022 +0800 [TVMScript] Update Type Annotation Behavior of the Parser (#269) This commit changes the behavior of the parser to allow type annotations, as suggested by the community. The current behavior: - Use the more refined type/shape between user annotated and deduced type/shape. The updated behavior: - Always use user annotations - Only checks if the type/shape is valid. commit 5c3079bb6e1e4eeb4dc2d9b740facb2686c67519 Author: sung Date: Mon Oct 17 19:07:01 2022 -0700 Reenable autotvm silencer; fix e2e_auto_tir.py; fix lint. Co-authored-by: YuchenJin commit 85b81292626ab6f23caf2b61095a6f957b61b21c Author: sung Date: Mon Oct 17 18:09:34 2022 -0700 Recover: [Bugfix] Couple of bug fixes to run TVM-gen code together with BYOC (#249) commit c46ae8566582f1fcd8fcda1479943d3abb95b3b0 Author: sung Date: Mon Oct 17 17:16:01 2022 -0700 Recover: [Pass] Separate ApplyHistoryBest from tuning passes (#226) commit 83bc7cb144643d5823bf06220186528923835667 Author: Junru Shao Date: Sun Oct 16 22:52:56 2022 -0700 Enable Hexagon tests commit f9f4f7904ec5468a725b2ba924a619a7c5ed4e43 Author: Junru Shao Date: Sat Oct 15 15:25:56 2022 -0700 Recover dropped commits [TVMScript] B4: If branch support (#263) B8: Local Function Support (#258) [TVMScript] B3: Type annotation checks (#256) [TVMScript][Parser] B1: Dataflow block (#252) [TVMScript] B2: match shape support (#251) [TVMScript] B6/B7: Symbolic shape and var shadowing (#245) [TVMScript] B5: Support relax op (#244) [TVMScript] B0: Call_tir support (#243) enhance parser error reporting (#242) [TVMScript] A1: Relax Parser infra (#240) update ci image versions. (#241) [TVMScript] B2-4: TIR IRBuilder (#239) [TVMScript] A0: Relax IRBuilder infra (#235) [TVMScript] B5-6: TIR IRBuilder (#231) [TVMScript] B1: IRBuilder (#228) [TVMScript] New Parser: Part C (#218) [TVMScript] New Parser: Part A (#221) [TVMScript] New Parser: Part B (#217) Not recovered: [Pass] Separate ApplyHistoryBest from tuning passes (#226) [Bugfix] Couple of bug fixes to run TVM-gen code together with BYOC (#249) co-authored-by: Yuchen Jin co-authored-by: Siyuan Feng co-authored-by: Ruihang Lai commit 65a53034bc0bee9877a1bdf363c2eadcde35f226 Author: Steven S. Lyubomirsky Date: Thu Oct 13 23:06:55 2022 -0400 [Op][Debugging] Add `assert` operator (#260) It was brought up that Relay lacks an assert operator, so we may as well have one in Relax for debugging. One issue is that we can't name it "`assert`" because Python will treat it as a syntax error to have it as a field name for the "`relax`" module, i.e., `relax.assert` is a syntax error. Thus the op is named "`assert_op`," which is not ideal but serves its purpose. commit 71d96e6c0a314936fa49fd7bc1ea79069027ab12 Author: Yuchen Jin Date: Wed Oct 12 05:07:33 2022 -0700 [Pass] Support Function and If in Normalize pass. (#268) * Support Function and If in Normalize pass. * Use structural equality for expr_memo_. * Change back to pointer equality for expr_memo_; Add more tests. * rebase. commit 312a344cdeec66b1330a80d34ca78556fb338e7c Author: Steven S. Lyubomirsky Date: Tue Oct 11 18:25:29 2022 -0400 [Analysis] Expose analyses related to vars in Python (#265) Previously, analyses to gather up all variables, free variables, bound variables, all global variables, and all global variables that are called had been implemented in C++ but had not been exposed in Python or tested. This PR exposes these analyses and adds tests for them. Two further changes: * The analyses previously ignored variables bound in `MatchShape` nodes; these are now treated as bindings too. * `rec_global_vars` is renamed `called_global_vars`, since the analysis itself does not check recursion. commit 132702be7e7ed0256045d7a405e532c3d5beef6d Author: Steven S. Lyubomirsky Date: Mon Oct 10 18:19:38 2022 -0400 [Expr] Allow annotating return shape on function nodes (#253) This PR adds a `ret_shape` field for specifying the shape of the function's return value. At present, we will not use this information, but by adding it into the AST, we will be able to parse the return shape and use it in the future. Parser V1 in this PR will just always list the `ret_shape` as `RuntimeDepShape`. commit 7276c9e2ee13a4754775491ca36a7aae2d55b827 Author: Steven S. Lyubomirsky Date: Sat Sep 24 00:11:45 2022 -0400 [Bugfix][VM] Properly convert tensor inputs in `save_function` (#257) It was observed that closures saved using `save_function` would crash when used over RPC with the `time_evaluator`, whereas using `set_input` and `invoke_stateful` worked as normal. While I am not entirely sure why these failures happened over RPC only in `time_evaluator` (but not in other RPC trials), it became clear that `set_input` performs a conversion of input tensor values in `SetInputTensorWithIndex`, while `save_function` was not doing this. Adding this conversion fixed the observed bug. commit 7183c7ffbe896dd9b5f5742b62afe9c821dae682 Author: Josh Fromm Date: Wed Sep 21 17:07:08 2022 -0700 [Call TIR] Fix bug when invoking call_tir with scalar values. (#254) This small PR changes a check in the tvmscript parser to support empty shape tuples which are used to represent scalars. I added a scalar addition test to make sure it works properly. commit 605ba8d1548efb90980f9b18ea94f1d53f9ec3ec Author: Steven S. Lyubomirsky Date: Wed Sep 14 17:27:03 2022 -0400 [Bugfix][Op] Register attributes for unique and print (#248) Attempting to use `dump_ast` on functions containing the operators `relax.unique` and `relax.print` previously crashed due to being unable to query their attributes' keys. It turned out that this was a problem with the operator attributes: They had not been registered on the Python side, so Python representation treated them as opaque TVM objects. This PR corrects this mistake. commit f4525dd8a3e61f572b50107555cef4b469c971f4 Author: Steven S. Lyubomirsky Date: Wed Sep 14 17:24:40 2022 -0400 [VM][Benchmarking] Add option for saving e2e results as CSV file (#247) This PR makes some small additions to the end-to-end AutoTIR script, namely eliminating a bug (it was incorrectly using the stateful API) and adding an option to save the test results as a CSV file for benchmarking purposes (the data can then be separately analyzed as needed). These changes also required a small extension to the save_function method in the VM, namely allowing it to take keyword arguments. commit f1ee4b6cd2c3ee0596cef6f5b7ff7e715fb4ae0d Author: Ruihang Lai Date: Wed Sep 14 17:23:29 2022 -0400 [BugFix] Enable emit global MatchShape (#246) Fix an incorrect check which disables emitting global MatchShape outside a dataflow block and mistakenly enables emitting dataflow MatchShape outside a dataflow block. commit 0a7a0a9daf5f1a2fa06ee6cd6169a28d397821fa Author: Steven S. Lyubomirsky Date: Thu Sep 8 09:49:05 2022 -0400 [Pass] Canonicalizing Bindings (#233) It may be useful for some passes to collapse chains of definitions, particularly after other compiler transformations that may reduce or simplify some expressions. This pass will take chains of definitions and replace references to later definitions to the original one. It works by checking `LookupBinding` for each var use-site and replacing the var with its definition if the definition was another var. (Note: This required updating `BlockBuilder` to also update its binding map for `MatchShape` nodes; that was arguably a bug.) Additionally, `MatchShape` bindings where the `LHS` and the `RHS` are guaranteed to match at compile time are canonicalized into ordinary `VarBinding`s. commit 7a6f91f7d4077eebf926aa1f19281404494b9362 Author: Prakalp Srivastava Date: Thu Sep 1 07:02:57 2022 -0400 [Hexgaon] Use uploaded path to load module. (#238) * Fixes a bug to use the uploaded file remote path for loading the module remotely. * Modifies the task_python_hexagon.sh script to only run passing test on device. This is used by Jenkins CI. commit e50290140c204ae091e335b797a07f2f6567a163 Author: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Thu Aug 18 21:51:35 2022 -0700 [Pass] New Python ExprVisitor/ExprMutator! (#190) Add decorators `visitor` and `mutator` to help users create `ExprVisitor` and `ExprMutator` in Python. Users can customize visit/rewrite/post-order-rewrite function in Python. `PyExprVisitor` and `PyExprMutator` lists the functions users can customize. commit 7313855476cc522bf3e8bdbe7a60b82cd725fe4c Author: Ruihang Lai Date: Thu Aug 18 15:20:06 2022 -0400 [BugFix] Expose `relax.expr.Constant` to `relax.Constant` (#230) commit cdfd4e939f2d1e88c560a05d83ddf2f7afe70304 Author: Siyuan Feng Date: Thu Aug 18 02:25:13 2022 +0800 [FIX] Fix windows build issue when allocating a dynamic array (#219) In the current codebase, kNumArgs is a runtime-dependent variable (i.e. its value depends on the input shape of Array). Allocating arrays with runtime values is not allowed during building on Windows (I'm surprised it can be compiled on Linux and macOS) commit 887762cd97686ae23a61609ca9ffc8d6a2c5178b Author: Yong Wu Date: Mon Aug 15 08:00:31 2022 +0800 Update with rebase commit 5a23346bc437043b48866411e39dfcf066edda59 Author: Yuchen Jin Date: Sun Aug 14 14:44:12 2022 -0700 [Bugfix][VM] Fix var binding to a ConstantNode; Force VM if.cond register to take an NDArray instead of POD. (#216) Fix the bug in #212. The cause of this bug is VM Codegen did not handle binding ConstantNode to variable (`x = relax.const([1, 2])`) and save the constant NDArray to the register. Previously the codegen only handles the case where ConstantNode as CallNode's arguments. Now it's fixed and unit test is added. Fix the bug in https://github.com/tlc-pack/relax/issues/214#issuecomment-1211411432, the issue was caused by the VM simply read the condition register of the If instruction, and expect it to be a POD int or bool. https://github.com/tlc-pack/relax/commit/811e877c289fa52f55886c8a3e8dce10ed84915f adds a `LoadScalarInt` function similar to the Relay VM to check the If.cond register stores an NDArray, and cast it to int_64. Since we haven't introduced PrimValue and PrimType (that represents POD values like int and bool) to the Relax language yet, let's enforce `If->cond` to be a Tensor (NDArray at runtime). commit 6c9d403503297a0d0e28318bafcba9fc9c99ae42 Author: Steven S. Lyubomirsky Date: Fri Aug 12 13:53:28 2022 -0400 [VM][UX] Allow for saving closures to avoid extra dictionary lookups in timing trials (#208) This PR implements a function that allows for saving a `PackedFunc` in the VM's module that just calls an existing function with a specific set of arguments to address #179 and #178. The main use of this is for timing, to avoid some overhead in looking up functions. commit e172b40af31dc3384adbcf6e7b0bce7f31ce41ea Author: Jiawei Liu Date: Thu Aug 11 19:55:57 2022 -0500 [Pass][UX] Statement rewriter for DataflowBlock (#210) - Implements a few APIs to quickly perform statement-level mutation: `add`/`remove_unused`/`remove_all_unused`/`replace_all_uses`. - Implemented `remove_all_unused` to remove dead statements inside `DataflowBlock` cc: @psrivas2 - Address minor issues (unnecessary headers and bad docstrings) in https://github.com/tlc-pack/relax/pull/163 commit 37791e0a5d4a495365fd647f2cecbed16f3a3785 Author: Jiawei Liu Date: Thu Aug 11 13:50:56 2022 -0500 Clean warning messages by Clang and Pylint (#215) * refact: clean clang warning in relax * refact: fix pylint * fix cpplint and clangd suggestions * fix: no cpplint on virtual-override commit 0b00715dc634aa7f091e942a54a29ee9c802ccf9 Author: Steven S. Lyubomirsky Date: Wed Aug 10 11:47:37 2022 -0400 [VM][UX] Implement stateful API (#207) This PR implements the stateful API discussed in https://github.com/tlc-pack/relax/issues/179. It ensures that if you use `set_input` to set inputs, you must use `invoke_stateful` to run the function (otherwise failing) and must obtain the results using `get_output`. It handles nested tuple returns. commit ed7b77e040654582d1ab1b9535ebbc4da77da243 Author: Steven S. Lyubomirsky Date: Tue Aug 9 17:07:52 2022 -0400 [Op][Debugging] Add a print operator (#201) * Attempt at adding a print operator * Fix the registration * Actually use the format string * Improve test * Fix comment placement * Improve the docstring for relax_print * Handle tuples too * Formatting :( * Correct commit message * Match attr name across Python and C++ * Make print variadic commit a9bd3053c1106d1926fce1dc5787fc8be27f3985 Author: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Fri Aug 5 11:45:03 2022 -0400 [Pass] Implement legacy lowering pass that leverages relay op strategy (#189) This PR implements Relax Op lowering that leverages existing Relay Op Strategy (legacy). As ops like conv2d, matmul are relay-, relax- independent, this pass assumes that we can always find relay op equivalents for such relax ops and use their info to leverage the relay op strategy. commit 1a1bcf75d97b2e7e4f758b6cd08bd747b222ef36 Author: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Thu Aug 4 17:56:17 2022 -0400 [Pass] Introduce metaschedule as a tuning pass (#188) This PR delivers MetaSchedule tuning as a tuning passes. We can either tune at IRModule level with relax.transform.MetaScheduleTuneIRMod or tune at primfunc level with relax.transform.MetaScheduleTuneTIR. commit 7144654633477ea0d2bff300ba753dc8bfdeae4d Author: Steven S. Lyubomirsky Date: Thu Aug 4 14:34:10 2022 -0400 [Example][UX] Make the RPC timeout configurable in the `e2e_auto_tir` example (#186) Running the e2e_auto_tir example over RPC can run into issues due to timeouts because some models can take a long time to run on some machines. This PR makes the RPC timeout configurable to more easily address these issues. commit 81e565e5df90cfe12d22deb7b26845ea3aa13526 Author: Tianqi Chen Date: Wed Aug 3 19:38:21 2022 -0400 Fix BlockBuilder Scope Recovery in Misuse (#199) This happens in interactive usecases. When function scope exit triggers an error, we need to recovery the BlockBuilder.current properly so users can try again. commit 21b1e7dc35dc838214cd4b6f26fbc31492323b02 Author: Steven S. Lyubomirsky Date: Wed Aug 3 19:09:21 2022 -0400 [Testing][AST] Add a simple AST printer for debugging (#198) * Add ast printer * Print seq expr body * Match annotation field names to real AST * Handle call attrs and func ret types * Add more advanced test cases commit 89f55c8167a80b4b9c8751309b5db648fb4db047 Author: Jiawei Liu Date: Wed Aug 3 09:59:47 2022 -0500 [UX] Adopt changes from tvm-main and render code with IPython.display (#192) Render code with IPython.display.HTML if possible to fix the ansi-escape 24-bit rendering issue in Colab. commit 0b52b558eb14b3f113a4b543c8f0a824baaa58bc Author: Jiawei Liu Date: Mon Aug 1 11:59:24 2022 -0500 Dataflow Pattern Lang: Core Matching Features (#163) The structure is similar to the Relay's pattern matcher (https://github.com/apache/tvm/pull/5231). The main difference is that those pattern types are adopted to be relax-compatible. Relay pattern types, some less used patterns (IfPattern) and df-topological patterns (DominatorPattern) are ignored (some of them will be brought later). The implementation splits patterns into two parts: - **Match an Expression**: match an expression syntactically (`MatchExprPattern`, i.e., `DFPatternMatcher`); - **Match a Graph**: match a graph (cross multiple `VarBinding`) topologically (`MatchGraphPattern`); commit 74371634e9a011e63650b734aba20546b016c524 Author: Jiawei Liu Date: Tue Jul 26 20:06:25 2022 -0500 [UX] Highlight TVMScript with Pygments (#185) commit 15e54ef215950944ffd74858c12c30aabcb0dcce Author: Siyuan Feng Date: Sat Jul 23 11:22:13 2022 +0800 [Pass] Enhance BindParams to take numpy dict as input (#184) commit cf2e3b97110c805597059c5ba8303a653417e080 Author: Steven S. Lyubomirsky Date: Mon Jul 18 21:45:21 2022 -0400 [Bugfix][VM] Ensure set_input works over RPC by not returning an array of argument names (#183) Currently, attempting to use the VM's `set_input` method will fail over RPC because `set_input` calls `get_func_param_names`, which returns an array of parameter names. RPC does not support sending arrays. This PR corrects this issue by instead having `set_input` query the function arity and then query the argument names one by one, which is the approach taken by the Relay VM (accordingly, the names for the functions used to do this, `get_function_arity` and `get_function_param_name`, are taken from the Relay VM). This PR also adds a unit test over RPC on localhost. commit b0e57dbc0862499c3f2a7d91858354c41fcf5e95 Author: Yong Wu Date: Fri Jul 15 11:50:29 2022 -0700 Fix after rebase commit 3494b7a47bf0f7c3219538b2e9064b825cf3258c Author: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Mon Jul 18 00:38:41 2022 -0400 [Pass Infra] Tuning API serialization and database support (#168) * refactor tuning API to support serialization of Choice, Knob, Trace * Implement tuning api JSON database * Add comments * fix pylint * fix cpplint * reflect feedback * add minor comment for the future work commit 777549a6037cc97b698f53ed629cf65c33ae7eca Author: Siyuan Feng Date: Mon Jul 18 00:05:14 2022 +0800 [Fix] fix windows build issue (#182) TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS is needed when we have a default-like constructor (e.g. (Span span = Span())) commit b81e6a9838f92ba412a0bd4951a46cc61a43a22d Author: Siyuan Feng Date: Mon Jul 18 00:04:03 2022 +0800 fix print twice issue (#181) commit d4cc79ed664bbe34a4d9dab2923cd5a7a7c5b52c Author: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Thu Jul 14 09:15:44 2022 -0700 [Pass] Python ExprMutatorBase/ExprMutator (#172) - Rewrite ExprFunctor in Python. New ExprMutatorBase and ExprMutator in Python. - Implement demo passes: RewriteFMA and FuseFMA with Python ExprMutator. - Expose some functions to ffi in block_builder.py commit 01cdc4d43258b1fb9dcc630f05f38f792e3bc513 Author: Prakalp Srivastava Date: Tue Jul 12 19:25:51 2022 -0400 [VM] Deprecate API to save/load executable to file (#176) Executable `save_to_file` and `load_exec_from_file` API was used to save/load just the executable to/from file. This was confusing as it did not export the TensorIR kernels in the Relax Module, thus leading to bugs such as https://github.com/tlc-pack/relax/issues/175. Moreover, the API was only used in some tests, and not useful for end user. Deprecating this API to have a single uniform way of serializing/deserializing TVM IRModule using `export_library` and `tvm.runtime.load_module` API. commit 74b3d67e8ae74aed3446a5ae5a05b8f5586e2c3b Author: Yuchen Jin Date: Fri Jul 1 09:31:30 2022 -0700 [Refactor] Generic dispatching for `IsBaseOf`; Simplify Type/Expr initializations; `relax` -> `R` in printer; Disallow local function in VMCodegen (#171) - Generic dispatching for `IsBaseOf`: `IsBaseOf` uses a bunch of if-else to check if the subtype relation between the base type and derived type, now it's changed to use a generic TypeFunctor to dispatch on the base class to do the check. - Simplify Type/Expr initializations: We had to write `RuntimeDepShape(Span()`), `ObjectType(Span())` to initialize several Types and Exprs, this is due to the `TVM_DEFINE_OBJECT_REF_METHODS` macro that sets the constructor with `= default`. By changing to use `TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS`, we can now just write `RuntimeDepShape()` without specifying an empty span. - `relax` -> `R` in printer: Change to print `R` rather than `relax` in TVMScript as the default behavior. This is consistent with our test cases and TIR convention: using `T` as shorthand. - Disallow generating code for local function in VMCodegen: these local functions should have been lifted in the lambda lifting pass before codegen. commit 8fdc3ba3eae0d1ffc535e240be251aaae5546eb8 Author: Prakalp Srivastava Date: Thu Jun 30 15:14:40 2022 -0700 [Parser] Enable R.parser.pretty_print to print TIR PrimFunc (#174) This way we can have a uniform API to print IRModule, TensorIR function and Relax functions. commit ed0414540c9fbc063aa727cfc71bdee51a4bafdd Author: Prakalp Srivastava Date: Wed Jun 29 08:20:17 2022 -0700 Update tests to use `set_input` for rpc calls. (#173) Fix relax-hexagon tests to use set_input api, which is the correct way to invoke a function over RPC. commit 1f962bda7a79d13fee1a4f9f4ad3ddde4f5467b2 Author: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Tue Jun 28 20:49:33 2022 -0400 [BYOC][PASS] Prototype implementation of modular compilation w/ TensorRT (#164) This PR delivers the prototype of the followings: - Relax BYOC JSON codegen - Relax BYOC TensorRT codegen - Extension in Relax VM to support external modules - `RunCodegen` pass: run codegen for the annotated relax functions - Annotation (dispatch decision) will be done by earlier passes e.g., greedy heuristic, Collage - The generated runtime module and Codegen itself should be tvm object - Misc minor code improvement for other passes commit f25fe0c80670272582db3aa791901c7fa49fc59e Author: Prakalp Srivastava Date: Tue Jun 28 12:47:07 2022 -0700 Run static/dynamic models over Hexagon using Relax VM RPC (#167) * Move Relax VM builtins to src/runtime. * This fixes a bug we encountered while loading the module for hexagon. Since it was building the minimal runtime it was missing definition of Relax VM builtins. * Mark Hexagon module as DSO exportable. * Load Relax VM Executable over RPC * Support allocation for shape heap on device Co-authored-by: Yuchen Jin commit 25174be634b5e04f0468b48bd477f22b17e75f84 Author: Prakalp Srivastava Date: Fri Jun 24 13:33:04 2022 -0700 [CI] Enable Hexagon CI in Jenkins. (#169) Running all Hexagon tests in simulator is very slow. So we only run Relax related hexagon tests `test_relax_integration.py`. This test file is empty right now and it would be populated as we push relax-hexagon related changes. commit 225aecdb5d7d33f2af048f3aef9c9a6ac758f4fd Author: Yuchen Jin Date: Thu Jun 23 09:47:30 2022 -0700 [VM] Add set_input interface; Fix e2e tuning script. (#166) * Add set_input interface. * Address comments. commit 29a707cbd9be6e02dd8a3cd1961cfb53057eb51b Author: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Thu Jun 16 09:07:45 2022 -0700 WellFormed Instrument (#165) * add conftest for test/python/relax * [Wellformed Check]: allow TupleType as Function parameters * move WellFromedInstrument to relax.ir.instrument * add header commit b4c3c4bb65b09db7c9b3ec114d6680d14f306d37 Author: Yong Wu Date: Sat Jun 11 23:26:17 2022 -0700 Update after rebase commit 3c0e3c0ee08c78b17cc1ba0429727c199737403a Author: Yuchen Jin Date: Sat Jun 11 18:42:29 2022 -0700 [Relay translator] Allow replacing default topi function with user-provided TIR PrimFunc. (#159) * Add replace_op_with_tir to translator. * came up with a better name * better doc. commit f250f93eed886dc2c3a1cb1f8a4ab2077c57080e Author: Yong Wu Date: Sat Jun 11 15:20:21 2022 -0700 [Pass] Lambda Lifting (#99) commit b55fd31d4e11373b30a93f88412a3d6e2d21d3c1 Author: Siyuan Feng Date: Tue Jun 7 10:07:17 2022 +0800 [E2E] End-to-End tuning e2e_script (#153) Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> commit d3f94e73ec7b9c9ac7b3675f962e9030e55fa603 Author: Prakalp Srivastava Date: Thu Jun 2 08:19:18 2022 -0700 Fix shape lowering pass bug for non i64 dims. (#152) Prior to this change, VM Shape Lowering pass did not cast integer values to shape heap dtype (i64) which resulted in incorrect values when read from heap later. This PR adds a cast to i64 for such values. This also adds well-formed check to ensure shape dimensions are of integer types. commit 9cf777f48069d598eda276be0b9aabaf301acf0f Author: Yong Wu Date: Wed Jun 1 17:52:40 2022 -0700 [Parser] Add FuncType support (#154) * [Parser] Add FuncType support * Address comments commit f99121d506df45870cd026e052f5b3c41d4bd982 Author: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Wed Jun 1 09:01:40 2022 -0700 [PASS] Remove Unused Functions in IRModule (#151) commit a718e9f9e073ca0ea1790562254c09aaa863eaa4 Author: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Tue May 31 15:15:28 2022 -0700 [Pass Infra] Tuning Pass API (#144) commit a485b7bdb45f8379daa45e8c923a47fd6871cbdf Author: Tianqi Chen Date: Sun May 29 12:51:07 2022 -0400 [REFACTOR] Move TIR op kind analysis to relax as it is relax oriented (#155) This also keep TIR mostly independent from higher-level IR. commit abd20bdc9b87aa53e0c27e8c5c3fc195be5e8c91 Author: Siyuan Feng Date: Sun May 29 23:31:05 2022 +0800 add test cases for FuseTIR (#156) commit de42ec3d5ae0f0304060460764619a5a16995a33 Author: Siyuan Feng Date: Thu May 26 22:14:51 2022 +0800 [Pass] Relax Transform FuseTIR (#150) * [Pass] Relax Transform FuseTIR Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai commit 153d0cc8f2d39b23e63fcd6feaf9755a0eaf8c28 Author: Yuchen Jin Date: Wed May 25 15:44:59 2022 -0700 [Mutator] Separate unnormalized-form and normal-form mutators (#148) commit dfa42c09a3087605e805526ab7db7b49d6752ca5 Author: Prakalp Srivastava Date: Fri May 20 16:30:18 2022 -0700 Print/parse tir cast/max operations in Relax shape (#149) tir.cast and tir.max are commonly used operators in shape expression in Relax. These two operators often show up when importing Relay module with `Any` dims to Relax module. commit c7186fd44ad5865d84ac61fc2981a15c8af9be4c Author: Prakalp Srivastava Date: Thu May 19 18:29:12 2022 -0700 Add support to import relay models with Any dim. (#146) Converts Relay Any dimension to symbolic dim in Relax. commit ef9cf6baba1c2f7215746459ad5a9193df6572c9 Author: Yuchen Jin Date: Tue May 17 07:55:56 2022 -0700 Refactor shape lowering pass and Blockbuilder. (#145) commit 230def2284c21eaff520e58fa96a80313b6a7c8f Author: Yong Wu Date: Fri May 13 14:30:05 2022 -0700 Support Closure (#140) commit 0e998988aabdeb8d913e2889eb5a9d72bee35ca2 Author: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Thu May 12 17:13:15 2022 -0700 [Analysis] IRModule well-formed check (#142) commit 1bd4e685ffcc0c4b677af47ecc8609dbfacdfd9d Author: Yong Wu Date: Wed May 11 09:31:13 2022 -0700 Change after rebase commit d0ad35b375449c7e067a1edada7502557a03dd26 Author: Siyuan Feng Date: Tue May 10 08:44:22 2022 +0800 FuseOps for relax (#141) Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> commit ae7b5b79c40498203842b6c9193e91bcc1937bea Author: Prakalp Srivastava Date: Wed May 4 20:52:16 2022 -0700 Add `relax.unique` operator in Relax. (#135) * Add Unique operator in Relax. This adds the functionality to register a packed function implementation of any operator using `FCallPacked` attribute. The relax operator would be lowered to a call to the registered packed function during codegen. For example, in this change relax.unique is lowered to `relax.run.unique` packed function which uses torch.unique under the hood. * Add support for integer constants in Relax VM. This adds serialization, deserialization, and print support for integer constants. commit 1ca18611ae59ab4d1667066ed9921690d2a5611c Author: Siyuan Feng Date: Tue May 3 09:34:55 2022 +0800 Add ShapeType to ShapeExpr.checked_type during construction (#139) commit 6481d533ed259a080dede704f7443c4a2221a842 Author: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Mon May 2 16:26:08 2022 -0700 Introduce Relax function attribute and drop name field in Relax function (#136) commit d735ebd719d89c804691b29ee0d881c785384fc6 Author: Yuchen Jin Date: Sat Apr 30 18:45:14 2022 -0700 [BlockBuilder] Sub function call shape deduction: constant shape case. (#137) commit 10f8e56cbcb27beb373075e3c6e3a9728ffb5eb2 Author: Yuchen Jin Date: Thu Apr 28 16:59:38 2022 -0700 [AST][Type] Introduce ObjectType; Infer the type of call_packed by type_args; Refactor InferType/InferShape. (#132) commit 7e2038a8b662659dd6ba2e2a86bedbc6c3891bfa Author: Yuchen Jin Date: Mon Apr 25 17:20:19 2022 -0700 [AST][BlockBuilder] Normalize relax.Function; Refactor BlockBuilder to take optional input IRModule. (#133) commit f1eca6d74365c6b0665b64c86ececce86fd76df3 Author: Prakalp Srivastava Date: Sun Apr 24 07:09:11 2022 -0700 [Printer][Parser] Modify Tensor annotation printing and parsing. (#128) commit 296876eaf1246ea7948c69d2111cfea2ca51ca0c Author: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Fri Apr 22 08:05:13 2022 -0700 [Pass] Python pass decorator and ExprFunctor (#126) * Relax ExprFunctor in Python * fix the register bug * Expr_functor in relax * function/dataflowblock Pass in python * testcases * reformat * fix Tensor annotation() * add return type hint * type hint * new test * fix typo * remove memo commit 5199a206cc86cee9e43b0c8ddddf704acdc4b513 Author: Ruihang Lai Date: Thu Apr 21 22:20:33 2022 +0800 [Relax][MS] Task extraction with proper weights (#129) * [Relax][MS] Task extraction with proper weights (hzfengsy#32) * Add a unit test * Update the deduplication mapping / Update the unit test * Update test for DummyDB reusing * Remove unnecessary args * Remove unused import commit badee2add6700f12671d3223e43875ca050f537a Author: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Wed Apr 20 17:09:37 2022 -0700 [Relay Translator] Use OpStrategy for lowering (#130) * [Relay Translator] Use OpStrategy for lowering * Reflect feedback and fix lint issue * Consider contexts for PassContext, Target, .. for both pass application and lowering commit 4454563d240c547fb762cec770502b1e09b195f0 Author: Prakalp Srivastava Date: Wed Apr 13 21:00:54 2022 -0700 Deprecate `[]` in favor `()` in Tensor annotation. (#123) commit fab2d95697f7eecce90cb0ba12db2457caf4f2e3 Author: Yong Wu Date: Tue Apr 12 21:15:38 2022 -0700 Add tune_relax to integrate with task scheduler (#127) commit 39bab0d25f3e5bb48adf52534f2318149047f617 Author: Yong Wu Date: Tue Apr 12 16:22:33 2022 -0700 Update autotir integration after rebase commit caae30f06d237c3aebd00290802122bbfdb2ae26 Author: Yuchen Jin Date: Tue Apr 12 08:23:32 2022 -0700 [VM] Support sub function call and recursion. (#125) * Sub function call and recursion. * Address comment. commit e7c7c15972f6aa29f30a167a794db17f74a6bdeb Author: Ruihang Lai Date: Tue Apr 12 14:18:32 2022 +0800 [VM] Copy constant tensors to device (#124) * [VM] Copy constants to device (Hzfengsy#24) * [VM] Copy constants to device * Add unit tests * Specify shape and dtype for constant TE tensors in EmitTE commit ef0a3e689b3896fd30a392d094beaa8d68b6de07 Author: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Wed Apr 6 11:59:33 2022 -0700 DataflowBlockPass (#114) * add DataflowBlockPass * update fma_rewrite * drop the skip function * update test_fma_rewrite with DataflowBlockPass * fix the format * fix name * rewrite test in tvm script * add non-dataflow Vars check * add fail testcases * module->IRModule * add docstring to DataflowBlockNode * remove unused pattern * Transform Pass->DataflowBlock Pass * rename global var to global scope var * remove print stmt * reformat tests * add docstring to DataflowBlockMutator * fix filename * minor fix commit 2607f3b9112197045e773b0fc7ceb9ae57e844f8 Author: Yuchen Jin Date: Mon Apr 4 19:59:30 2022 -0700 Remove type annotation from Var. (#121) commit 969ffb4302f35344524ef36e74325c0d5e427b76 Author: Prakalp Srivastava Date: Mon Apr 4 08:33:43 2022 -0700 Add a new Expr to represent runtime dependent shapes. (#117) This can be used to represent runtime dependent shapes such as output of `unique` operator. Having explicit runtime dependent shape expression helps to distinguish the following two cases in AST - (1) shape has not been deduced (`shape_ = nullptr`), and (2) shape is runtime dependent. Previously both cases were mapped to `shape_ = nullptr`. commit 1e2a11f6326c9b3fd3807bbe5d97e4a20ce9dadd Author: Hongyi Jin <3231950289@qq.com> Date: Sun Apr 3 00:42:38 2022 +0800 [PASS] Fold constant & Bind Params (#113) * fold constant and bind params * fix test * format * format * format * address comments * format * address comment * address comment * format * fix type bug commit d441f1d0f2104b51287f9f29d9ec9f0e87f4b9d9 Author: Tianqi Chen Date: Sat Apr 2 00:00:19 2022 -0400 Temporary remove function type deduction in normalizer. (#119) * Temporary remove function type deduction in normalizer. This PR temporary removes the function type deduction in normalizer to unblock some of the followup passes that needs to check function type equality. Function's checked_type_ are left as nullptr for now. We should followup to add function type deduction from annotations. * revert the normalizer skip for now * comment out parser assert for now commit 159f599248e3c6faf969198d4e7cf03c4f3f6c70 Author: Yuchen Jin Date: Fri Apr 1 09:18:33 2022 -0700 [BlockBuilder] Deduce and fill shape/type for Expr in Normalize. (#116) commit 96c8bbc53286a0ca90ddcb92346156f23ab9efe3 Author: Yuchen Jin Date: Wed Mar 30 11:46:50 2022 -0700 [CI] Enable GPU tests; Add AutoTIR cuda test. (#115) * Add gpu ci. * Update autotir gpu test. commit 1e5c2dac7b01f73c7e3e1a8b092eb0f2b6cc5e28 Author: Tianqi Chen Date: Mon Mar 28 19:12:59 2022 -0400 [FIX] Fix structure equal hash for MatchShape (#112) The pattern field of the match shape can define variables, as a result, we need to add DefEqual and Hash here. Added a regression testcase. Lesson: we would benefit from more testcases with check_save_roundtrip checks(like this one) for more relax example. Additional change: - Redirected TVMScript printer to be able to print relax fragements useful for debugging. commit 8e466be1d1fa65b9df119e0563ef58c38e8562f2 Author: Siyuan Feng Date: Tue Mar 29 01:30:07 2022 +0800 introduce blockbuilder call_te (#110) commit 6ff1614ac3c9e63ea5b615a072a1d26a197b58f9 Author: Siyuan Feng Date: Sun Mar 27 00:02:53 2022 +0800 [FIX] fix structural_equal_hash (#107) * fix structural_equal_hash (cherry picked from commit e7e962634999739a32129378f61cc95f58335447) * address comment & pass the ci commit 31ed53c92192c74a3f55009e718b8ae0527ce078 Author: Yuchen Jin Date: Fri Mar 25 10:49:00 2022 -0700 [Bugfix] Fix call_tir parsing bug (#109) * Fix call_tir parsing bug. * update. commit 3c7ff5a272d4b004b9b86b79e0f10c33635cea05 Author: Yuchen Jin Date: Thu Mar 24 19:50:27 2022 -0700 [VM] Fix hardcoded device type in memory lowering (#106) * Add is_device field to attr. * Update. * Address comment. * update. * Update. commit 6bcdcf8d02809dbbafbbd9515ea7ada17bb00077 Author: Ruihang Lai Date: Thu Mar 24 23:04:11 2022 +0800 [VM] Initialize VM through packed function (#101) commit cfc779e732933eb43cb0bca6448c51fac51dc39f Author: Yong Wu Date: Tue Mar 22 19:44:37 2022 -0700 Fix after rebase commit c368324831d378033d9b0f6621f3ee3b366624e6 Author: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Tue Mar 22 18:51:40 2022 -0700 Improve printer for DynTensorType and ShapeExpr (#97) * improve Printer for DynTensorType & ShapeExpr * add testcases commit a861f2eeadc3ded5a98aa2947a6b17f077e29dc2 Author: Ruihang Lai Date: Tue Mar 22 23:16:33 2022 +0800 [VM][Refactor] Move VM files to TVM runtime directory (#98) commit d96806093e9ff50aaf4d46a89d1003f87385bf7e Author: Tianqi Chen Date: Mon Mar 21 12:03:59 2022 -0400 [VM] Refactor and improve vm. (#96) * [VM] Refactor and improve vm. - Have a separate function for RunInstCall. - Cache func_index lookup by table to avoid repeative lookup by str. - Move PackedFunc call arg stack to Frame to increase locality and avoid re-allocation in repeative calls. - Make frame stack of unique_ptr to avoid frame re-allocation and copy during frame.resize. - Pass curr_frame as arguments into sub-functions to make it explicit. * address review comments commit b14c100835910d78a0332fd6baf7947fd224fb2c Author: Ruihang Lai Date: Sun Mar 20 22:12:20 2022 +0800 [VM] Enhance VM Executable as a Subclass of runtime::Module (#95) Enhance VM Executable as a Subclass of runtime::Module commit f885b8d5e085d244c021ff924a433533ab4b769a Author: Yuchen Jin Date: Thu Mar 17 18:51:24 2022 -0700 Change call_tir convention; Unify shape/type deduction rule (#94) * Change call_tir convention and fix shape/type deduction. * test * output shape as 3rd arg. * address comments. * lint commit d4cf8b53e7dee09f9b8806ff915809607e0607a9 Author: Masahiro Masuda Date: Sat Mar 12 06:00:40 2022 +0900 Clean up task extraction (#92) * Clean up taske extraction * black commit 841593e227cd3c15e09298920057367af49d0765 Author: Prakalp Srivastava Date: Thu Mar 10 16:56:36 2022 -0800 Fix bug in relax.vm.build to pass target argument. (#91) Co-authored-by: Prakalp Srivastava commit 9cccd4f29d129ba62ba0d5d279b6fe9a35fe6828 Author: Yong Wu Date: Wed Mar 9 22:43:11 2022 -0800 Add metadata section, support constant and metadata in parser & printer (#76) * [CI] Set up CI; format and lint relax code to pass CI (#72) * init * fix lint * update task_lint * more lint * more lint * lint * jenkinsfile * jenkinsfile * run relax only tests * python3.7 for pytest * point to personal ci-cpu docker * docker pull * test * fix cmake config * update * update * rebase * rebase * AutoTIR integration (#58) * [WIP] Basic task extraction mechanism is implemented. * [WIP] For gradual integration with Relay pipeline, meta_schedule/integration.py is created for relax to avoid potential conflict. * support tir tuning and injection mode * Add target field for Relax Extracted Task * 1. Create relax namespace/tvm objects/... for metaschedule to preserve relay support. 2. Promote target field from Optional to Target * Support ApplyHistoryBest * Reflect feedback from Yuchen * minor improvement and fix linter issue * add ASF header * Reorganize file structure * fix lint errors * remove the import-outside-toplevel * Reflect comments * remove redundant comment * As per discussion w/ Yuchen, ApplyHistoryBest is introduced as a Relax transformation pass. * remove redundant print msg * fix lint * reflect comments * Yuchen's change * relax ConstantNode in parser and printer * Add constant data in the metasection * rebase * Support ir_module(metadata=json_str) * update test case * remove print info * Update tests * clang-format * pylint * fix ci * Save a copy of metadata in RelaxTransformer * Fix comments * fix comments Co-authored-by: Yuchen Jin Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> commit 470ecc68e45ee6fc220bdecf005d64f114092dfc Author: Jinkun Chen <65396089+Robslhc@users.noreply.github.com> Date: Fri Mar 4 09:21:03 2022 +0800 [Bugfix] Fix bb multi-function creation bug (#86) commit 09dc5a0039fe01f855f65b54fdf657b8dd4c80a1 Author: YuchenJin Date: Wed Mar 2 11:00:41 2022 -0800 Rebase. commit 786c8f4bd75e11f86935e4366c79d376876ec2ac Author: Josh Fromm Date: Tue Mar 1 16:03:41 2022 -0800 Make offset type specific to avoid errors on non-linux systems. (#84) commit c2f0b86f292121eddece414b971d422bdf5ca2b3 Author: Ziheng Jiang Date: Fri Feb 25 12:18:08 2022 -0800 [TESTS] Enable Tests (#78) * Enable tests. * Updated. * Updated. * Updated. commit 449a2679e3bbde2582f26dc5eea4f8c6a0b97eb6 Author: Yuchen Jin Date: Fri Feb 25 10:43:35 2022 -0800 Bug fix; print ShapeExpr (#82) commit 113849c4d16e6885805c3a9b143fecf8600b00ad Author: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Mon Feb 14 16:16:15 2022 -0800 AutoTIR integration (#58) * [WIP] Basic task extraction mechanism is implemented. * [WIP] For gradual integration with Relay pipeline, meta_schedule/integration.py is created for relax to avoid potential conflict. * support tir tuning and injection mode * Add target field for Relax Extracted Task * 1. Create relax namespace/tvm objects/... for metaschedule to preserve relay support. 2. Promote target field from Optional to Target * Support ApplyHistoryBest * Reflect feedback from Yuchen * minor improvement and fix linter issue * add ASF header * Reorganize file structure * fix lint errors * remove the import-outside-toplevel * Reflect comments * remove redundant comment * As per discussion w/ Yuchen, ApplyHistoryBest is introduced as a Relax transformation pass. * remove redundant print msg * fix lint * reflect comments commit 4ac905b18471ccadb93b8fbfea1bc5e92eea51c9 Author: Yuchen Jin Date: Mon Feb 14 09:06:53 2022 -0800 Relay->Relax translator (ResNet example) (#75) * Relay translator; build static bert forward/backward. * rebase * Add ops. * resnet demo * cleanup code. * Rebase. * Address comments. * leverage FTVMCompute for most op translation; reuse relay.Constant. * lint. commit 6d69587f3ec0aa31b17b4d2e2814a5fb369fc342 Author: Yong Wu Date: Wed Feb 9 10:33:33 2022 -0800 Parse relax TupleGetItem (#77) * Parse relax TupleGetItem * Incorporate comments * fix comment commit 7239ec50f9493d5cf089964dea8a728da0a866dd Author: Yuchen Jin Date: Tue Feb 1 16:30:35 2022 -0800 [Pass] Refactor shape lowering pass to better handle static shape cases (#74) * Refactor shape lowering * Add regression test for generating unique name in AddFuncToContext. commit 9e1ac698f1599775a182b297f8e5d7accdf48f6f Author: Yuchen Jin Date: Tue Feb 1 16:28:07 2022 -0800 [BlockBuilder] Emit TupleGetItem (#73) * Emit TupleGetItem. * Remove a branch because TupleNode is not leaf node. * Address comment. commit 0971a8f55784cc95eca8ab9e4d8456e16c016ca2 Author: Yong Wu Date: Tue Jan 25 16:25:29 2022 -0800 [VM] Add control flow in VmCodeGen, add vm.builtin.copy (#69) * Add control flow in VmCodeGen, Merge test and target registers of If * Update test case * Add vm.builtin.copy packed func * fix lint * Remove true_offset of If instruction * fix comments commit 10b4efe9a0d21d97fb52b3d46a69dbdb5ab3c015 Author: Yuchen Jin Date: Mon Jan 24 17:18:16 2022 -0800 [CI] Set up CI; format and lint relax code to pass CI (#72) * init * fix lint * update task_lint * more lint * more lint * lint * jenkinsfile * jenkinsfile * run relax only tests * python3.7 for pytest * point to personal ci-cpu docker * docker pull * test * fix cmake config * update * update * rebase * rebase commit 11435eb3c1c6911d4455af7fafc5a50ab89c2a63 Author: Yuchen Jin Date: Mon Jan 24 14:50:20 2022 -0800 [BlockBuilder] Avoid generating duplicated PrimFunc (#68) * Avoid generating duplicated primfunc. * Move logic to c++. * Update method names. commit a938307e9903e9e5362e4f1887937e74c7b8c1c2 Author: Andrew Liu Date: Sun Jan 23 18:12:47 2022 -0800 [CreatePrimFunc] Support multi-source ReduceNode (#64) * initial * assert structural equal test commit 7f5ae9c5e8b79f1e3166f802b6e9f2408ea3dc4f Author: Junru Shao Date: Mon Jan 17 18:37:49 2022 -0800 xfail => pytest.raises; fix a unittest (#67) commit 6a0a2d1287c5c7e1d528b9f6f15cec556990b168 Author: Andrew Liu Date: Mon Jan 17 13:12:00 2022 -0800 [EmitTE] multi-output semantics for call_tir, Tuples (#62) * call_tir multi-output and tuples working * CallTIR checked_type_ quick fix * use te.compute fcompute returns multiple * comments commit aeb635fc6e04a57002ffeacf91a04a6794837776 Author: Junru Shao Date: Mon Jan 17 12:04:35 2022 -0800 [Relax/IR] Disallow creating Binding directly (#66) commit 0ae9d44abd55170af1d635f67427acc77ccd5e3e Author: Junru Shao Date: Sun Jan 16 12:18:49 2022 -0800 [Refactor] Format; Simplify PackedFunc Registration; Unify parameter order of `alloc_tensor` (#65) * [Refactor] Simplify the global registration logic * Address comments; Pass tests; Format stuff * Address comments commit e87a582a9f00e00e55137a3011a1ad4b4db8dde2 Author: Yuchen Jin Date: Thu Jan 13 17:28:19 2022 -0800 [VM] Add control flow to relax vm (#61) * Add if and goto instr. * Update python/tvm/relax/exec_builder.py Co-authored-by: Yong Wu * use set_body_method. Co-authored-by: Yong Wu commit 7f8cb36eaa0ecf60c2ef9681e7a8436da9e8c211 Author: Lily Orth-Smith Date: Wed Jan 5 15:42:35 2022 -0800 call_dps -> call_tir (#60) * Rename call_dps to call_tir * Rename call_dps_rewrite.cc commit a55c4b85777d6d200d8e08158bde5f0eae71894f Author: Andrew Liu Date: Tue Dec 14 16:35:09 2021 -0800 [EmitTe] Dynamic TIR Function (w/ unbound TIR vars) (#57) * initial WIP: relax.vm.call_tir_dyn_lowered * fix test * cleanup and comments * comments * use nn.Placeholder * remove _check_te_args * move mod_ into VMState * add packed integer values as optional argument to call_dps * add special case for call_dps in parser for optional argument * comments commit b20f7a96559949b720ee057a9e0be36682b89f7e Author: Yuchen Jin Date: Mon Dec 6 16:49:48 2021 -0800 [TESTING] pytorch-like nn.Module API to build neural network (#54) * nn module * address comments. * Add nn.init_params * Remove nn.Builder and use BlockBuilder instead. * Rebase. * Refactor block builder and add tests. * Address comments. * Update. commit 6a217afb152d22999782fc2616ece5e8bbca56bb Author: Yuchen Jin Date: Tue Nov 30 19:51:28 2021 -0800 Update vm build. (#55) commit 6ddf998c7d9789867d323250903c1f97d4375ac8 Author: Andrew Liu Date: Tue Nov 30 17:16:32 2021 -0800 [EmitTE] EmitTE Symbolic Shape (#53) commit 346fe5c94f60883ef3ea5a06e1de4cfa3b9ba729 Author: Yuchen Jin Date: Wed Nov 24 13:22:08 2021 -0800 Call topi and external library through emit_te and add MLP example (#50) commit df2c06813f95c6ba6dbc1fa5d509c82462541b80 Author: Yuchen Jin Date: Wed Nov 24 05:53:46 2021 -0800 Visit shape in Visitor/Mutator (#45) commit 0fc4547e889fc57c54dfb9e01318f602f0092347 Author: Ziheng Jiang Date: Sun Nov 21 12:43:30 2021 -0800 TE Integration (#36) * Init. * Proof of concept. * Rebase on the newest branch * Move to emit_te * Update emit_te * Make RXPlaceholderOpNode as a subclass of PlaceholderOpNode * Update * run vm test_te * Update argument conversion * Reset create_primfunc * Update doc * Update test * Add error message * Update * Update * Address comment * unit test check structural and validate_te_args * raise ValueError when multiple outputs * address comments * example usage emit_te * Rename to context_mod * Handle multiple call * Address comments * Address comments * Use unique name * remove * rename args to te_args * address comments * fix TVMscript manually * spelling Co-authored-by: Andrew Liu commit 1f409f33fc5887e755e6b7708c53cf3ba7038597 Author: Altan Haan Date: Wed Nov 17 19:12:06 2021 -0800 fix IRModule parsing by resolving GlobalVars later (#41) * fix IRModule parsing by resolving GlobalVars later * disable fast path that causes type inference problem for now * print checked type on vars if present * document ResolveGlobals commit 685b232e09371e1ce67148994b8af70441a24456 Author: Yuchen Jin Date: Tue Nov 16 08:45:53 2021 -0800 Update Shape lowering pass (#38) * Update shape lowering pass. * Rebase. commit bafde350c1b9abdff8a8dcbc4922f0387c09d5f1 Author: Yuchen Jin Date: Mon Nov 15 10:34:53 2021 -0800 Generic dispatching in Visitor (#39) commit a0b85289291f1e947b990c065ae62354242f235d Author: Yuchen Jin Date: Fri Nov 12 07:53:37 2021 -0800 Migrate passes to Pass Infra (#37) * Migrate relax passes -> Pass infra. * Update. * Add docs and update tests. * Rebase and change namespace. * Address comments. commit 2583394395e94ffc8ac60ee2d9673b13f93f196d Author: Yuchen Jin Date: Tue Nov 9 16:10:47 2021 -0800 ExprMutator refactor & Normalizer (#32) * fixes * revert checked_type visitor and fix relax usage * ExprNormalizer * fix that annoying bug and get tests passing * Memoization fix for the ExprMutator; separate VisitVarDef from use. * rebase. * rebase. * address part of comments. * address more comments * address more comments and add doc * address more comments * fix potential mutation bug * always assign normalized shape if can * address comments Co-authored-by: Altan Haan commit 75c1d497e8df74279a592f3076545cdc5742570b Author: Yuchen Jin Date: Mon Nov 8 17:45:40 2021 -0800 Fix vm build. (#35) commit 793d7fe8ff1c688b258b3ed639f76ecfa2a6342a Author: Yuchen Jin Date: Thu Nov 4 11:05:42 2021 -0700 VM compiler refactor (#25) * Return Instruction::Arg for each CodeGenLLVM::VisitExpr_. * Change VMCompiler to be an Object from ModuleNode. * Introduce intrinsics and attrs. * Generic handling of attribute codegen. * Do to-non-dataflow transform in call_dps_rewrite. * Back to special attr handling. * Address comments. * Standalone to_non_dataflow pass; more tests. * Rename decode/make shape to store/load shape. * Update. * Fix namespace, add comments. * rebase * Rename files. * nit commit abcb622355e7b87b359fd9f5588eb24473150807 Author: Altan Haan Date: Wed Nov 3 15:27:48 2021 -0700 rebase is green commit bfdc1d38a437ffcf34fcebeb0b0d6fc3ef4f83d6 Author: tqchen Date: Sat Oct 30 11:43:21 2021 -0400 Update fixes for rebase commit afa06ed0d5d779fbc7cc5357952dfe9b127cfca9 Author: Altan Haan Date: Sat Oct 30 07:04:37 2021 -0700 Fixes and improvements (#24) commit b75b0ba4ca93f319709dd7f91de182482e24a001 Author: Ziheng Jiang Date: Fri Oct 22 14:56:42 2021 -0700 End2End Lowering (#23) * call_dps lowering. * Improve shape lowering. * Support alloc_storage for dynamic shape. * implementt ToNonDF to transform program to non-dataflow format. * Fix the mutator issue. * Update build api, an issue occurred. * vm tests can pass. * Support shape tuple in executable seriablization. * Fix for test. * Minor fixes. * Address comments. * Add mutate binding var back. * Visit binding var and fix tests. Co-authored-by: YuchenJin commit 6acf69e363572e774b8a2b645a02026a1736a6b3 Author: Ziheng Jiang Date: Mon Oct 18 12:18:07 2021 -0700 End2End Lowering Stage2: Enable Lowering from ShapeExpr to VM Executable (#21) * rebase. * Update. * Update shape lowering, make sure the lowering pipeline works. * Address comment. commit dc11b5bd17a661b00278209123214a162fd78d45 Author: Yuchen Jin Date: Sun Oct 17 14:58:28 2021 -0700 Redesign IRBuilder to BlockBuilder (#22) * init * update * update * test case working * update and add multi block test case * check in * fixes * fix * update * add * update * add * update * address comments. Co-authored-by: Altan Haan commit a8878ffef4da2124c6c4e560b42ce00b122c1016 Author: Yuchen Jin Date: Wed Oct 6 15:35:05 2021 -0700 Add type hint. (#20) commit 601978f67e4520db30890a069090abd241de1bae Author: Yuchen Jin Date: Wed Oct 6 15:34:09 2021 -0700 VM compiler. (#18) * VM compiler. * Update. * Compile IRmodule; expose Python api * Add dtype contant serialization and type hint. * Address comments. * Add todos and fix lint. * Update * Update. commit e997f430c57eca9ed91864776383ac873c377f0c Author: Altan Haan Date: Thu Sep 30 14:31:13 2021 -0700 [Parser][Printer] explicitly parse and print attrs_type_key in calls (#19) * relax call_packed arity, return IRModule factory, print IRModule PrimFuncs * explicitly parse and print attrs_type_key on calls * print type even when attrs has no fields commit 0169aeeeb6d6d17fce77e8c91013ba9edcb52679 Author: Ziheng Jiang Date: Wed Sep 29 17:17:06 2021 -0700 [PASS] Shape lowering (#16) * [PASS] Shape lowering. * Update to IRModule based. * TIR function generation. * Improve. * Improve. * Improve test. * Improve. * Address comment. commit 7e95f3f22f2cb7400e342658355d007e3dafc044 Author: Altan Haan Date: Wed Sep 29 13:40:48 2021 -0700 [Parser][Printer] relax call_packed arity, return IRModule factory, print IRModule PrimFuncs (#17) commit fe12c055a999609e35f7ea7e51179a248d3c87d1 Author: Altan Haan Date: Tue Sep 28 19:31:46 2021 -0700 [Parser][Printer] Add class -> IRModule parsing, and extern func support for call_dps (#15) * update parser and printer for match_shape * support parsing class to IRModule, and extern func in call_dps commit 3fed13c175cf8e4b48fb3cbb546bfadb720fded7 Author: ziheng Date: Mon Sep 27 18:00:44 2021 -0700 Reorganize source code. (#14) commit 915947cc2b03c028f79fc338882ccde84c32a659 Author: Altan Haan Date: Mon Sep 27 16:45:51 2021 -0700 [Parser][Printer] update parser and printer for match_shape (#13) commit 0fd56015c37abc75dc4ffbb1cacdfd5182ffe41f Author: Yuchen Jin Date: Mon Sep 27 15:58:59 2021 -0700 Relax IRVisitor/IRMuator (#10) * ExprVisitor/ExprMutator for relax nodes. * Update Visitor & Mutator. * Update Mutator. * DataflowMutator interface. * EwiseFMARewriter. * Update fma rewrite and add test. * Update test. * Fix dataflow block dispatching. * Construct new dataflow block with IRBuilder. * VisitBinding return void and mutate internal IRBuilder. * Simplify. * Update emit dataflow output. * Explicit memeory allocation rewrite. * LazyIRBuilder. * Update ExplicitMemMutator. * Overload IRBuilder::Emit to have 3 styles. * Update IRBuilder/IRMutator interfaces and passes. * Add MatchShape binding to IRBuilder. * Improve IRMutator interface; add Normalize and CanProveShapeEqual to IRBuilder * Update EmitMatchShape. Co-authored-by: ZihengJiang commit 4c7f23ebc4028dce42f58b397fe4bbc9e9e61139 Author: Altan Haan Date: Mon Sep 27 14:03:36 2021 -0700 [Parser][Printer] More parser/printer improvements (#12) * Relax pretty printer initial prototype * call into TVMScriptPrinter for PrimFuncs * most round-trip tests pass * address comments * implement relax.output syntax for dataflow block outputs * remove leftover comments * fix Var constructor on ShapeExpr annotation * add printing and parsing for simple PrimExpr and Call Attrs commit 76727634b4afdaa23d7cb1bd9cd2fb71ccb180ac Author: ziheng Date: Mon Sep 27 11:30:46 2021 -0700 Update MatchShape AST Node (#11) * Update MatchShape AST Node. * Update. * Update. commit b4c5010a275560d122bdd7d216234cc9885337ff Author: Altan Haan Date: Fri Sep 24 15:06:49 2021 -0700 [Parser][Printer] Switch to output annotation for dataflow blocks (#9) * Relax pretty printer initial prototype * call into TVMScriptPrinter for PrimFuncs * most round-trip tests pass * address comments * implement relax.output syntax for dataflow block outputs * remove leftover comments * fix Var constructor on ShapeExpr annotation * fix DataflowVar as well commit 0900727b4f220651649858b13b5d22f461d973ff Author: Altan Haan Date: Thu Sep 23 13:36:07 2021 -0700 Relax pretty printer (#8) * Relax pretty printer initial prototype * call into TVMScriptPrinter for PrimFuncs * most round-trip tests pass * address comments * fix typo commit 5ef7b23f8191eb4430de9dcd791d005307f80911 Author: Yuchen Jin Date: Mon Sep 20 17:04:25 2021 -0700 Shape and type deduction (#7) * Shape and type deduction. * Fix header. * Add call attrs to the deduce signature. * Address comments. * Add DiagnosticContext to IRBuilder and inference signature. * Fix nits. commit 26f7c4aeb483d976a4bc7efcd929b756209d34f8 Author: Altan Haan Date: Mon Sep 13 14:06:07 2021 -0700 Relax IR Parser (#6) * Copy jared's frontend * Remove some extraneous code + add TODOs * Skeleton AST * Added more skeleton AST, worked on parsing shape annotations. Something is wrong with span_to_span * Fix spans * Type annotations parsing correctly * some match_shape support * More bug fixes! Some stuff parses. Importing into tests is messed up. We probably need to restructure this code as well. * refactor parser and fill out more stubs * some parser tests * yolo dataflow * checkpoint for rebase * hook up AST * add inline TIR parsing * some cleanup * support call_packed parsing to ExternFunc call * remove stub ops * improve docstrings * address nits * support coercing tuples to ShapeExpr when possible for call_dps Co-authored-by: electriclilies commit 58a5a8907bcb5bfff12e044e8d3d86b8ff1315cf Author: Yuchen Jin Date: Tue Aug 31 12:41:26 2021 -0700 Relax IRBuilder (#4) * Add initial IRBuilder. * Add function output to irbuilder; update based on new AST. * Add call method; clean up bindings * Add test. * Add multifuction test * Move implementation to C++; infer shape and type * update op python hook * More tests and bug fix * Add comments. * Update shape/type inference. * Restructure code; add python type hint. * Cleanup code. * Rebase; address comments. * Add call intrinsic. * nits. * Remove call op. * Migrate scope to C++ using tvm::With. * Address naming. * Add GetBlocks API. * Unify EmitOutput APIs; add more comments. * Remove shape and type deduction code. * Also remove the shape/type attr interface. * Address comments. * Differentiate global and local function. * Reset counter after building func/block. * Rebase. * Remove shape infer builtin. * Return from void function as empty tuple. Co-authored-by: Michalis Papadimitriou commit 2340269e1cae6d231def60721bffea8458f5442e Author: ziheng Date: Mon Aug 30 13:28:52 2021 -0700 Update AST and Shape() implementation (#5) * Update AST. * ShapeOf. * ShapeOf. * Address comment. commit 2d51aee116e58595b2d280c7831106f05055895d Author: ziheng Date: Sat Aug 21 10:13:16 2021 -0700 Implementation of CallDPS (#3) * Implementation of call_dps. * Implementation of PackedFuncExpr. * Test CallDPS for TIR function. * Rename. * Add header and comments. * Update. * Address comments. commit cb03c5e88942a5b402d96b9f1ceeab464425545e Author: Jared Roesch Date: Thu Aug 12 14:32:02 2021 -0700 Relax AST (#2) Co-authored-by: ZihengJiang commit 97a96eae2eb0822e6f25bbacb678c49e4621ce62 Author: ZihengJiang Date: Mon Apr 19 05:12:52 2021 -0700 Relax Virtual Machine Co-Authored-By: Yuchen Jin commit 9460385b974f3119cb601e3cb455ee8100f93f8e Author: Tianqi Chen Date: Fri May 21 15:31:13 2021 -0400 disable GH --- 3rdparty/cutlass | 2 +- cmake/modules/contrib/CUTLASS.cmake | 4 +- cmake/modules/contrib/DNNL.cmake | 8 +- .../using_pipeline_executor.py | 8 +- include/tvm/relax/dataflow_matcher.h | 3 + python/tvm/contrib/cutlass/build.py | 33 +- python/tvm/relax/dpl/context.py | 2 +- python/tvm/relax/dpl/pattern.py | 4 +- python/tvm/relax/transform/transform.py | 6 + src/relax/analysis/var2value.cc | 1 + .../contrib/codegen_json/codegen_json.h | 2 +- src/relax/backend/contrib/cutlass/codegen.cc | 1074 +++++++++++++++++ src/relax/backend/contrib/dnnl/codegen.cc | 128 ++ src/relax/ir/dataflow_matcher.cc | 19 +- src/relax/transform/fuse_ops.cc | 132 +- src/relay/backend/contrib/cutlass/codegen.cc | 4 +- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 4 +- tests/python/relax/test_codegen_cutlass.py | 263 ++++ tests/python/relax/test_codegen_dnnl.py | 207 ++++ 19 files changed, 1869 insertions(+), 35 deletions(-) create mode 100644 src/relax/backend/contrib/cutlass/codegen.cc create mode 100644 src/relax/backend/contrib/dnnl/codegen.cc create mode 100644 tests/python/relax/test_codegen_cutlass.py create mode 100644 tests/python/relax/test_codegen_dnnl.py diff --git a/3rdparty/cutlass b/3rdparty/cutlass index a3bcc6981d..8b42e751c6 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit a3bcc6981d5dad3afb212689e2c7853d1b1ee45d +Subproject commit 8b42e751c63ba219755c8ed91af5f6ec1ecc1ee6 diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index afd5ef5302..4b4ef355b6 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -16,8 +16,8 @@ # under the License. if(USE_CUDA AND USE_CUTLASS) - tvm_file_glob(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc) - list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc) + list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC}) message(STATUS "Build with CUTLASS") endif() diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake index 7547af81eb..857f7bdfd5 100644 --- a/cmake/modules/contrib/DNNL.cmake +++ b/cmake/modules/contrib/DNNL.cmake @@ -21,8 +21,8 @@ if(IS_DIRECTORY ${USE_DNNL}) message(WARNING "Cannot find DNNL library at ${USE_DNNL}.") else() add_definitions(-DUSE_JSON_RUNTIME=1) - tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) - list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc src/relax/backend/contrib/dnnl/*.cc) + list(APPEND COMPILER_SRCS ${DNNL_CONTRIB_SRC}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -34,8 +34,8 @@ if(IS_DIRECTORY ${USE_DNNL}) endif() elseif((USE_DNNL STREQUAL "ON") OR (USE_DNNL STREQUAL "JSON")) add_definitions(-DUSE_JSON_RUNTIME=1) - tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) - list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc src/relax/backend/contrib/dnnl/*.cc) + list(APPEND COMPILER_SRCS ${DNNL_CONTRIB_SRC}) find_library(EXTERN_LIBRARY_DNNL dnnl) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) diff --git a/gallery/how_to/work_with_relay/using_pipeline_executor.py b/gallery/how_to/work_with_relay/using_pipeline_executor.py index 8f61368656..4a28a59251 100755 --- a/gallery/how_to/work_with_relay/using_pipeline_executor.py +++ b/gallery/how_to/work_with_relay/using_pipeline_executor.py @@ -29,12 +29,8 @@ from tvm import relay from tvm.relay import testing import tvm.testing -from tvm.contrib.cutlass import ( - has_cutlass, - num_cutlass_partitions, - finalize_modules, - finalize_modules_vm, -) +from tvm.contrib.cutlass import finalize_modules + img_size = 8 ####################################################################### diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index e394e9ff53..ae0c7e548c 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -45,6 +45,9 @@ namespace relax { */ bool MatchExpr(DFPattern pattern, Expr expr, Optional> var2val = NullOpt); +Optional> ExtractMatchedExpr( + DFPattern pattern, Expr expr, Optional> bindings_opt = NullOpt); + /** * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. * \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 68d8fe7cef..af95622d76 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -20,7 +20,7 @@ import os import multiprocessing import tvm -from tvm import runtime, relay +from tvm import runtime, relay, relax from tvm.contrib.nvcc import get_cuda_version from tvm._ffi.registry import register_func from .gen_gemm import CutlassGemmProfiler @@ -516,6 +516,22 @@ def tune_cutlass_function( ) +@register_func("contrib.cutlass.compile") +def compile_cutlass_module(c_source_module): + # TODO: Pass them as param + tmp_dir = "tmp" + compile_config = {"sm": 80, "threads": -1, "use_fast_math": False} + + function_names = c_source_module.get_function("get_func_names")() + compile_options = _get_cutlass_compile_options(**compile_config) + lib_path = os.path.join(tmp_dir, "cutlass.o") + logger.info("Compiling generated CUTLASS code") + c_source_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options) + + # Recover static library + return tvm.runtime.load_static_library(lib_path, function_names) + + @register_func("relay.ext.cutlass.compile_for_cutlass") def compile_for_cutlass(mod, cutlass_target): """Given an IRModule with at least one Compiler='cutlass' Relay function, return a @@ -558,6 +574,8 @@ def compile_for_cutlass(mod, cutlass_target): logger.info("Creating CSource module for CUTLASS") create_c_source_module = tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module") c_module = create_c_source_module(mod) + + # TODO: use compile_cutlass_module above function_names = c_module.get_function("get_func_names")() compile_options = _get_cutlass_compile_options(**compile_config) lib_path = os.path.join(tmp_dir, "cutlass.o") @@ -633,3 +651,16 @@ def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", fo.write(code) lib = tvm.runtime.load_module(lib_path) return tvm.runtime.vm.Executable.load_exec(code, lib) + + +def finalize_modules_relax( + vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", tmp_dir="./tmp" +): + lib_path = os.path.join(tmp_dir, lib_path) + vmcode_path = os.path.join(tmp_dir, vmcode_path) + + lib = vm_exec.mod + lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") + lib = tvm.runtime.load_module(lib_path) + + return relax.vm.Executable(lib) diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py index a621d31460..69a5e70ed0 100644 --- a/python/tvm/relax/dpl/context.py +++ b/python/tvm/relax/dpl/context.py @@ -20,7 +20,7 @@ from typing import Optional, Dict import tvm -from tvm.relax import DataflowBlock, Var +from ..expr import DataflowBlock, Var from .pattern import DFPattern from . import _ffi as ffi diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 31dbffda4a..7c360e57ab 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -25,10 +25,10 @@ import tvm import tvm._ffi as tvm_ffi from tvm.ir.expr import PrimExpr -from tvm.relax import Expr, Var from tvm.relay.op import get from tvm.ir.container import Array +from ..expr import Expr, Var from ...ir import make_node from ...runtime import Object from ...ir.base import Node @@ -198,7 +198,7 @@ def match(self, expr, var2val: Optional[Dict[Var, Expr]] = None) -> bool: Unlike Relay whose function is an expression, functions in Relax consists of blocks of bindings that they are not syntactically connected. We use a mapping (i.e., var2val) to migrate the gap. For example, to when matching - "relax.add(lv0, lv1)", given var2val, we match lv0's binded expression + "relax.add(lv0, lv1)", given var2val, we match lv0's bound expression when the recursive pattern matching goes to check lv0. The var2val mapping can be computed through the tvm.relax.analysis.get_var2val function. """ diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 323b4f3d1a..d657dd9e8b 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -25,6 +25,7 @@ import tvm.ir from tvm.runtime import NDArray from . import _ffi_api +from ..dpl import DFPattern @tvm._ffi.register_object("relax.FunctionPass") @@ -286,6 +287,11 @@ def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass: return _ffi_api.FuseOps(fuse_opt_level) # type: ignore +def FuseOpsByPattern(pattern: DFPattern) -> tvm.ir.transform.Pass: + """TODO""" + return _ffi_api.FuseOpsByPattern(pattern) # type: ignore + + def FuseTIR() -> tvm.ir.transform.Pass: """Fuse primitive relax function into a larger TIR function if possible diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc index d034afeb21..0e30427397 100644 --- a/src/relax/analysis/var2value.cc +++ b/src/relax/analysis/var2value.cc @@ -28,6 +28,7 @@ class Var2ValAnalysis : public relax::ExprVisitor { tvm::runtime::Map var2value_; void VisitBinding_(const VarBindingNode* binding) override { var2value_.Set(binding->var, binding->value); + VisitExpr(binding->value); } }; diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 809156bfeb..7daa63f7b6 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -358,7 +358,7 @@ class JSONSerializer // TODO(@sunggg): Revisit when we have op naming convention. // Currently, simply remove "relax." prefix to make it work. - name = std::string("tensorrt.") + name.substr(6); + name = std::string("dnnl.") + name.substr(6); std::vector inputs; for (const auto& arg : cn->args) { diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc new file mode 100644 index 0000000000..f4380071ed --- /dev/null +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -0,0 +1,1074 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/cutlass/codegen.cc + * \brief Implementation of the CUTLASS JSON serializer. + */ +#include +#include +#include +#include + +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using Str2StrMap = std::unordered_map; + +static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, + {"float32", "float"}, + {"int8", "int8_t"}, + {"uint8", "uint8_t"}, + {"int32", "int32_t"}}; + +constexpr const char* kAnyDim = "Any"; + +std::string GetDimAsStr(ObjectRef dim) { + if (auto d = dim.as()) { + return std::to_string(d->value); + } + return kAnyDim; +} + +inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int indent = 2) { + for (int i = 0; i < indent; ++i) { + os << " "; + } + os << stmt; +} + +Str2StrMap ArgsCommon(const Map& attrs) { + Str2StrMap args; + auto arg0_dtype = std::string(attrs["arg0_dtype"].as()->data); + auto arg1_dtype = std::string(attrs["arg1_dtype"].as()->data); + auto ret_dtype = std::string(attrs["ret_dtype"].as()->data); + args["ElementInputA"] = dtype_map.at(arg0_dtype); + args["ElementInputB"] = dtype_map.at(arg1_dtype); + args["ElementOutput"] = dtype_map.at(ret_dtype); + args["op_def"] = std::string(attrs["cutlass_op_def"].as()->data); + args["op_name"] = std::string(attrs["cutlass_op_name"].as()->data); + args["op_type"] = std::string(attrs["op_type"].as()->data); + return args; +} + +Str2StrMap GemmArgsCommon(const Map& attrs) { + Str2StrMap args = ArgsCommon(attrs); + args["lda"] = std::string(attrs["lda"].as()->data); + args["ldb"] = std::string(attrs["ldb"].as()->data); + args["ldc"] = std::string(attrs["ldc"].as()->data); + return args; +} + +Str2StrMap DenseArgs(const Map& attrs) { + Str2StrMap args = GemmArgsCommon(attrs); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + args["M"] = GetDimAsStr(arg0_shape->at(0)); + args["K"] = GetDimAsStr(arg0_shape->at(1)); + args["N"] = GetDimAsStr(arg1_shape->at(0)); + return args; +} + +Str2StrMap BatchMatmulArgs(const Map& attrs) { + Str2StrMap args = GemmArgsCommon(attrs); + args["batch"] = GetDimAsStr(attrs["batch"]); + args["batch_stride_A"] = GetDimAsStr(attrs["batch_stride_A"]); + args["batch_stride_B"] = GetDimAsStr(attrs["batch_stride_B"]); + args["batch_stride_C"] = GetDimAsStr(attrs["batch_stride_C"]); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + args["M"] = GetDimAsStr(arg0_shape->at(1)); + args["K"] = GetDimAsStr(arg0_shape->at(2)); + args["N"] = GetDimAsStr(arg1_shape->at(1)); + return args; +} + +void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs, + const std::vector& func_args, const std::string& kernel, + bool has_bias, bool is_gelu, int m_axis_idx, int n_axis_idx, int k_axis_idx) { + CutlassPrint(gemm_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); + CutlassPrint(gemm_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); + CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(gemm_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(gemm_decl, attrs.at("op_def")); + CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n"); + + auto get_dim = [&attrs, &func_args](const std::string& axis, int arg_idx, int axis_idx) { + if (attrs.at(axis) == kAnyDim) { + return func_args[arg_idx] + "->shape[" + std::to_string(axis_idx) + "]"; + } else { + return attrs.at(axis); + } + }; + CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, m_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, n_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, k_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "cutlass::gemm::GemmCoord problem_size(M, N, K);\n"); + CutlassPrint(gemm_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); + if (is_gelu) { + // GeLU epilogue does not compile with NoBetaScaling, so we explicitly specify the scale. + CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); + } else { + CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); + } + + ICHECK(func_args.size() >= 2); + CutlassPrint(gemm_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); + CutlassPrint(gemm_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); + if (has_bias) { + ICHECK(func_args.size() >= 3); + CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); + } + + CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0->data);\n"); + + CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n"); + CutlassPrint(gemm_decl, " problem_size,\n"); +} + +void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string& kernel) { + // Using the arguments, query for extra workspace required for matrix multiplication computation + CutlassPrint(gemm_decl, + "size_t workspace_size = " + kernel + "::get_workspace_size(arguments);\n"); + // Allocate workspace memory + CutlassPrint(gemm_decl, + "cutlass::device_memory::allocation workspace(workspace_size);\n"); + // Instantiate CUTLASS kernel depending on template + CutlassPrint(gemm_decl, kernel + " gemm_op;\n"); + + // Check the problem size is supported or not + CutlassPrint(gemm_decl, "cutlass::Status status = gemm_op.can_implement(arguments);\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Initialize CUTLASS kernel with arguments and workspace pointer + CutlassPrint(gemm_decl, "status = gemm_op.initialize(arguments, workspace.get());\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Launch initialized CUTLASS kernel + CutlassPrint(gemm_decl, "status = gemm_op();\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); +} + +std::string DenseOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + bool has_bias = attrs.at("op_type").find("bias") != std::string::npos; + bool is_gelu = + attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16 + std::ostringstream gemm_decl; + AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0, 1); + + CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); + if (has_bias) { + CutlassPrint(gemm_decl, " {static_cast(ptr_c_bias), 0},\n"); + } else { + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + } + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + if (has_bias && !is_gelu) { + CutlassPrint(gemm_decl, " {alpha},\n"); + } else { + // For GeLU, we explicitly specify the scale. + CutlassPrint(gemm_decl, " {alpha, beta},\n"); + } + CutlassPrint(gemm_decl, " 1};\n"); // split_k_slices + + AppendGemmExecute(gemm_decl, "Gemm"); + return gemm_decl.str(); +} + +std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + std::ostringstream gemm_decl; + AppendPrologue(gemm_decl, attrs, func_args, "BatchedGemm", false, false, 1, 1, 2); + + auto get_batch_stride = [&attrs, &func_args](const std::string& name, int arg0_idx, int arg1_idx, + int arg0_axis_idx, int arg1_axis_idx) { + if (attrs.at(name) == kAnyDim) { + return func_args[arg0_idx] + "->shape[" + std::to_string(arg0_axis_idx) + "] * " + + func_args[arg1_idx] + "->shape[" + std::to_string(arg1_axis_idx) + "]"; + } else { + return attrs.at(name); + } + }; + + CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_A", 0, 0, 1, 2) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_B", 1, 1, 1, 2) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); + CutlassPrint(gemm_decl, " {alpha, beta},\n"); + + if (attrs.at("batch") == kAnyDim) { + CutlassPrint(gemm_decl, func_args[0] + "->shape[0]" + "};\n"); + } else { + CutlassPrint(gemm_decl, attrs.at("batch") + "};\n"); + } + + AppendGemmExecute(gemm_decl, "BatchedGemm"); + return gemm_decl.str(); +} + +Str2StrMap Conv2dArgs(const Map& attrs, bool is_dgrad = false, + bool is_wgrad = false) { + Str2StrMap args = ArgsCommon(attrs); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + auto ret_shape = attrs["ret_shape"].as(); + auto activation_shape = arg0_shape; + auto weight_shape = arg1_shape; + auto output_shape = ret_shape; + + if (is_dgrad) { + activation_shape = ret_shape; + output_shape = arg0_shape; + } else if (is_wgrad) { + activation_shape = arg1_shape; + weight_shape = ret_shape; + output_shape = arg0_shape; + } + + args["N"] = GetDimAsStr(activation_shape->at(0)); + args["H"] = GetDimAsStr(activation_shape->at(1)); + args["W"] = GetDimAsStr(activation_shape->at(2)); + args["C"] = GetDimAsStr(activation_shape->at(3)); + args["P"] = GetDimAsStr(output_shape->at(1)); + args["Q"] = GetDimAsStr(output_shape->at(2)); + args["K"] = GetDimAsStr(output_shape->at(3)); + args["R"] = GetDimAsStr(weight_shape->at(1)); + args["S"] = GetDimAsStr(weight_shape->at(2)); + args["pad_h"] = GetDimAsStr(attrs["padding"].as()->at(0)); + args["pad_w"] = GetDimAsStr(attrs["padding"].as()->at(1)); + args["stride_h"] = GetDimAsStr(attrs["strides"].as()->at(0)); + args["stride_w"] = GetDimAsStr(attrs["strides"].as()->at(1)); + args["dilation_h"] = GetDimAsStr(attrs["dilation"].as()->at(0)); + args["dilation_w"] = GetDimAsStr(attrs["dilation"].as()->at(1)); + + return args; +} + +std::string Conv2dOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args, bool has_residual_block = false) { + auto op_type = attrs.at("op_type"); + bool has_bias = op_type.find("bias") != std::string::npos; + bool no_bias_scaling = op_type != "cutlass.conv2d_bias_sigmoid" && + op_type != "cutlass.conv2d_bias_silu" && + op_type != "cutlass.conv2d_bias_hardswish"; + + const std::string op_name = attrs.at("op_name"); + std::ostringstream conv2d_decl; + CutlassPrint(conv2d_decl, attrs.at("op_def")); + CutlassPrint(conv2d_decl, "using Operation_" + op_name + + " = cutlass::conv::device::ImplicitGemmConvolution<" + op_name + + ">;\n"); + CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + op_name + ";\n"); + CutlassPrint(conv2d_decl, "using ElementInputA = Conv2d::ElementA;\n"); + CutlassPrint(conv2d_decl, "using ElementInputB = Conv2d::ElementB;\n"); + CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = Conv2d::ElementAccumulator;\n"); + + auto get_dim = [&attrs](const std::string& axis, const std::string& var_name, int axis_idx) { + if (attrs.at(axis) == kAnyDim) { + return var_name + "->shape[" + std::to_string(axis_idx) + "]"; + } else { + return attrs.at(axis); + } + }; + + CutlassPrint(conv2d_decl, "int N = " + get_dim("N", func_args[0], 0) + ";\n"); + CutlassPrint(conv2d_decl, "int H = " + get_dim("H", func_args[0], 1) + ";\n"); + CutlassPrint(conv2d_decl, "int W = " + get_dim("W", func_args[0], 2) + ";\n"); + CutlassPrint(conv2d_decl, "int C = " + attrs.at("C") + ";\n"); + CutlassPrint(conv2d_decl, "int K = " + attrs.at("K") + ";\n"); + CutlassPrint(conv2d_decl, "int R = " + attrs.at("R") + ";\n"); + CutlassPrint(conv2d_decl, "int S = " + attrs.at("S") + ";\n"); + CutlassPrint(conv2d_decl, "int P = " + get_dim("P", "out0", 1) + ";\n"); + CutlassPrint(conv2d_decl, "int Q = " + get_dim("Q", "out0", 2) + ";\n"); + CutlassPrint(conv2d_decl, "int pad_h = " + attrs.at("pad_h") + ";\n"); + CutlassPrint(conv2d_decl, "int pad_w = " + attrs.at("pad_w") + ";\n"); + CutlassPrint(conv2d_decl, "int stride_h = " + attrs.at("stride_h") + ";\n"); + CutlassPrint(conv2d_decl, "int stride_w = " + attrs.at("stride_w") + ";\n"); + CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n"); + CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n"); + + const bool use_split_k = op_name.find("splitk") != std::string::npos; + + if (use_split_k) { + std::string split_k_slices = op_name.substr(op_name.find_last_not_of("0123456789") + 1); + CutlassPrint(conv2d_decl, "int split_k_slices = " + split_k_slices + ";\n"); + } else { + CutlassPrint(conv2d_decl, "int split_k_slices = 1;\n"); + } + + CutlassPrint( + conv2d_decl, + "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, " + "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, " + "split_k_slices);\n"); + + const std::string split_k_mode = use_split_k ? "kParallel" : "kSerial"; + CutlassPrint(conv2d_decl, + "const cutlass::conv::SplitKMode split_k_mode = cutlass::conv::SplitKMode::" + + split_k_mode + ";\n"); + + bool is_wgrad = op_type.find("backward_weight") != std::string::npos; + bool is_dgrad = op_type.find("conv2d_transpose") != std::string::npos; + + ICHECK(func_args.size() >= 2); + CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); + CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); + + if (has_residual_block) { + ICHECK(func_args.size() >= 4); + CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] + "->data);\n"); + CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] + "->data);\n"); + } else if (has_bias) { + ICHECK(func_args.size() >= 3); + CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); + } + + CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n"); + CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); + if ((!has_bias || no_bias_scaling) && !has_residual_block) { + CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); + } else { + CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); + } + CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n"); + CutlassPrint(conv2d_decl, + "auto activation_shape = TensorNHWC::packed(cutlass::make_Coord(N, H, W, C));\n"); + CutlassPrint(conv2d_decl, + "auto weight_shape = TensorNHWC::packed(cutlass::make_Coord(K, R, S, C));\n"); + CutlassPrint(conv2d_decl, + "auto output_oshape = TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K));\n"); + + if (is_wgrad) { + CutlassPrint(conv2d_decl, "TensorNHWC layout_A(output_oshape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_B(activation_shape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_C(weight_shape);\n\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_D(weight_shape);\n\n"); + } else if (is_dgrad) { + CutlassPrint(conv2d_decl, "TensorNHWC layout_A(output_oshape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_B(weight_shape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_C(activation_shape);\n\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_D(activation_shape);\n\n"); + } else { + CutlassPrint(conv2d_decl, "TensorNHWC layout_A(activation_shape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_B(weight_shape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_C(output_oshape);\n\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_D(output_oshape);\n\n"); + } + + if (use_split_k) { + CutlassPrint(conv2d_decl, "using ElementOutput = EpilogueOutputOp::ElementOutput;\n"); + } else { + CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n"); + } + + std::string tensor_c_init = "{static_cast(ptr_out), layout_C}"; + if (has_residual_block) { + tensor_c_init = "{static_cast(ptr_residual), layout_C}"; + } else if (has_bias) { + tensor_c_init = + "{static_cast(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)}"; + } + + CutlassPrint(conv2d_decl, + "cutlass::TensorRef tensor_c" + tensor_c_init + ";\n"); + CutlassPrint(conv2d_decl, + "cutlass::TensorRef " + "tensor_d{static_cast(ptr_out),layout_D};\n"); + + CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n"); + CutlassPrint(conv2d_decl, " problem_size,\n"); + CutlassPrint(conv2d_decl, " {static_cast(ptr_a), layout_A},\n"); + CutlassPrint(conv2d_decl, " {static_cast(ptr_b), layout_B},\n"); + + if (use_split_k) { + CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n"); + CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n"); + } else { + CutlassPrint(conv2d_decl, " tensor_c,\n"); + CutlassPrint(conv2d_decl, " tensor_d,\n"); + } + + if (has_residual_block) { + ICHECK(use_split_k == false) << "Split-k not supported for residual block fusion"; + CutlassPrint(conv2d_decl, "{alpha, beta},\n"); + CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); // split_k_slices + CutlassPrint(conv2d_decl, "static_cast(ptr_bias),\n"); + CutlassPrint(conv2d_decl, "nullptr, 0, K};\n"); + } else if (has_bias && no_bias_scaling) { + CutlassPrint(conv2d_decl, " {alpha},\n"); + CutlassPrint(conv2d_decl, "split_k_mode\n};\n"); + } else { + CutlassPrint(conv2d_decl, "{alpha, beta},\n"); + CutlassPrint(conv2d_decl, "split_k_mode\n};\n"); + } + + CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n"); + + CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n"); + // Allocate workspace memory + CutlassPrint(conv2d_decl, + "cutlass::device_memory::allocation workspace(workspace_size);\n"); + // Check the problem size is supported or not + CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + + if (use_split_k) { + CutlassPrint(conv2d_decl, + "arguments.ref_D.reset(reinterpret_cast(workspace.get())," + " layout_D);\n\n"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + + if (use_split_k) { + CutlassPrint( + conv2d_decl, + "arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n"); + CutlassPrint(conv2d_decl, "status = conv2d_op.update(arguments, workspace.get()); \n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + } + + // Launch initialized CUTLASS kernel + CutlassPrint(conv2d_decl, "status = conv2d_op();\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + + if (use_split_k) { + CutlassPrint(conv2d_decl, "ReductionDevice reduction_op;\n"); + CutlassPrint(conv2d_decl, + "const static cutlass::conv::Operator kConvolutionalOperator = " + "Conv2d::kConvolutionalOperator;\n"); + CutlassPrint(conv2d_decl, "typename ReductionDevice::Arguments reduction_args(\n"); + CutlassPrint(conv2d_decl, + "cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, " + "problem_size).mn(),\n"); + CutlassPrint(conv2d_decl, "problem_size.split_k_slices,\n"); + CutlassPrint(conv2d_decl, + "cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, " + "problem_size),\n"); + CutlassPrint(conv2d_decl, "{\n"); + CutlassPrint(conv2d_decl, + " reinterpret_cast (workspace.get()),\n"); + CutlassPrint(conv2d_decl, + "ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::" + "kTensorCStrideIdx])\n"); + CutlassPrint(conv2d_decl, "},\n"); + CutlassPrint(conv2d_decl, "{\n"); + CutlassPrint(conv2d_decl, "tensor_d.data(),\n"); + CutlassPrint(conv2d_decl, + "ReductionStrideIndex(tensor_d.stride()[Conv2d::ImplicitGemmKernel::" + "kTensorCStrideIdx])\n"); + CutlassPrint(conv2d_decl, "},\n"); + CutlassPrint(conv2d_decl, "{\n"); + CutlassPrint(conv2d_decl, "tensor_c.data(),\n"); + CutlassPrint(conv2d_decl, + "ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::" + "kTensorCStrideIdx])\n"); + CutlassPrint(conv2d_decl, "},\n"); + CutlassPrint(conv2d_decl, " {alpha, beta}\n"); + CutlassPrint(conv2d_decl, ");\n\n"); + CutlassPrint(conv2d_decl, "status = reduction_op.initialize(reduction_args, nullptr);\n"); + CutlassPrint(conv2d_decl, "status = reduction_op();\n"); + } + + return conv2d_decl.str(); +} + +struct Output { + std::string name; + std::string dtype; + int size; + bool need_copy; +}; + +struct GenerateBodyOutput { + std::string decl; + std::vector buffers; + std::vector outputs; +}; + +inline bool IsOp(const CallNode* call, const std::string& op_name) { + const auto* op_node = call->op.as(); + if (!op_node) return false; + Op op = GetRef(op_node); + return op == Op::Get(op_name); +} + +class CodegenCutlass : public tvm::relax::backend::MemoizedExprTranslator> { + public: + CodegenCutlass(const std::string& id, const Map& attrs, const Expr& expr) { + // todo: clean up + this->ext_func_id_ = id; + this->attrs_ = attrs; + bindings_ = AnalyzeVar2Value(expr); + } + + std::vector VisitExpr_(const VarNode* node) final { + ext_func_args_.push_back(GetRef(node)); + Output output; + output.name = node->name_hint(); + return {output}; + } + + std::vector VisitExpr_(const CallNode* call) final { + const auto* fn_var = call->op.as(); + ICHECK(fn_var); + const auto func = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(func.defined()) << "Only composite function is supported for CUTLASS."; + GenerateBodyOutput ret = GenerateCompositeFunctionCall(func, call); + ext_func_body_.push_back(ret.decl); + return ret.outputs; + } + + std::string JIT(const std::vector& out) { + CHECK(out.size() > 0); + code_stream_ << "void " << ext_func_id_ << "_("; + + for (const auto& arg : ext_func_args_) { + code_stream_ << "DLTensor* " << arg->name_hint() << ", "; + } + for (size_t i = 0; i < out.size() - 1; ++i) { + code_stream_ << "DLTensor* out" << i << ", "; + } + code_stream_ << "DLTensor* out" << out.size() - 1 << ") {\n"; + this->EnterScope(); + + // Function body + for (auto decl : buf_decl_) { + this->PrintIndents(); + code_stream_ << decl << "\n"; + } + code_stream_ << "\n"; + for (auto stmt : ext_func_body_) { + this->PrintIndents(); + code_stream_ << stmt << "\n"; + } + + this->ExitScope(); + code_stream_ << "}\n"; + + this->GenerateBackendCFunc(ext_func_id_, ext_func_args_, /*const_arr_name=*/"", out, true); + return code_stream_.str(); + } + + /*! \brief The external function source code stream. */ + std::ostringstream code_stream_; + + protected: + std::vector VisitExpr_(const FunctionNode* fn) { + ICHECK(fn->GetAttr(attr::kComposite).defined()) + << "JSON runtime only supports composite functions"; + // FunctionNode should be handled by the caller. + return {}; + } + + std::vector VisitBinding_(const VarBindingNode* binding) { + ICHECK_EQ(memo_.count(binding->var), 0); + memo_[binding->var] = VisitExpr(binding->value); + return VisitExpr(binding->value); + } + + std::vector VisitBinding(const Binding& binding) { + std::vector nodes; + if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } else { + LOG(FATAL) << "Unimplemented type: " << binding->GetTypeKey(); + } + return nodes; + } + + std::vector VisitBindingBlock(const BindingBlock& block) { + std::vector nodes; + if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return nodes; + } + + std::vector VisitBindingBlock_(const BindingBlockNode* block) { + std::vector nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + std::vector VisitBindingBlock_(const DataflowBlockNode* block) { + std::vector nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + std::vector VisitExpr_(const SeqExprNode* op) { + std::vector nodes; + + for (BindingBlock block : op->blocks) { + auto from_bb = VisitBindingBlock(block); + } + + auto from_body = VisitExpr(op->body); + nodes.insert(nodes.end(), from_body.begin(), from_body.end()); + + return nodes; + } + + private: + std::vector GetArgumentNames(const CallNode* call) { + std::vector arg_names; + for (size_t i = 0; i < call->args.size(); ++i) { + auto res = VisitExpr(call->args[i]); + for (const auto& out : res) { + arg_names.push_back(out.name); + } + } + return arg_names; + } + + GenerateBodyOutput GenerateCompositeFunctionCall(Function callee, const CallNode* caller) { + const auto pattern_name = callee->GetAttr(attr::kComposite); + ICHECK(pattern_name.defined()) << "Only functions with composite attribute are supported."; + + if (pattern_name == "conv2d_bias_relu") { + const CallNode* conv2d_call = caller; + for (auto [var, val] : bindings_) { + if (val->IsInstance() && IsOp(val.as(), "relax.nn.conv2d")) { + conv2d_call = val.as(); + break; + } + } + return GenerateBody(conv2d_call, "cutlass_conv2d_bias_relu", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); + } + + LOG(FATAL) << "Unknown composite function: " << pattern_name; + return {}; + } + + GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name, + const std::vector& func_args, + const Str2StrMap& attribute_args) { + // Make function call with input buffers when visiting arguements + ICHECK_GT(func_args.size(), 0); + std::ostringstream decl_stream; + decl_stream << "(" << func_args[0]; + for (size_t i = 1; i < func_args.size(); ++i) { + decl_stream << ", " << func_args[i]; + } + // Analyze the output buffers + auto struct_info = GetStructInfo(GetRef(root_call)); + + std::vector out_types; + if (const auto* tensor_sinfo = struct_info.as()) { + out_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); + } else { + LOG(FATAL) << "Unimplemented"; + } + + GenerateBodyOutput ret; + for (const auto& out_type : out_types) { + const std::string out = "out" + std::to_string(buf_idx_++); + decl_stream << ", " << out; + Output output; + output.name = out; + output.dtype = out_type; + output.need_copy = false; + ret.outputs.push_back(output); + } + decl_stream << ");"; + if (func_name.find("dense") != std::string::npos) { + ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); + } else if (func_name == "cutlass_batch_matmul") { + ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); + } else if (func_name.find("conv2d") != std::string::npos) { + ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args); + } + return ret; + } + + /*! \brief Print indents using spaces. */ + void PrintIndents() { + for (int i = 0; i < indent_; i++) { + code_stream_ << ' '; + } + } + + /*! + * \brief Enter a new scope. + */ + void EnterScope() { indent_ += 2; } + + /*! + * \brief Exit a scope. + */ + void ExitScope() { + ICHECK_GE(indent_, 2U) << "Wrong ident found."; + indent_ -= 2; + } + + /*! + * \brief Creates a runtime function header + */ + void PrintRuntimeFunctionHeader(std::string func_name) { + code_stream_ << "#ifdef __cplusplus\n"; + code_stream_ << "extern \"C\" {\n"; + code_stream_ << "#endif\n"; + code_stream_ << "TVM_DLL int32_t "; + code_stream_ << func_name << "("; + code_stream_ << "TVMValue* args, "; + code_stream_ << "int* type_code, "; + code_stream_ << "int num_args, "; + code_stream_ << "TVMValue* out_value, "; + code_stream_ << "int* out_type_code) {\n"; + } + + /*! + * \brief Adds a line to convert TVMValue args to DLTensors + */ + void PrintArgToData(int idx) { + PrintIndents(); + code_stream_ << "DLTensor* arg" << idx << " = "; + code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << "].v_handle);\n"; + } + + /*! + * \brief Adds a line to convert TVMValue rets to DLTensors + */ + void PrintRetToData(int idx) { + PrintIndents(); + code_stream_ << "DLTensor* ret" << idx << " = "; + code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << "].v_handle);\n"; + } + + /*! + * \brief Gerenate C code for the external function. + * + * \param func_name The name of the external function. + * \param args arguments to the external function. + * + * \code + * + * Array foo_consts; + * + * // An example code for the generated C function. + * int foo_wrapper_(DLTensor* arg0, + * DLTensor* arg1, + * DLTensor* out) { + * foo_((float*)(arg0->data), + * (float*)(arg1->data), + * (float*)(out->data)); + * return 0; + * } + * + * TVM_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); + * + * int foo_init_wrapper_(Array arr) { + * foo_consts = arr; + * return 0; + * } + * + * TVM_DLL_EXPORT_TYPED_FUNC(__init_foo, foo_init_wrapper_); + * + * \endcode + */ + void GenerateBackendCFunc(const std::string& func_name, const Array& args, + const std::string& const_arr_name, const std::vector& outs, + bool pass_dl_tensor = false) { + // Print signature + code_stream_ << "\n"; + + code_stream_ << "int " << func_name << "_wrapper_("; + for (size_t i = 0; i < args.size(); i++) { + code_stream_ << "DLTensor* arg" << i << ",\n"; + code_stream_ << "\t"; + } + for (size_t i = 0; i < outs.size() - 1; i++) { + code_stream_ << "DLTensor* out" << i << ",\n"; + code_stream_ << "\t"; + } + code_stream_ << "DLTensor* out" << outs.size() - 1 << ") {\n"; + + EnterScope(); + + // Generate the internal call. + PrintIndents(); + code_stream_ << func_name << "_("; + for (size_t i = 0; i < args.size(); i++) { + if (pass_dl_tensor) { + code_stream_ << "arg" << i << ",\n"; + } else { + const auto& dtype_str = GetDtypeString(args[i]); + code_stream_ << "(" << dtype_str << "*)(arg" << i << "->data),\n"; + } + PrintIndents(); + } + for (size_t i = 0; i < outs.size() - 1; i++) { + if (pass_dl_tensor) { + code_stream_ << "out" << i << ",\n"; + } else { + code_stream_ << "(" << outs[i].dtype << "*)(out" << i << "->data),\n"; + } + PrintIndents(); + } + if (pass_dl_tensor) { + code_stream_ << "out" << outs.size() - 1 << ");\n"; + } else { + code_stream_ << "(" << outs.back().dtype << "*)(out" << outs.size() - 1 << "->data));\n"; + } + PrintIndents(); + code_stream_ << "return 0;\n"; + ExitScope(); + code_stream_ << "}\n\n"; + + // Create the external function + PrintRuntimeFunctionHeader(func_name); + EnterScope(); + for (size_t i = 0; i < args.size(); i++) { + PrintArgToData(i); + } + for (size_t i = 0; i < outs.size(); i++) { + PrintRetToData(args.size() + i); + } + PrintIndents(); + code_stream_ << func_name << "_wrapper_("; + for (size_t i = 0; i < args.size(); i++) { + code_stream_ << "arg" << i << ","; + } + for (size_t i = 0; i < outs.size() - 1; i++) { + code_stream_ << "ret" << args.size() + i << ","; + } + code_stream_ << "ret" << args.size() + outs.size() - 1 << ");\n"; + PrintIndents(); + code_stream_ << "return 0;\n"; + ExitScope(); + code_stream_ << "}\n"; + code_stream_ << "#ifdef __cplusplus\n"; + code_stream_ << "}\n"; + code_stream_ << "#endif\n"; + + if (!const_arr_name.empty()) { + // If there are constants, insert the __init_ and the wrapper + // This segment would be generated in C++ because of the usage + // of tvm::runtime::Array. This is not ideal, but this to demonstrate + // constant copying process used packed imports in other external + // codegen. Moreover, in microTVM we dont expect this part to be generated. + code_stream_ << "#ifdef __cplusplus\n"; + code_stream_ << "int " << func_name + << "_init_wrapper_(tvm::runtime::Array arr) {\n"; + EnterScope(); + PrintIndents(); + code_stream_ << func_name << "_consts = arr;\n"; + code_stream_ << "return 0;\n"; + ExitScope(); + code_stream_ << "}\n\n"; + code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(__init_" << func_name << ", " << func_name + << "_init_wrapper_);\n\n"; + code_stream_ << "#endif\n"; + } + } + + std::string GetDtypeString(const Var& var) { + auto ttype = var->checked_type().as(); + ICHECK(ttype) << "Expect TensorTypeNode"; + return GetDtypeString(ttype); + } + + /*! + * \brief Returns dtype string + * + * \param ttype TensorTypeNode* to get the dtype of + * + * \return The dtype string. + */ + std::string GetDtypeString(const TensorTypeNode* ttype) { + std::string dtype; + if (runtime::TypeMatch(ttype->dtype, kDLFloat, 32)) { + dtype = "float"; + } else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) { + dtype = "half"; + } else if (runtime::TypeMatch(ttype->dtype, kDLBfloat, 16)) { + dtype = "bfloat"; + } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) { + dtype = "int"; + } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) { + dtype = "int64_t"; + } else { + LOG(FATAL) << "Unsupported dtype " << ttype->dtype; + } + + return dtype; + } + + /*! \brief Indent of the source code. */ + int indent_{0}; + /*! \brief The id of the external cutlass ext_func. */ + std::string ext_func_id_; + /*! \brief The attrs of the external cutlass ext_func. */ + Map attrs_; + /*! + * \brief The index to track the output buffer. Each kernel will redirect the + * output to a buffer that may be consumed by other kernels. + */ + int buf_idx_{0}; + /*! \brief The arguments used by a wrapped function that calls CUTLASS kernels. */ + Array ext_func_args_; + /*! \brief Statement of the function that will be compiled using CUTLASS kernels. */ + std::vector ext_func_body_; + /*! \brief The declaration of intermediate buffers. */ + std::vector buf_decl_; + + Map bindings_; +}; + +class CutlassModuleCodegen { + public: + runtime::Module CreateCSourceModule(Function f) { + EmitPreamble(); + GenCutlassFunc(f); + return Finalize(); + } + + private: + void EmitPreamble() { + // create header + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + // cutlass header + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + } + + void GenCutlassFunc(const Function& function) { + ICHECK(function.defined()) << "Input error: expect a Relay function."; + + // Record the external symbol for runtime lookup. + Optional opt_global_symbol = function->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()) + << "CUTLASS functions must have a " << tvm::attr::kGlobalSymbol << " attribute"; + std::string sid = opt_global_symbol.value(); + if (std::find(func_names_.begin(), func_names_.end(), sid) != func_names_.end()) { + // Already emitted. + return; + } + func_names_.push_back(sid); + + const auto* attrs = function->attrs.as(); + ICHECK(attrs != nullptr); + const auto dict = attrs->dict; + CodegenCutlass builder(sid, dict, function); + VLOG(1) << "Creating cutlass C code for '" << sid << "' from:\n" << PrettyPrint(function); + auto out = builder.VisitExpr(function->body); + code_stream_ << builder.JIT(out); + } + + runtime::Module Finalize() { + ICHECK(!func_names_.empty()) + << "Should only create CUTLASS CSourceModule if have at least one CUTLASS partition"; + const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); + ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; + VLOG(1) << "Generated CUTLASS code:" << std::endl << code_stream_.str(); + return (*pf)(code_stream_.str(), "cu", func_names_, /*const_vars=*/Array()); + } + + /*! + * \brief Returns \p expr as function if it is a \p Function with "Compiler" attribute + * value "cutlass". + */ + static const FunctionNode* GetCutlassFunctionNode(const Expr& expr) { + if (const auto* function_node = expr.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCodegen); + if (opt_compiler.defined() && opt_compiler.value() == "cutlass") { + return function_node; + } + } + return nullptr; + } + + /*! \brief The accumulated code stream that will be compiled by NVCC */ + std::ostringstream code_stream_; + /*! \brief The accumulated function names. */ + Array func_names_; +}; // CutlassModuleCodegen + +/*! + * \brief Create a runtime module for CUTLASS. + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * \return A runtime module. + */ +runtime::Module CUTLASSCompiler(const ObjectRef& ref) { + ICHECK(ref->IsInstance()) << "The input ref is expected to be a Relax function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + auto source_mod = CutlassModuleCodegen().CreateCSourceModule(func); + const auto* pf = runtime::Registry::Get("contrib.cutlass.compile"); + ICHECK(pf != nullptr); + return (*pf)(source_mod); +} + +TVM_REGISTER_GLOBAL("relax.ext.cutlass").set_body_typed(CUTLASSCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/dnnl/codegen.cc b/src/relax/backend/contrib/dnnl/codegen.cc new file mode 100644 index 0000000000..cf2379a0fa --- /dev/null +++ b/src/relax/backend/contrib/dnnl/codegen.cc @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/dnnl/codegen.cc + * \brief Implementation of the DNNL JSON serializer. + */ +#include +#include +#include +#include + +#include +#include +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONSerializer = backend::contrib::JSONSerializer; + +inline bool IsOp(const CallNode* call, const std::string& op_name) { + const auto* op_node = call->op.as(); + if (!op_node) return false; + Op op = GetRef(op_node); + return op == Op::Get(op_name); +} + +/*! + * \brief Generates an DNNLModule from a relax expression by serializing the expression to a + * json representation. DNNL is not required here because use of DNNL APIs is deferred until + * runtime. + */ +class DNNLJSONSerializer : public JSONSerializer { + public: + DNNLJSONSerializer(const std::string& symbol, const Expr& expr) + : JSONSerializer(symbol, expr), bindings_(AnalyzeVar2Value(expr)) {} + + using JSONSerializer::VisitExpr_; + + std::vector VisitExpr_(const CallNode* call_node) final { + // The call must be to an inline "Composite" function + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()); + + auto opt_composite = fn->GetAttr(attr::kComposite); + ICHECK(opt_composite.defined()); + + std::string name = opt_composite.value(); + + const CallNode* root_call = call_node; + if (name.find("conv2d") != std::string::npos) { + for (auto [var, val] : bindings_) { + if (val->IsInstance() && IsOp(val.as(), "relax.nn.conv2d")) { + root_call = val.as(); + break; + } + } + ICHECK(root_call->op.as()) << "Not op node"; + } else { + LOG(FATAL) << "Unimplemented"; + } + + std::vector inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + SetCallNodeAttribute(node, root_call); + return AddNode(node, GetRef(call_node)); + } + + private: + Map bindings_; +}; + +/*! + * \brief Create a runtime module for DNNL. + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * \return A runtime module. + */ +runtime::Module DNNLCompiler(const ObjectRef& ref) { + ICHECK(ref->IsInstance()) << "The input ref is expected to be a Relax function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + + DNNLJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + auto param_names = serializer.GetParams(); + const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find DNNL runtime module create function."; + runtime::Module lib = (*pf)(func_name, graph_json, param_names); + return lib; +} + +TVM_REGISTER_GLOBAL("relax.ext.dnnl").set_body_typed(DNNLCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 76bfdb12d2..d061f675e3 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -499,7 +499,7 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr } bool MatchExpr(DFPattern pattern, Expr expr, Optional> var2val) { - if (var2val.defined()) // autojump is enabled with var2val. + if (var2val) // autojump is enabled with var2val. return DFPatternMatcher(std::move(var2val.value())).Match(pattern, expr); else return DFPatternMatcher().Match(pattern, expr); @@ -507,6 +507,23 @@ bool MatchExpr(DFPattern pattern, Expr expr, Optional> v TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); +Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, + Optional> bindings_opt) { + auto bindings = bindings_opt ? bindings_opt.value() : Map{}; + DFPatternMatcher matcher(bindings); + + if (!matcher.Match(pattern, expr)) { + return NullOpt; + } + + Map matching; + for (const auto& [pat, matches] : matcher.GetMemo()) { + ICHECK(matches.size() == 1) << "More than one match for the pattern " << pat; + matching.Set(pat, matches[0]); + } + return matching; +} + struct PNode { const DFPatternNode* ptr; const VarNode* matched = nullptr; diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 0983db3989..85768c7174 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -27,6 +27,9 @@ * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. */ +#include +#include +#include #include #include #include @@ -372,15 +375,22 @@ class FunctionCreator : public ExprMutator { if (const auto* var_binding = binding.as()) { if (const auto* call = var_binding->value.as()) { - ICHECK(call->op == Op::Get("relax.call_tir")); - // Update the name of the function. - name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; + if (call->op == Op::Get("relax.call_tir")) { + // Update the name of the function. + name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; - const Tuple& args = Downcast(call->args[1]); - for (const Expr& arg : args->fields) { - CheckDefAndUpdateParam(arg); + const Tuple& args = Downcast(call->args[1]); + for (const Expr& arg : args->fields) { + CheckDefAndUpdateParam(arg); + } + // TODO(tvm-team): handle shape expr + } else { + ICHECK(call->op->IsInstance()); + name_hint_ = name_hint_ + "_" + Downcast(call->op)->name; + for (const Expr& arg : call->args) { + CheckDefAndUpdateParam(arg); + } } - // TODO(tvm-team): handle shape expr } else { const auto* tuple_item = var_binding->value.as(); ICHECK(tuple_item != nullptr); @@ -407,7 +417,7 @@ class FunctionCreator : public ExprMutator { /*! * \brief Create the grouped function according according to the collected bindings and parameters - * \note The created function won't be returned immediately. Tt's stored in the `function_` field. + * \note The created function won't be returned immediately. It's stored in the `function_` field. */ void CreateFunction() { // Step 1. Start constructing a new dataflow block. @@ -543,14 +553,16 @@ class OperatorFusor : public ExprMutator { } } + OperatorFusor(IRModule mod, + const std::unordered_map& obj2group) + : ExprMutator(mod), mod_(std::move(mod)), obj2group_(obj2group) {} + /*! * \brief The main transformation on the IRModule * \return The new IRModule after transformation */ IRModule Transform() { - for (const auto& kv : mod_->functions) { - const GlobalVar& gv = kv.first; - const BaseFunc& func = kv.second; + for (const auto& [gv, func] : mod_->functions) { // Only visit Relax function without attr kPrimitive. if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { auto updated_func = Downcast(VisitExpr(func)); @@ -579,8 +591,7 @@ class OperatorFusor : public ExprMutator { CollectFuncBoundary(block->bindings); // Step 3. Create the grouped function for each group. - for (auto& kv : group2func_) { - FunctionCreator& creator = kv.second; + for (auto& [_, creator] : group2func_) { creator.CreateFunction(); } @@ -757,6 +768,90 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { return mod; } +static Map GetBindingInverse(const Map& binding) { + Map value_to_bound_var; + for (const auto& [var, val] : binding) { + value_to_bound_var.Set(val, var); + } + return value_to_bound_var; +} + +class PatternBasedPartitioner : ExprVisitor { + public: + using Group = GraphPartitioner::Group; + using ExprVisitor::VisitExpr_; + + static std::unordered_map Run(DFPattern pattern, Expr expr, + support::Arena* arena) { + PatternBasedPartitioner part(pattern, AnalyzeVar2Value(expr)); + PostOrderVisit( + expr, [arena, &part](const Expr& e) { part.group_map_[e.get()] = arena->make(); }); + part.VisitExpr(expr); + return part.group_map_; + } + + PatternBasedPartitioner(DFPattern pattern, const tvm::runtime::Map& bindings) + : pat_(pattern), bindings_(bindings), value_to_bound_var_(GetBindingInverse(bindings)) {} + + void VisitBindingBlock_(const DataflowBlockNode* block) final { + for (const auto& binding : block->bindings) { + auto it = group_map_.find(binding->var.get()); + ICHECK(it != group_map_.end()); + if (const auto* var_binding = binding.as()) { + VisitExpr(var_binding->value); + } + } + } + + void VisitExpr_(const CallNode* call) override { + if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef(call), bindings_)) { + auto parent_group = GetGroupForBoundVar(GetRef(call)); + ICHECK(parent_group); + + for (const auto& [_, match] : matches_opt.value()) { + ICHECK(group_map_.count(match.get())); + if (!match->IsInstance()) { + AddToGroup(match, parent_group); + if (value_to_bound_var_.count(match) && GetGroupForBoundVar(match)->num_nodes == 1) { + AddToGroup(value_to_bound_var_[match], parent_group); + } + } + } + } + } + + private: + void AddToGroup(Expr e, Group* to) { + if (group_map_[e.get()] != to) { + --group_map_[e.get()]->num_nodes; + group_map_[e.get()] = to; + ++to->num_nodes; + } + } + + Group* GetGroupForBoundVar(Expr e) { + ICHECK(value_to_bound_var_.count(e)); + auto bound_var = value_to_bound_var_[e]; + ICHECK(group_map_.count(bound_var.get())); + return group_map_[bound_var.get()]; + } + + DFPattern pat_; + Map bindings_; + Map value_to_bound_var_; + std::unordered_map group_map_; +}; + +IRModule FuseOpsByPattern(DFPattern pattern, IRModule mod) { + std::unordered_map group_map; + support::Arena arena; + for (const auto& [gv, func] : mod->functions) { + auto map = PatternBasedPartitioner::Run(pattern, func, &arena); + group_map.insert(map.begin(), map.end()); + } + return OperatorFusor(mod, group_map).Transform(); +} + namespace transform { Pass FuseOps(int fuse_opt_level) { @@ -774,6 +869,17 @@ Pass FuseOps(int fuse_opt_level) { TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); +Pass FuseOpsByPattern(DFPattern pattern) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { return relax::FuseOpsByPattern(pattern, m); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOpsByPattern", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); + } // namespace transform } // namespace relax diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index 173dcf5e5f..0e36768c70 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -801,8 +801,8 @@ class CutlassModuleCodegen { runtime::Module CreateCSourceModule() { EmitPreamble(); - for (const auto& kv : mod_->functions) { - if (const auto* function_node = GetCutlassFunctionNode(kv.second)) { + for (const auto& [_, f] : mod_->functions) { + if (const auto* function_node = GetCutlassFunctionNode(f)) { GenCutlassFunc(GetRef(function_node)); } } diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index ba06d082c4..a9e621117d 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -73,6 +73,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { /* Thread safe implementation of Run. Keep runtime instance immutable */ void Run(const TVMArgs& args) const { + LOG(INFO) << "Running DNNL"; auto arg_data_provider = makeIODataProvider(args); auto mem_solver = tensor_registry_.MakeSolver(arg_data_provider); // Execute primitives one by one @@ -316,7 +317,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto padding = GetNodeAttr>(node, "padding"); std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); - auto groups = GetNodeAttr(node, "groups"); + // todo: groups attribute missing in Relax conv2d + auto groups = 1; // GetNodeAttr(node, "groups"); auto src_layout = GetNodeAttr(node, "data_layout"); auto dst_layout = GetNodeAttr(node, "out_layout"); auto wgh_layout = GetNodeAttr(node, "kernel_layout"); diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py new file mode 100644 index 0000000000..5225294f93 --- /dev/null +++ b/tests/python/relax/test_codegen_cutlass.py @@ -0,0 +1,263 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +import tvm.testing + +from tvm import relax, relay +from tvm.script import relax as R +from tvm.relax.dpl import * +from tvm.contrib.cutlass.build import finalize_modules_relax + + +op_name = "cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align8" + +op_def = """ + using cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align8 = + typename cutlass::conv::kernel::DefaultConv2dFprop< + cutlass::half_t, + cutlass::layout::TensorNHWC, + cutlass::half_t, + cutlass::layout::TensorNHWC, + cutlass::half_t, + cutlass::layout::TensorNHWC, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32 >, + cutlass::gemm::GemmShape<16, 8, 8>, + + cutlass::epilogue::thread::LinearCombinationRelu< + cutlass::half_t, + 8, + cutlass::half_t, + cutlass::half_t, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, + 8, + 8 + >::Kernel; +""" + + +def make_conv_pattern(conv_name, with_bias=False, activation=None): + data = wildcard() + weight = wildcard() + conv = is_op(conv_name)(data, weight) + + if with_bias: + bias = wildcard() + conv_out = is_op("relax.add")(conv, bias) + else: + conv_out = conv + + if activation: + return is_op(activation)(conv_out) + + return conv_out + + +@tvm.script.ir_module +class Conv2dBiasReLU: + @R.function + def conv2d( + data: R.Tensor((16, 32, 32, 16), "float16"), + weight: R.Tensor((32, 3, 3, 16), "float16"), + bias: R.Tensor((1, 1, 1, 32), "float16"), + ): + with R.dataflow(): + conv1 = relax.op.nn.relu( + relax.op.add( + relax.op.nn.conv2d( + data, weight, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ), + bias, + ) + ) + R.output(conv1) + + return conv1 + + +@tvm.script.ir_module +class Conv2dBiasReLUPartitioned: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight: R.Tensor((32, 3, 3, 16), dtype="float16"), + bias: R.Tensor((1, 1, 1, 32), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 32), dtype="float16"): + # block 0 + with R.dataflow(): + gv: R.Tensor( + (16, 32, 32, 32), dtype="float16" + ) = fused_relax_nn_conv2d_relax_add_relax_nn_relu(data, weight, bias) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_add_relax_nn_relu( + data1: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((32, 3, 3, 16), dtype="float16"), + bias1: R.Tensor((1, 1, 1, 32), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 32), dtype="float16"): + R.func_attr( + {"Codegen": "cutlass", "global_symbol": "fused_relax_nn_conv2d_relax_add_relax_nn_relu"} + ) + + @R.function + def fused_relax_nn_conv2d_relax_add_relax_nn_relu_inner( + data1: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((32, 3, 3, 16), dtype="float16"), + bias1: R.Tensor((1, 1, 1, 32), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 32), dtype="float16"): + # function attr dict + R.func_attr({"Primitive": 1, "Composite": "conv2d_bias_relu"}) + # block 0 + with R.dataflow(): + lv: R.Tensor((16, 32, 32, 32), dtype="float16") = R.nn.conv2d( + data1, + weight1, + strides=[1, 1], + padding=[1, 1], + dilation=[1, 1], + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="", + ) + lv1: R.Tensor((16, 32, 32, 32), dtype="float16") = R.add(lv, bias1) + gv1: R.Tensor((16, 32, 32, 32), dtype="float16") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + return fused_relax_nn_conv2d_relax_add_relax_nn_relu_inner(data1, weight1, bias1) + + +def annotate_attributes(mod): + # TODO: automate + f_name = "fused_relax_nn_conv2d_relax_add_relax_nn_relu" + f = mod[f_name] + + for k, v in { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float32", + "arg0_shape": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float32", + "op_type": "conv2d_bias_relu", + "arg0_shape": [16, 32, 32, 16], + "arg1_shape": [32, 3, 3, 16], + "ret_shape": [16, 32, 32, 32], + "strides": [1, 1], + "padding": [1, 1], + "dilation": [1, 1], + "cutlass_op_name": op_name, + "cutlass_op_def": op_def, + }.items(): + f = f.with_attr(k, v) + + mod[f_name] = f + + return mod + + +def test_conv2d_partition(): + mod = Conv2dBiasReLU + pat = make_conv_pattern("relax.nn.conv2d", True, "relax.nn.relu") + mod = relax.transform.FuseOpsByPattern(pat)(mod) + + print(mod.script()) + + +def get_relay_conv2d_bias_relu(d_shape, w_shape): + data = relay.var("data", shape=d_shape) + weight = relay.var("weight", shape=w_shape) + bias = relay.var("bias", shape=(1, 1, 1, w_shape[0])) + return relay.nn.relu( + relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + ) + + bias + ) + + +def get_ref(data_np, weight_np, bias_np): + relay_mod = tvm.IRModule.from_expr(get_relay_conv2d_bias_relu(data_np.shape, weight_np.shape)) + + with tvm.transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential( + [relay.transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]})] + ) + relay_mod = seq(relay_mod) + + ref = ( + relay.create_executor("graph", mod=relay_mod, device=tvm.gpu(0), target="cuda") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) + + return ref + + +def test_conv2d_offload(): + data_np = np.random.randn(16, 32, 32, 16).astype("float16") + weight_np = np.random.randn(32, 3, 3, 16).astype("float16") + bias_np = np.random.randn(1, 1, 1, 32).astype("float16") + + seq = tvm.transform.Sequential( + [ + relax.transform.RunCodegen(), + relax.transform.RemoveUnusedFunctions(), + ] + ) + + mod = annotate_attributes(Conv2dBiasReLUPartitioned) + mod = seq(mod) + + target = tvm.target.Target("cuda") + ex = relax.vm.build(mod, target) + ex = finalize_modules_relax(ex) + + dev = tvm.gpu(0) + vm = relax.VirtualMachine(ex, dev) + + data = tvm.nd.array(data_np, dev) + weight = tvm.nd.array(weight_np, dev) + bias = tvm.nd.array(bias_np, dev) + out = vm["main"](data, weight, bias).numpy() + + ref = get_ref(data_np, weight_np, bias_np) + + print(np.max(np.abs(out - ref)), np.mean(np.abs(out - ref))) + + +if __name__ == "__main__": + test_conv2d_offload() diff --git a/tests/python/relax/test_codegen_dnnl.py b/tests/python/relax/test_codegen_dnnl.py new file mode 100644 index 0000000000..d7af00c0b4 --- /dev/null +++ b/tests/python/relax/test_codegen_dnnl.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +import tvm.testing + +from tvm import relax, relay +from tvm.script import relax as R +from tvm.relax.dpl import * + + +def make_conv_pattern(conv_name, with_bias=False, activation=None): + data = wildcard() + weight = wildcard() + conv = is_op(conv_name)(data, weight) + + if with_bias: + bias = wildcard() + conv_out = is_op("add")(conv, bias) + else: + conv_out = conv + + return is_op(activation)(conv_out) + + +@tvm.script.ir_module +class Conv2dReLUx2: + @R.function + def conv2d( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1, padding=(1, 1))) + conv2d = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2, padding=(0, 0))) + R.output(conv2d) + + return conv2d + + +def get_relay_conv2d_relu_x2(d_shape, w_shape): + data = relay.var("data", shape=d_shape) + weight1 = relay.var("weight1", shape=w_shape) + weight2 = relay.var("weight2", shape=w_shape) + conv1 = relay.nn.relu( + relay.nn.conv2d( + data=data, + weight=weight1, + kernel_size=w_shape[2:], + padding=(1, 1), + ) + ) + return relay.nn.relu( + relay.nn.conv2d( + data=conv1, + weight=weight2, + kernel_size=w_shape[2:], + padding=(0, 0), + ) + ) + + +def test_conv2d_partition(): + mod = Conv2dReLUx2 + pat = make_conv_pattern("relax.nn.conv2d", False, "relax.nn.relu") + mod = relax.transform.FuseOpsByPattern(pat)(mod) + print(mod.script()) + + +@tvm.script.ir_module +class Conv2dReLUx2Partitioned: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu( + data, weight1 + ) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu1( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Codegen": "dnnl", "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu"}) + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu_inner( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + # function attr dict + R.func_attr({"Primitive": 1, "Composite": "conv2d_relu"}) + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, + weight11, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="", + ) + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + return fused_relax_nn_conv2d_relax_nn_relu_inner(data1, weight11) + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu1( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Codegen": "dnnl", "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu1"}) + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu1_inner( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Primitive": 1, "Composite": "conv2d_relu"}) + # block 0 + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, + weight21, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="", + ) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2) + R.output(gv2) + return gv2 + + return fused_relax_nn_conv2d_relax_nn_relu1_inner(conv1, weight21) + + +def test_dnnl_offload(): + seq = tvm.transform.Sequential( + [ + relax.transform.RunCodegen(), + relax.transform.RemoveUnusedFunctions(), + ] + ) + + mod = seq(Conv2dReLUx2Partitioned) + print(mod.script()) + + target = tvm.target.Target("llvm") + ex = relax.vm.build(mod, target) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + f = vm["main"] + + data_np = np.random.randn(1, 64, 56, 56).astype("float32") + weight1_np = np.random.randn(64, 64, 3, 3).astype("float32") + weight2_np = np.random.randn(64, 64, 3, 3).astype("float32") + out = f(tvm.nd.array(data_np), tvm.nd.array(weight1_np), tvm.nd.array(weight2_np)).numpy() + + relay_mod = tvm.IRModule.from_expr(get_relay_conv2d_relu_x2(data_np.shape, weight1_np.shape)) + + ref = ( + relay.create_executor("graph", mod=relay_mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight1_np, weight2_np]) + .numpy() + ) + + print(np.max(np.abs(out - ref)), np.mean(np.abs(out - ref))) + + +if __name__ == "__main__": + test_conv2d_partition() + # test_dnnl_offload()