标题这个说法是我在 Triton 的 tutorial 里看到的:Matrix Multiplication
其实意思就是,在 Triton 这个比较接近内存结构的 DSL 里,我们需要使用指针来进行访存。进一步地,由于 Triton 推行 “Blocked Program, Scalar Threads”,我们写程序时候的思路也是用块状的方式写的。所以,我们实际上操纵的是 “一块指针簇”,或者说 “张量化的指针 / N-dim 指针数组”。
指针偏移
首先,在指针的语义上,下标操作(indexing)都是被转化为头地址 + 偏移量的模式,即
1 | &X[i, j] = i * stride_xi + j * stride_xj |
这里的 stride 即每个维度的步长,这和数据的排布方式有关。如果采用 row-major 的方式,例如一个 3*4 矩阵是这样
1 | 1 2 3 4 |
在内存上的排布就是这样
1 | 1 2 3 4 5 6 7 8 9 10 11 12 |
所以显然这时 stride_xi 是 4,stride_xj 是 1。
PS:在 numpy 中,你可以用 strides 来获得每个维度的步长,这通常会返回一个长度等于其 shape 维度的 tuple,其中第 i 个位置的值表示其它维一样,这一维度相邻的两个元素相隔多少 bytes。
指针块
前面说了,由于 Blocked Program 的关系,我们往往一次需要处理一小块的指针,也就是一个指针多维数组。假设我们给定了 X 和 Y 方向的块大小 BLOCK_SIZE_X 和 BLOCK_SIZE_Y,怎样快速得到这一片指针呢?
由于我们其实是知道步长的,因此我们可以先当成步长为 1 来处理,之后再乘上对应的步长就是对的。那么其实我们就是要快速生成这些坐标
1 | [0, 0] [0, 1] ... [0, BLOCK_SIZE_Y-1] |
然后每个 [i, j] 再按第一部分那样乘上步长做偏移。很容易发现,这里用 broadcast 非常方便(因为可以观察出上面的矩阵就是由两个向量分别朝彼此垂直的方向 broadcast 出来得到的)。Triton 是可以手动 broadcast 的,语法和一些地方也比较像:[:, None] 就是给第二维加一个大小为 1 的维度。按照 broadcast 的规则,上面的块可以直接由
1 | range(0, BLOCK_SIZE_X)[:, None] * stride_x + range(0, BLOCK_SIZE_Y)[None, :] + stride_y |
得出。
我自己感觉这里最好玩的地方在于 stride 这个概念是可以被分离出来的,因为以前我自己想这种二维转一维的情况都是直接想着 i*n+j。有这个 stride 概念之后,我们可以不局限于 row-major,而可以通过改变 stride 来实现一些奇怪的访存顺序。而且在有这个 stride 概念后,我们只要相信这么一点:你想在某个维度上移动,只要加上 步数*这个维度的 stride 即可。(总结完之后我感觉其实也是比较直观的事情)
转置
我们可以来尝试一些更有意思的操作。通过一个简单的技巧,我们可以实现 load 出一个转置的块。
我们原本 load 一个块的话,[i, j] 这个位置的指针应该对应于 [i, j] 的数据,而在转置的情况下它对应于 [j, i] 的数据。也就是说,我们只需要在换一下两边的横竖关系(i.e. 把 [;, None] 和 [None, :] 调换一下)就能做到这一点:
1 | range(0, BLOCK_SIZE_X)[None, :] * stride_x + range(0, BLOCK_SIZE_Y)[:, None] + stride_y |
至于为什么不用 Triton 语言的 transpose (a.T 或者 tl.trans(a))嘛,主要是在运行的过程中碰到了一些 issue: