Pointer Arithmetics in Triton

| Views: | Total Words: 3.9k | Reading Time: 4 mins.

标题这个说法是我在 Triton 的 tutorial 里看到的:Matrix Multiplication

其实意思就是,在 Triton 这个比较接近内存结构的 DSL 里,我们需要使用指针来进行访存。进一步地,由于 Triton 推行 “Blocked Program, Scalar Threads”,我们写程序时候的思路也是用块状的方式写的。所以,我们实际上操纵的是 “一块指针簇”,或者说 “张量化的指针 / N-dim 指针数组”。

指针偏移

首先,在指针的语义上,下标操作(indexing)都是被转化为头地址 + 偏移量的模式,即

1
&X[i, j] = i * stride_xi + j * stride_xj

这里的 stride 即每个维度的步长,这和数据的排布方式有关。如果采用 row-major 的方式,例如一个 3*4 矩阵是这样

1
2
3
1  2  3  4
5 6 7 8
9 10 11 12

在内存上的排布就是这样

1
1 2 3 4 5 6 7 8 9 10 11 12

所以显然这时 stride_xi 是 4,stride_xj 是 1。

PS:在 numpy 中,你可以用 strides 来获得每个维度的步长,这通常会返回一个长度等于其 shape 维度的 tuple,其中第 i 个位置的值表示其它维一样,这一维度相邻的两个元素相隔多少 bytes。

指针块

前面说了,由于 Blocked Program 的关系,我们往往一次需要处理一小块的指针,也就是一个指针多维数组。假设我们给定了 X 和 Y 方向的块大小 BLOCK_SIZE_X 和 BLOCK_SIZE_Y,怎样快速得到这一片指针呢?

由于我们其实是知道步长的,因此我们可以先当成步长为 1 来处理,之后再乘上对应的步长就是对的。那么其实我们就是要快速生成这些坐标

1
2
3
4
[0, 0] [0, 1] ... [0, BLOCK_SIZE_Y-1]
[1, 0] [1, 1] ... [1, BLOCK_SIZE_Y-1]
...
[BLOCK_SIZE_X-1, 0] [BLOCK_SIZE_X-1, 1] ... [BLOCK_SIZE_X-1, BLOCK_SIZE_Y-1]

然后每个 [i, j] 再按第一部分那样乘上步长做偏移。很容易发现,这里用 broadcast 非常方便(因为可以观察出上面的矩阵就是由两个向量分别朝彼此垂直的方向 broadcast 出来得到的)。Triton 是可以手动 broadcast 的,语法和一些地方也比较像:[:, None] 就是给第二维加一个大小为 1 的维度。按照 broadcast 的规则,上面的块可以直接由

1
range(0, BLOCK_SIZE_X)[:, None] * stride_x + range(0, BLOCK_SIZE_Y)[None, :] + stride_y

得出。

我自己感觉这里最好玩的地方在于 stride 这个概念是可以被分离出来的,因为以前我自己想这种二维转一维的情况都是直接想着 i*n+j。有这个 stride 概念之后,我们可以不局限于 row-major,而可以通过改变 stride 来实现一些奇怪的访存顺序。而且在有这个 stride 概念后,我们只要相信这么一点:你想在某个维度上移动,只要加上 步数*这个维度的 stride 即可。(总结完之后我感觉其实也是比较直观的事情)

转置

我们可以来尝试一些更有意思的操作。通过一个简单的技巧,我们可以实现 load 出一个转置的块。

我们原本 load 一个块的话,[i, j] 这个位置的指针应该对应于 [i, j] 的数据,而在转置的情况下它对应于 [j, i] 的数据。也就是说,我们只需要在换一下两边的横竖关系(i.e. 把 [;, None] 和 [None, :] 调换一下)就能做到这一点:

1
range(0, BLOCK_SIZE_X)[None, :] * stride_x + range(0, BLOCK_SIZE_Y)[:, None] + stride_y

至于为什么不用 Triton 语言的 transpose (a.T 或者 tl.trans(a))嘛,主要是在运行的过程中碰到了一些 issue:

Author: SiriusNEO

Published on: Metric Space

All posts on this blog are licensed under the CC BY-NC-SA 4.0 license unless otherwise noted.