to-do-list
- POMO项目与heron项目结合的问题
- 特征提取分析过程,NLTSP认为优化时间是消耗在特征分析过程,因此NLTSP重新设计新的特征提取,—- 这里要做的工作是自己编写程序测试分析每一步需要的时间,进行分析
- AMOS、GTA、Bayesian、Familyseer三者项目的比较分析
tensorization_phases/compute_transfrom.py MappingGenerator → SAEntryGenerator → EntryGenerator VMappingGenerator → CDParamGenerator → ParamGenerator MappingApplier.apply的调用顺序 _ffi_api.MappingState _ffi_api.MappingRequest _ffi_api.MappingMainOp main_op_mapping(init, request)
tensorization_phases/schedulers/cuda_v2.py CUDAScheduleGeneratorV2 → AcceleratorScheduleGenerator → SAEntryGenerator CUDAKernelParamGeneratorV2 → CDParamGenerator → ParamGenerator
tensorization_phases/schedule_base.py
SplitFactorGenerator → CDParamGenerator
VectorizeLengthGenerator
UnrollStepGenerator
InlineGenerator
SplitKGenerator
search/parameter.py SAEntryGenerator → EntryGenerator CDParamGenerator → ParamGenerator FlipFlopParamGenerator
hw_abstraction/hw_abs_base.py MemroyAbstraction → HardwareAbstraction ComputeAbstraction → HardwareAbstraction ElementwiseMemoryAbstraction → MemoryAbstraction ElementwiseComputeAbsraction → ComputeAbstraction
hw_abstraction/cuda/wmma_base.py WMMAStoreMatrixSync → MemoryAbstraction WMMALoadMatrixSync → MemoryAbstraction WMMAFillFragment → ComputeAbstraction WMMAMmmaSync → ComputeAbstraction
hw_abs_dag/hw_abs_dag_base.py hw_abs_dag/cuda/wmma_base.py hw_abs_dag/cuda/wmma_fp16.py WMMAFp16Fp16 → WMMABaseHwAbsDAG → HardwareAbstractionDAG WMMABin1Int32 WMMAFp16Fp32 WMMAFp64Fp64 WMMAInt4Int32 WMMAInt8Int32 WMMATf32Fp32
compute(wmma_C, body=[T.reduce(T.comm_reducer(lambda x, y: x + y, [T.float16(0)]), source=[wmma_A[i, rk] * wmma_B[rk, j]], init=[], axis=[T.iter_var(rk, T.Range(0, 16), “CommReduce”, “”)], condition=T.bool(True), value_index=0)], axis=[T.iter_var(i, T.Range(0, 32), “DataPar”, “”), T.iter_var(j, T.Range(0, 8), “DataPar”, “”)], reduce_axis=[T.iter_var(rk, T.Range(0, 16), “CommReduce”, “”)], tag=, attrs={}):
compute(C.vmap.main.cmap.main, body=[T.reduce(T.comm_reducer(lambda x, y: x + y, [T.float16(0)]), source=[T.Cast(“float16”, A.vmap.input.cmap.input[i_o_main, rk_o_main, i_main, rk_main] * B.vmap.input.cmap.input[j_o_main, rk_o_main, rk_main, j_main])], init=[], axis=[T.iter_var(rk_o_main, T.Range(0, 128), “CommReduce”, “”), T.iter_var(rk_main, T.Range(0, 16), “CommReduce”, “”)], condition=T.bool(True), value_index=0)], axis=[T.iter_var(i_o_main, T.Range(0, 1), “DataPar”, “”), T.iter_var(j_o_main, T.Range(0, 125), “DataPar”, “”), T.iter_var(i_main, T.Range(0, 32), “DataPar”, “”), T.iter_var(j_main, T.Range(0, 8), “DataPar”, “”)], reduce_axis=[T.iter_var(rk_o_main, T.Range(0, 128), “CommReduce”, “”), T.iter_var(rk_main, T.Range(0, 16), “CommReduce”, “”)], tag=.vmap.main.cmap.main, attrs={})
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module class Module: @T.prim_func def main(A: T.Buffer((32, 2048), “float16”), B: T.Buffer((1000, 2048), “float16”), dense: T.Buffer((32, 1000), “float16”)): T.func_attr({“from_legacy_te_schedule”: T.bool(True), “tir.noalias”: T.bool(True)}) blockIdx_x = T.launch_thread(“blockIdx.x”, 125) dense_wmma_accumulator = T.allocate([256], “float16”, “wmma.accumulator”) A_shared = T.allocate([1792], “float16”, “shared”) B_shared = T.allocate([320], “float16”, “shared”) A_shared_wmma_matrix_a = T.allocate([512], “float16”, “wmma.matrix_a”) B_shared_wmma_matrix_b = T.allocate([128], “float16”, “wmma.matrix_b”) with T.launch_thread(“threadIdx.y”, 1) as threadIdx_y: T.tvm_fill_fragment(dense_wmma_accumulator, 32, 8, 16, 0, T.float32(0)) for k_outer in T.unroll(64): for ax0_ax1_fused_outer_outer_outer in T.unroll(4): threadIdx_y_1 = T.launch_thread(“threadIdx.y”, 1) threadIdx_x = T.launch_thread(“threadIdx.x”, 32) A_shared_1 = T.Buffer((1280,), “float16”, data=A_shared, scope=”shared”) A_1 = T.Buffer((65536,), “float16”, data=A.data) A_shared_1[ax0_ax1_fused_outer_outer_outer * 320 + threadIdx_x // 4 * 40 + threadIdx_x % 4 * 8:ax0_ax1_fused_outer_outer_outer * 320 + threadIdx_x // 4 * 40 + threadIdx_x % 4 * 8 + 8] = A_1[ax0_ax1_fused_outer_outer_outer * 16384 + threadIdx_x // 4 * 2048 + k_outer * 32 + threadIdx_x % 4 * 8:ax0_ax1_fused_outer_outer_outer * 16384 + threadIdx_x // 4 * 2048 + k_outer * 32 + threadIdx_x % 4 * 8 + 8] for ax0_ax1_fused_outer_outer_outer in T.unroll(2): threadIdx_y_1 = T.launch_thread(“threadIdx.y”, 1) threadIdx_x = T.launch_thread(“threadIdx.x”, 32) B_shared_1 = T.Buffer((320,), “float16”, data=B_shared, scope=”shared”) B_1 = T.Buffer((2048000,), “float16”, data=B.data) B_shared_1[ax0_ax1_fused_outer_outer_outer * 160 + threadIdx_x // 8 * 40 + threadIdx_x % 8 * 4:ax0_ax1_fused_outer_outer_outer * 160 + threadIdx_x // 8 * 40 + threadIdx_x % 8 * 4 + 4] = B_1[blockIdx_x * 16384 + ax0_ax1_fused_outer_outer_outer * 8192 + threadIdx_x // 8 * 2048 + k_outer * 32 + threadIdx_x % 8 * 4:blockIdx_x * 16384 + ax0_ax1_fused_outer_outer_outer * 8192 + threadIdx_x // 8 * 2048 + k_outer * 32 + threadIdx_x % 8 * 4 + 4] for k_inner_outer in T.unroll(2): cse_var_1: T.int32 = k_inner_outer * 16 T.tvm_load_matrix_sync(A_shared_wmma_matrix_a, 32, 8, 16, 0, T.tvm_access_ptr(T.type_annotation(“float16”), A_shared, cse_var_1, 1280, 1), 40, “row_major”) T.tvm_load_matrix_sync(B_shared_wmma_matrix_b, 32, 8, 16, 0, T.tvm_access_ptr(T.type_annotation(“float16”), B_shared, cse_var_1, 320, 1), 40, “col_major”) T.tvm_mma_sync(dense_wmma_accumulator, 0, A_shared_wmma_matrix_a, 0, B_shared_wmma_matrix_b, 0, dense_wmma_accumulator, 0) T.tvm_store_matrix_sync(dense_wmma_accumulator, 32, 8, 16, 0, T.tvm_access_ptr(T.type_annotation(“float16”), A_shared, 0, 1792, 2), 56, “row_major”) threadIdx_y = T.launch_thread(“threadIdx.y”, 1) threadIdx_x = T.launch_thread(“threadIdx.x”, 32) for i_inner_inner_inner_inner_outer in T.unroll(4): dense_1 = T.Buffer((32000,), “float16”, data=dense.data) A_shared_1 = T.Buffer((1792,), “float16”, data=A_shared, scope=”shared”) dense_1[threadIdx_x // 8 * 8000 + i_inner_inner_inner_inner_outer * 2000 + blockIdx_x * 8 + threadIdx_x % 8:threadIdx_x // 8 * 8000 + i_inner_inner_inner_inner_outer * 2000 + blockIdx_x * 8 + threadIdx_x % 8 + 2000:1000] = A_shared_1[threadIdx_x // 8 * 448 + i_inner_inner_inner_inner_outer * 112 + threadIdx_x % 8:threadIdx_x // 8 * 448 + i_inner_inner_inner_inner_outer * 112 + threadIdx_x % 8 + 112:56]
#[version = “0.0.5”] primfn(A_1: handle, B_1: handle, C.vmap.output_1: handle) -> () attr = {“global_symbol”: “main”, “tir.noalias”: True} buffers = {C.vmap.output: Buffer(C.vmap.output_2: Pointer(float16), float16, [32, 1000], []), B: Buffer(B_2: Pointer(float16), float16, [2048, 1000], []), A: Buffer(A_2: Pointer(float16), float16, [32, 2048], [])} buffer_map = {A_1: A, B_1: B, C.vmap.output_1: C.vmap.output} { attr [memcpy_dst: Pointer(float16)] “storage_scope” = “global”; allocate(memcpy_dst, float16, [32000]) { attr [IterVar(blockIdx.x: int32, (nullptr), “ThreadIndex”, “blockIdx.x”)] “thread_extent” = 25; attr [C.vmap.main.cmap.main: Pointer(float16)] “storage_scope” = “local”; allocate(C.vmap.main.cmap.main, float16, [1280]); attr [A.vmap.input.cmap.input.shared: Pointer(float16)] “storage_scope” = “shared”; allocate(A.vmap.input.cmap.input.shared, float16, [2048]); attr [B.vmap.input.cmap.input.shared: Pointer(float16)] “storage_scope” = “shared”; allocate(B.vmap.input.cmap.input.shared, float16, [2560]); attr [memcpy_dst_1: Pointer(float16)] “storage_scope” = “local”; allocate(memcpy_dst_1, float16, [512]); attr [memcpy_dst_2: Pointer(float16)] “storage_scope” = “local”; allocate(memcpy_dst_2, float16, [640]); attr [IterVar(threadIdx.y: int32, (nullptr), “ThreadIndex”, “threadIdx.y”)] “thread_extent” = 1 { for (i1.init: int32, 0, 5) { @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::fill_fragment”, C.vmap.main.cmap.main, 32, 8, 16, i1.init, 0f16, dtype=handle) } for (rk.o.main.outer.outer: int32, 0, 32) { attr [IterVar(threadIdx.x: int32, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; A.vmap.input.cmap.input.shared[ramp((threadIdx.x8), 1, 8)] = (float16x8)A_2[ramp((((floordiv(threadIdx.x, 2)2048) + (rk.o.main.outer.outer64)) + (floormod(threadIdx.x, 2)8)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; A.vmap.input.cmap.input.shared[ramp((((floordiv(threadIdx.x, 2) + 16)16) + (floormod(threadIdx.x, 2)8)), 1, 8)] = (float16x8)A_2[ramp(((((floordiv(threadIdx.x, 2) + 16)2048) + (rk.o.main.outer.outer64)) + (floormod(threadIdx.x, 2)8)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; A.vmap.input.cmap.input.shared[ramp((((floordiv(threadIdx.x, 2)16) + (floormod(threadIdx.x, 2)8)) + 512), 1, 8)] = (float16x8)A_2[ramp(((((floordiv(threadIdx.x, 2)2048) + (rk.o.main.outer.outer64)) + (floormod(threadIdx.x, 2)8)) + 16), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; A.vmap.input.cmap.input.shared[ramp((((floordiv(((threadIdx.x8) + 768), 512)512) + (floormod((floordiv(threadIdx.x, 2) + 16), 32)16)) + (floormod(threadIdx.x, 2)8)), 1, 8)] = (float16x8)A_2[ramp(((((floormod((floordiv(threadIdx.x, 2) + 16), 32)2048) + (rk.o.main.outer.outer64)) + (floordiv(((threadIdx.x8) + 768), 512)16)) + (floormod(threadIdx.x, 2)8)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; A.vmap.input.cmap.input.shared[ramp((((floordiv(threadIdx.x, 2)16) + (floormod(threadIdx.x, 2)8)) + 1024), 1, 8)] = (float16x8)A_2[ramp(((((floordiv(threadIdx.x, 2)2048) + (rk.o.main.outer.outer64)) + (floormod(threadIdx.x, 2)8)) + 32), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; A.vmap.input.cmap.input.shared[ramp((((floordiv(((threadIdx.x8) + 1280), 512)512) + (floormod((floordiv(threadIdx.x, 2) + 16), 32)16)) + (floormod(threadIdx.x, 2)8)), 1, 8)] = (float16x8)A_2[ramp(((((floormod((floordiv(threadIdx.x, 2) + 16), 32)2048) + (rk.o.main.outer.outer64)) + (floordiv(((threadIdx.x8) + 1280), 512)16)) + (floormod(threadIdx.x, 2)8)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; A.vmap.input.cmap.input.shared[ramp((((floordiv(threadIdx.x, 2)16) + (floormod(threadIdx.x, 2)8)) + 1536), 1, 8)] = (float16x8)A_2[ramp(((((floordiv(threadIdx.x, 2)2048) + (rk.o.main.outer.outer64)) + (floormod(threadIdx.x, 2)8)) + 48), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; A.vmap.input.cmap.input.shared[ramp((((floordiv(((threadIdx.x8) + 1792), 512)512) + (floormod((floordiv(threadIdx.x, 2) + 16), 32)16)) + (floormod(threadIdx.x, 2)8)), 1, 8)] = (float16x8)A_2[ramp(((((floormod((floordiv(threadIdx.x, 2) + 16), 32)2048) + (rk.o.main.outer.outer64)) + (floordiv(((threadIdx.x8) + 1792), 512)16)) + (floormod(threadIdx.x, 2)8)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((threadIdx.x8), 1, 8)] = (float16x8)B_2[ramp((((rk.o.main.outer.outer64000) + (threadIdx.x1000)) + (blockIdx.x40)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((((floordiv(threadIdx.x, 16) + 2)128) + (floormod(threadIdx.x, 16)8)), 1, 8)] = (float16x8)B_2[ramp(((((rk.o.main.outer.outer64000) + ((floordiv(threadIdx.x, 16) + 2)16000)) + (floormod(threadIdx.x, 16)1000)) + (blockIdx.x40)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((((floordiv(threadIdx.x, 16)128) + (floormod(threadIdx.x, 16)8)) + 512), 1, 8)] = (float16x8)B_2[ramp((((((rk.o.main.outer.outer64000) + (floordiv(threadIdx.x, 16)16000)) + (floormod(threadIdx.x, 16)1000)) + (blockIdx.x40)) + 8), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((((floordiv((threadIdx.x + 96), 64)512) + (floormod((floordiv(threadIdx.x, 16) + 2), 4)128)) + (floormod(threadIdx.x, 16)8)), 1, 8)] = (float16x8)B_2[ramp((((((rk.o.main.outer.outer64000) + (floormod((floordiv(threadIdx.x, 16) + 2), 4)16000)) + (floormod(threadIdx.x, 16)1000)) + (blockIdx.x40)) + (floordiv((threadIdx.x + 96), 64)8)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((((floordiv(threadIdx.x, 16)128) + (floormod(threadIdx.x, 16)8)) + 1024), 1, 8)] = (float16x8)B_2[ramp((((((rk.o.main.outer.outer64000) + (floordiv(threadIdx.x, 16)16000)) + (floormod(threadIdx.x, 16)1000)) + (blockIdx.x40)) + 16), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((((floordiv((threadIdx.x + 160), 64)512) + (floormod((floordiv(threadIdx.x, 16) + 2), 4)128)) + (floormod(threadIdx.x, 16)8)), 1, 8)] = (float16x8)B_2[ramp((((((rk.o.main.outer.outer64000) + (floormod((floordiv(threadIdx.x, 16) + 2), 4)16000)) + (floormod(threadIdx.x, 16)1000)) + (blockIdx.x40)) + (floordiv((threadIdx.x + 160), 64)8)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((((floordiv(threadIdx.x, 16)128) + (floormod(threadIdx.x, 16)8)) + 1536), 1, 8)] = (float16x8)B_2[ramp((((((rk.o.main.outer.outer64000) + (floordiv(threadIdx.x, 16)16000)) + (floormod(threadIdx.x, 16)1000)) + (blockIdx.x40)) + 24), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((((floordiv((threadIdx.x + 224), 64)512) + (floormod((floordiv(threadIdx.x, 16) + 2), 4)128)) + (floormod(threadIdx.x, 16)8)), 1, 8)] = (float16x8)B_2[ramp((((((rk.o.main.outer.outer64000) + (floormod((floordiv(threadIdx.x, 16) + 2), 4)16000)) + (floormod(threadIdx.x, 16)1000)) + (blockIdx.x40)) + (floordiv((threadIdx.x + 224), 64)8)), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((((floordiv(threadIdx.x, 16)128) + (floormod(threadIdx.x, 16)8)) + 2048), 1, 8)] = (float16x8)B_2[ramp((((((rk.o.main.outer.outer64000) + (floordiv(threadIdx.x, 16)16000)) + (floormod(threadIdx.x, 16)1000)) + (blockIdx.x40)) + 32), 1, 8)] attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; B.vmap.input.cmap.input.shared[ramp((((floordiv((threadIdx.x + 288), 64)512) + (floormod((floordiv(threadIdx.x, 16) + 2), 4)128)) + (floormod(threadIdx.x, 16)8)), 1, 8)] = (float16x8)B_2[ramp((((((rk.o.main.outer.outer64000) + (floormod((floordiv(threadIdx.x, 16) + 2), 4)16000)) + (floormod(threadIdx.x, 16)1000)) + (blockIdx.x40)) + (floordiv((threadIdx.x + 288), 64)8)), 1, 8)] @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_1, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), A.vmap.input.cmap.input.shared, 0, 512, 1, dtype=handle), 16, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 0, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 1, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 512, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 2, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 1024, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 3, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 1536, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 4, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 2048, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 0, memcpy_dst_1, 0, memcpy_dst_2, 0, C.vmap.main.cmap.main, 0, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 1, memcpy_dst_1, 0, memcpy_dst_2, 1, C.vmap.main.cmap.main, 1, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 2, memcpy_dst_1, 0, memcpy_dst_2, 2, C.vmap.main.cmap.main, 2, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 3, memcpy_dst_1, 0, memcpy_dst_2, 3, C.vmap.main.cmap.main, 3, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 4, memcpy_dst_1, 0, memcpy_dst_2, 4, C.vmap.main.cmap.main, 4, False, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_1, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), A.vmap.input.cmap.input.shared, 512, 512, 1, dtype=handle), 16, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 128, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 1, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 640, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 2, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 1152, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 3, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 1664, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 4, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 2176, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 0, memcpy_dst_1, 0, memcpy_dst_2, 0, C.vmap.main.cmap.main, 0, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 1, memcpy_dst_1, 0, memcpy_dst_2, 1, C.vmap.main.cmap.main, 1, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 2, memcpy_dst_1, 0, memcpy_dst_2, 2, C.vmap.main.cmap.main, 2, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 3, memcpy_dst_1, 0, memcpy_dst_2, 3, C.vmap.main.cmap.main, 3, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 4, memcpy_dst_1, 0, memcpy_dst_2, 4, C.vmap.main.cmap.main, 4, False, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_1, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), A.vmap.input.cmap.input.shared, 1024, 512, 1, dtype=handle), 16, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 256, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 1, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 768, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 2, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 1280, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 3, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 1792, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 4, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 2304, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 0, memcpy_dst_1, 0, memcpy_dst_2, 0, C.vmap.main.cmap.main, 0, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 1, memcpy_dst_1, 0, memcpy_dst_2, 1, C.vmap.main.cmap.main, 1, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 2, memcpy_dst_1, 0, memcpy_dst_2, 2, C.vmap.main.cmap.main, 2, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 3, memcpy_dst_1, 0, memcpy_dst_2, 3, C.vmap.main.cmap.main, 3, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 4, memcpy_dst_1, 0, memcpy_dst_2, 4, C.vmap.main.cmap.main, 4, False, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_1, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), A.vmap.input.cmap.input.shared, 1536, 512, 1, dtype=handle), 16, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 384, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 1, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 896, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 2, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 1408, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 3, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 1920, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::load_matrix_sync”, memcpy_dst_2, 32, 8, 16, 4, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), B.vmap.input.cmap.input.shared, 2432, 128, 1, dtype=handle), 8, “nvcuda::wmma::row_major”, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 0, memcpy_dst_1, 0, memcpy_dst_2, 0, C.vmap.main.cmap.main, 0, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 1, memcpy_dst_1, 0, memcpy_dst_2, 1, C.vmap.main.cmap.main, 1, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 2, memcpy_dst_1, 0, memcpy_dst_2, 2, C.vmap.main.cmap.main, 2, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 3, memcpy_dst_1, 0, memcpy_dst_2, 3, C.vmap.main.cmap.main, 3, False, dtype=handle) @tir.amos_compute(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::mma_sync”, C.vmap.main.cmap.main, 4, memcpy_dst_1, 0, memcpy_dst_2, 4, C.vmap.main.cmap.main, 4, False, dtype=handle) } for (i1.inner: int32, 0, 5) { @tir.amos_memory(“cuda”, “wmma_fp16_fp16”, “nvcuda::wmma::store_matrix_sync”, C.vmap.main.cmap.main, 32, 8, 16, i1.inner, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), memcpy_dst, ((blockIdx.x1280) + (i1.inner256)), 256, 2, dtype=handle), 8, “nvcuda::wmma::mem_row_major”, dtype=handle) } } attr [IterVar(blockIdx.x_1: int32, (nullptr), “ThreadIndex”, “blockIdx.x”)] “thread_extent” = 4; for (i.output.j.output.fused.outer.outer.inner: int32, 0, 10) { attr [IterVar(threadIdx.y_1: int32, (nullptr), “ThreadIndex”, “threadIdx.y”)] “thread_extent” = 25; attr [IterVar(threadIdx.x_1: int32, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 32; C.vmap.output_2[((((blockIdx.x_18000) + (i.output.j.output.fused.outer.outer.inner800)) + (threadIdx.y_132)) + threadIdx.x_1)] = (float16)memcpy_dst[((((floordiv(floormod((((i.output.j.output.fused.outer.outer.inner800) + (threadIdx.y_132)) + threadIdx.x_1), 1000), 8)256) + (blockIdx.x_164)) + (floordiv((((i.output.j.output.fused.outer.outer.inner800) + (threadIdx.y_132)) + threadIdx.x_1), 1000)8)) + floormod(threadIdx.x_1, 8))] } } }
#[metadata] { “root”: 1, “nodes”: [ { “type_key”: “” }, { “type_key”: “Map”, “keys”: [ “IntImm” ], “data”: [2] }, { “type_key”: “Array”, “data”: [3] }, { “type_key”: “IntImm”, “attrs”: { “dtype”: “bool”, “value”: “1” } } ], “b64ndarrays”: [], “attrs”: {“tvm_version”: “0.8.dev0”} }