Skip to content

Commit 2eb1352

Browse files
author
ssjia
committed
[ET-VK] buffer implementation of rotary positional embeddings
Pull Request resolved: #15620 Title says it all! ghstack-source-id: 321258711 @exported-using-ghexport Differential Revision: [D86340338](https://our.internmc.facebook.com/intern/diff/D86340338/)
1 parent 47fde1e commit 2eb1352

File tree

7 files changed

+139
-48
lines changed

7 files changed

+139
-48
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(
138138

139139
# Magic number to limit "lookahead" when tracing through users of an operator
140140
# to constrain the representation of its arguments/outputs.
141-
self.max_trace_search_depth = 20
141+
self.max_trace_search_depth = None
142142

143143
def is_valid_op_node(self, node: Any) -> bool:
144144
"""

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ def register_sdpa_ops():
692692
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
693693
def register_rotary_emb_op():
694694
return OpFeatures(
695-
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
695+
inputs_storage=utils.CONTIGUOUS_ANY,
696696
supports_resize=True,
697697
)
698698

backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ int load_embedding_idx(const TensorIndex4D out_tidx) {
3838
indices_tidx.data.xyz = out_tidx.data.yzw;
3939
indices_tidx.data.w = 0;
4040

41-
TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple(
42-
indices_tidx, indices);
41+
TextureElementIndex elem_pos = tensor4d_idx_to_texture_element_idx_simple(
42+
indices, indices_tidx);
4343

4444
const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0);
4545
return in_texel[elem_pos.comp];
@@ -61,7 +61,7 @@ void main() {
6161
return;
6262
}
6363

64-
TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp);
64+
TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);
6565
const int embedding_idx = load_embedding_idx(out_tidx);
6666

6767
const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x);

backends/vulkan/runtime/graph/ops/glsl/indexing.glslh

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,20 @@ struct TensorIndex4D {
147147
ivec4 data;
148148
};
149149

150+
TensorIndex4D zero_tensor4d_idx() {
151+
TensorIndex4D tidx;
152+
tidx.data = ivec4(0);
153+
return tidx;
154+
}
155+
156+
bool out_of_bounds(const TensorIndex4D tidx, const BufferMetadata meta) {
157+
return any(greaterThanEqual(tidx.data, meta.sizes[0]));
158+
}
159+
160+
bool out_of_bounds(const TensorIndex4D tidx, const TextureMetadata meta) {
161+
return any(greaterThanEqual(tidx.data, meta.sizes));
162+
}
163+
150164
//
151165
// TextureElementIndex
152166
//
@@ -245,15 +259,9 @@ void clamp_tensor_idx(const BufferMetadata meta, inout TensorIndex tidx) {
245259
tidx.data[1] = min(tidx.data[1], meta.sizes[1] - 1);
246260
}
247261

248-
TensorIndex4D zero_tensor4d_idx() {
249-
TensorIndex4D tidx;
250-
tidx.data = ivec4(0);
251-
return tidx;
252-
}
253-
254262
// Does not account for axis mapping or batches
255-
TensorIndex4D texture_pos_to_tensor_idx_simple(
256-
const ivec3 pos, const TextureMetadata meta) {
263+
TensorIndex4D texture_pos_to_tensor4d_idx_simple(
264+
const TextureMetadata meta, const ivec3 pos) {
257265
TensorIndex4D tidx;
258266
tidx.data.xyz = pos;
259267
tidx.data.w = 0;
@@ -262,8 +270,20 @@ TensorIndex4D texture_pos_to_tensor_idx_simple(
262270
}
263271

264272
// Does not account for axis mapping or batches
265-
TextureElementIndex tensor_idx_to_texture_element_idx_simple(
266-
const TensorIndex4D tidx, const TextureMetadata meta) {
273+
ivec3 tensor4d_idx_to_texel_pos_simple(
274+
const TextureMetadata meta, const TensorIndex4D tidx) {
275+
ivec3 texel_pos;
276+
277+
const int packed_dim_idx = tidx.data[meta.packed_dim];
278+
279+
texel_pos = tidx.data.xyz;
280+
texel_pos[meta.packed_dim] = div_4(packed_dim_idx);
281+
return texel_pos;
282+
}
283+
284+
// Does not account for axis mapping or batches
285+
TextureElementIndex tensor4d_idx_to_texture_element_idx_simple(
286+
const TextureMetadata meta, const TensorIndex4D tidx) {
267287
const int packed_dim_idx = tidx.data[meta.packed_dim];
268288
TextureElementIndex tex_idx;
269289
tex_idx.pos = tidx.data.xyz;
@@ -272,6 +292,16 @@ TextureElementIndex tensor_idx_to_texture_element_idx_simple(
272292
return tex_idx;
273293
}
274294

295+
uint tensor4d_idx_to_linear_idx(
296+
const BufferMetadata meta,
297+
const TensorIndex4D tidx) {
298+
uint lin_idx = 0;
299+
for (int d = 0; d < 4; ++d) {
300+
lin_idx += meta.strides[0][d] * tidx.data[d];
301+
}
302+
return lin_idx;
303+
}
304+
275305
//
276306
// Debug utilities
277307
//

backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,29 @@
1313
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
1414

1515
${define_required_extensions(DTYPE)}
16+
${define_active_storage_type(STORAGE)}
1617

1718
layout(std430) buffer;
1819

19-
${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)}
20-
${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)}
21-
${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)}
22-
${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)}
23-
${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)}
24-
${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)}
25-
${layout_declare_ubo(B, "ivec3", "xqout_limits")}
26-
${layout_declare_ubo(B, "ivec3", "xkout_limits")}
20+
#include "indexing.glslh"
2721

28-
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
22+
${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array=False)}
23+
${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array=False)}
24+
${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array=False)}
25+
${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array=False)}
26+
${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array=False)}
27+
${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array=False)}
2928

30-
layout(constant_id = 3) const int packed_dim = 0;
29+
$if STORAGE == "buffer":
30+
${layout_declare_ubo(B, "BufferMetadata", "xqout")}
31+
${layout_declare_ubo(B, "BufferMetadata", "xkout")}
32+
${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")}
33+
$else:
34+
${layout_declare_ubo(B, "TextureMetadata", "xqout")}
35+
${layout_declare_ubo(B, "TextureMetadata", "xkout")}
36+
${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")}
3137

32-
#include "indexing_utils.h"
38+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3339

3440
/*
3541
* This shader computes rotary positional embeddings which are used in the Llama
@@ -39,7 +45,7 @@ layout(constant_id = 3) const int packed_dim = 0;
3945
* 1. xq (batch_size, sequence_len, num_heads, head_dim)
4046
* 2. xk (batch_size, sequence_len, num_kv_heads, head_dim)
4147
* 3. freqs_cos (sequence_len, head_dim / 2)
42-
* 4. freqs_cos (sequence_len, head_dim / 2)
48+
* 4. freqs_sin (sequence_len, head_dim / 2)
4349
*
4450
* Two output tensors are produced, with the same shapes as xq and xk
4551
* respectively.
@@ -66,23 +72,43 @@ void main() {
6672
// Each thread will write to two output locations to maximize data re-use.
6773
// One texel loaded from the freqs_cos/freqs_sin tensors can be used to
6874
// calculate two output texels.
69-
const ivec3 x_pos_1 = ivec3(
70-
gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz);
71-
const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz);
75+
TensorIndex4D out_tidx_1 = zero_tensor4d_idx();
76+
out_tidx_1.data.x = int(gl_GlobalInvocationID.x) * 8;
77+
out_tidx_1.data.yz = ivec2(gl_GlobalInvocationID.yz);
78+
79+
TensorIndex4D out_tidx_2 = out_tidx_1;
80+
out_tidx_2.data.x += 4;
7281

73-
if (any(greaterThanEqual(x_pos_2, xqout_limits))) {
82+
if (out_of_bounds(out_tidx_2, xqout)) {
7483
return;
7584
}
7685

77-
const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0);
86+
TensorIndex4D freqs_tidx = zero_tensor4d_idx();
87+
freqs_tidx.data.x = int(gl_GlobalInvocationID.x) * 4;
88+
freqs_tidx.data.y = out_tidx_1.data.z;
7889

79-
VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos);
80-
VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos);
90+
#ifdef USING_BUFFER
91+
const uint freqs_texel_bufi = div_4(tensor4d_idx_to_linear_idx(freqs_cos, freqs_tidx));
92+
VEC4_T cos_tex = t_freqs_cos[freqs_texel_bufi];
93+
VEC4_T sin_tex = t_freqs_sin[freqs_texel_bufi];
8194

82-
// Compute xqout
95+
uint x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_1));
96+
uint x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_2));
97+
VEC4_T x_tex_1 = t_xq[x_texel_bufi_1];
98+
VEC4_T x_tex_2 = t_xq[x_texel_bufi_2];
99+
100+
#else // USING_TEXTURE
101+
const ivec3 freqs_pos = tensor4d_idx_to_texel_pos_simple(freqs_cos, freqs_tidx);
102+
VEC4_T cos_tex = texelFetch(t_freqs_cos, freqs_pos, 0);
103+
VEC4_T sin_tex = texelFetch(t_freqs_sin, freqs_pos, 0);
83104

84-
VEC4_T x_tex_1 = load_texel(xq, x_pos_1);
85-
VEC4_T x_tex_2 = load_texel(xq, x_pos_2);
105+
const ivec3 x_pos_1 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_1);
106+
const ivec3 x_pos_2 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_2);
107+
VEC4_T x_tex_1 = texelFetch(t_xq, x_pos_1, 0);
108+
VEC4_T x_tex_2 = texelFetch(t_xq, x_pos_2, 0);
109+
#endif
110+
111+
// Compute xqout
86112

87113
// Separate into even and odd elements
88114
VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
@@ -94,20 +120,34 @@ void main() {
94120
VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
95121
VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);
96122

97-
write_texel(xqout, x_pos_1, xout_tex_1);
98-
write_texel(xqout, x_pos_2, xout_tex_2);
123+
#ifdef USING_BUFFER
124+
t_xqout[x_texel_bufi_1] = xout_tex_1;
125+
t_xqout[x_texel_bufi_2] = xout_tex_2;
126+
#else // USING_TEXTURE
127+
imageStore(t_xqout, x_pos_1, xout_tex_1);
128+
imageStore(t_xqout, x_pos_2, xout_tex_2);
129+
#endif
99130

100131
// n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout
101132
// may have a larger height dim than xk and xkout. Only compute xkout if this
102133
// invocation is still within bounds.
103-
if (any(greaterThanEqual(x_pos_2, xkout_limits))) {
134+
if (out_of_bounds(out_tidx_2, xkout)) {
104135
return;
105136
}
106137

107138
// Compute xkout
108139

109-
x_tex_1 = load_texel(xk, x_pos_1);
110-
x_tex_2 = load_texel(xk, x_pos_2);
140+
#ifdef USING_BUFFER
141+
x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_1));
142+
x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_2));
143+
144+
x_tex_1 = t_xk[x_texel_bufi_1];
145+
x_tex_2 = t_xk[x_texel_bufi_2];
146+
147+
#else // USING_TEXTURE
148+
x_tex_1 = texelFetch(t_xk, x_pos_1, 0);
149+
x_tex_2 = texelFetch(t_xk, x_pos_2, 0);
150+
#endif
111151

112152
x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
113153
x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);
@@ -118,6 +158,11 @@ void main() {
118158
xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
119159
xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);
120160

121-
write_texel(xkout, x_pos_1, xout_tex_1);
122-
write_texel(xkout, x_pos_2, xout_tex_2);
161+
#ifdef USING_BUFFER
162+
t_xkout[x_texel_bufi_1] = xout_tex_1;
163+
t_xkout[x_texel_bufi_2] = xout_tex_2;
164+
#else // USING_TEXTURE
165+
imageStore(t_xkout, x_pos_1, xout_tex_1);
166+
imageStore(t_xkout, x_pos_2, xout_tex_2);
167+
#endif
123168
}

backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ rotary_embedding:
33
DTYPE: float
44
STORAGE: texture3d
55
generate_variant_forall:
6+
STORAGE:
7+
- VALUE: texture3d
8+
- VALUE: buffer
69
DTYPE:
710
- VALUE: half
811
- VALUE: float

backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,17 @@ utils::uvec3 rotary_embedding_global_wg_size(
4343

4444
const ValueRef xq_out = args.at(0).refs.at(0);
4545

46-
utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out);
47-
global_wg_size[0] /= 2;
46+
// Head dim texel size
47+
const uint32_t D4 = utils::div_up_4(graph->size_at<uint32_t>(-1, xq_out));
48+
// Divide by 2 since each invocation computes 2 output locations
49+
const uint32_t D8 = utils::div_up(D4, uint32_t(2));
4850

49-
return global_wg_size;
51+
// Number of query heads
52+
const uint32_t QH = graph->size_at<uint32_t>(-2, xq_out);
53+
// Input tokens sequence length
54+
const uint32_t S = graph->size_at<uint32_t>(-3, xq_out);
55+
56+
return {D8, QH, S};
5057
}
5158

5259
void add_rotary_embedding_node(
@@ -73,8 +80,14 @@ void add_rotary_embedding_node(
7380
VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin));
7481

7582
std::string kernel_name = "rotary_embedding";
83+
add_storage_type_suffix(kernel_name, graph.storage_type_of(xq_out));
7684
add_dtype_suffix(kernel_name, graph.dtype_of(xq_out));
7785

86+
vkapi::ParamsBindList param_ubos = {
87+
graph.meta_ubo(xq_out),
88+
graph.meta_ubo(xk_out),
89+
graph.meta_ubo(freqs_cos)};
90+
7891
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
7992
graph,
8093
VK_KERNEL_FROM_STR(kernel_name),
@@ -84,7 +97,7 @@ void add_rotary_embedding_node(
8497
{{{xq_out, xk_out}, vkapi::kWrite},
8598
{{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}},
8699
// Parameter buffers
87-
{graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)},
100+
param_ubos,
88101
// Push Constants
89102
{},
90103
// Specialization Constants

0 commit comments

Comments
 (0)