1313#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
1414
1515${define_required_extensions(DTYPE)}
16+ ${define_active_storage_type(STORAGE)}
1617
1718layout (std430) buffer ;
1819
19- ${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)}
20- ${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)}
21- ${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)}
22- ${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)}
23- ${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)}
24- ${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)}
25- ${layout_declare_ubo(B, "ivec3 ", "xqout_limits")}
26- ${layout_declare_ubo(B, "ivec3 ", "xkout_limits")}
20+ #include "indexing.glslh"
2721
28- layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
22+ ${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array= False)}
23+ ${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array= False)}
24+ ${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array= False)}
25+ ${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array= False)}
26+ ${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array= False)}
27+ ${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array= False)}
2928
30- layout (constant_id = 3 ) const int packed_dim = 0 ;
29+ $if STORAGE == "buffer ":
30+ ${layout_declare_ubo(B, "BufferMetadata", "xqout")}
31+ ${layout_declare_ubo(B, "BufferMetadata", "xkout")}
32+ ${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")}
33+ $else :
34+ ${layout_declare_ubo(B, "TextureMetadata", "xqout")}
35+ ${layout_declare_ubo(B, "TextureMetadata", "xkout")}
36+ ${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")}
3137
32- #include "indexing_utils.h"
38+ layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3339
3440/*
3541 * This shader computes rotary positional embeddings which are used in the Llama
@@ -39,7 +45,7 @@ layout(constant_id = 3) const int packed_dim = 0;
3945 * 1. xq (batch_size, sequence_len, num_heads, head_dim)
4046 * 2. xk (batch_size, sequence_len, num_kv_heads, head_dim)
4147 * 3. freqs_cos (sequence_len, head_dim / 2)
42- * 4. freqs_cos (sequence_len, head_dim / 2)
48+ * 4. freqs_sin (sequence_len, head_dim / 2)
4349 *
4450 * Two output tensors are produced, with the same shapes as xq and xk
4551 * respectively.
@@ -66,23 +72,43 @@ void main() {
6672 // Each thread will write to two output locations to maximize data re-use.
6773 // One texel loaded from the freqs_cos/freqs_sin tensors can be used to
6874 // calculate two output texels.
69- const ivec3 x_pos_1 = ivec3 (
70- gl_GlobalInvocationID.x * 2 , gl_GlobalInvocationID.yz);
71- const ivec3 x_pos_2 = ivec3 (x_pos_1.x + 1 , x_pos_1.yz);
75+ TensorIndex4D out_tidx_1 = zero_tensor4d_idx();
76+ out_tidx_1.data.x = int (gl_GlobalInvocationID.x) * 8 ;
77+ out_tidx_1.data.yz = ivec2 (gl_GlobalInvocationID.yz);
78+
79+ TensorIndex4D out_tidx_2 = out_tidx_1;
80+ out_tidx_2.data.x += 4 ;
7281
73- if (any ( greaterThanEqual (x_pos_2, xqout_limits) )) {
82+ if (out_of_bounds(out_tidx_2, xqout )) {
7483 return ;
7584 }
7685
77- const ivec3 freqs_pos = ivec3 (gl_GlobalInvocationID.xz, 0 );
86+ TensorIndex4D freqs_tidx = zero_tensor4d_idx();
87+ freqs_tidx.data.x = int (gl_GlobalInvocationID.x) * 4 ;
88+ freqs_tidx.data.y = out_tidx_1.data.z;
7889
79- VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos);
80- VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos);
90+ #ifdef USING_BUFFER
91+ const uint freqs_texel_bufi = div_4(tensor4d_idx_to_linear_idx(freqs_cos, freqs_tidx));
92+ VEC4_T cos_tex = t_freqs_cos[freqs_texel_bufi];
93+ VEC4_T sin_tex = t_freqs_sin[freqs_texel_bufi];
8194
82- // Compute xqout
95+ uint x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_1));
96+ uint x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_2));
97+ VEC4_T x_tex_1 = t_xq[x_texel_bufi_1];
98+ VEC4_T x_tex_2 = t_xq[x_texel_bufi_2];
99+
100+ #else // USING_TEXTURE
101+ const ivec3 freqs_pos = tensor4d_idx_to_texel_pos_simple(freqs_cos, freqs_tidx);
102+ VEC4_T cos_tex = texelFetch(t_freqs_cos, freqs_pos, 0 );
103+ VEC4_T sin_tex = texelFetch(t_freqs_sin, freqs_pos, 0 );
83104
84- VEC4_T x_tex_1 = load_texel(xq, x_pos_1);
85- VEC4_T x_tex_2 = load_texel(xq, x_pos_2);
105+ const ivec3 x_pos_1 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_1);
106+ const ivec3 x_pos_2 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_2);
107+ VEC4_T x_tex_1 = texelFetch(t_xq, x_pos_1, 0 );
108+ VEC4_T x_tex_2 = texelFetch(t_xq, x_pos_2, 0 );
109+ #endif
110+
111+ // Compute xqout
86112
87113 // Separate into even and odd elements
88114 VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
@@ -94,20 +120,34 @@ void main() {
94120 VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
95121 VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);
96122
97- write_texel(xqout, x_pos_1, xout_tex_1);
98- write_texel(xqout, x_pos_2, xout_tex_2);
123+ #ifdef USING_BUFFER
124+ t_xqout[x_texel_bufi_1] = xout_tex_1;
125+ t_xqout[x_texel_bufi_2] = xout_tex_2;
126+ #else // USING_TEXTURE
127+ imageStore(t_xqout, x_pos_1, xout_tex_1);
128+ imageStore(t_xqout, x_pos_2, xout_tex_2);
129+ #endif
99130
100131 // n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout
101132 // may have a larger height dim than xk and xkout. Only compute xkout if this
102133 // invocation is still within bounds.
103- if (any ( greaterThanEqual (x_pos_2, xkout_limits) )) {
134+ if (out_of_bounds(out_tidx_2, xkout )) {
104135 return ;
105136 }
106137
107138 // Compute xkout
108139
109- x_tex_1 = load_texel(xk, x_pos_1);
110- x_tex_2 = load_texel(xk, x_pos_2);
140+ #ifdef USING_BUFFER
141+ x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_1));
142+ x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_2));
143+
144+ x_tex_1 = t_xk[x_texel_bufi_1];
145+ x_tex_2 = t_xk[x_texel_bufi_2];
146+
147+ #else // USING_TEXTURE
148+ x_tex_1 = texelFetch(t_xk, x_pos_1, 0 );
149+ x_tex_2 = texelFetch(t_xk, x_pos_2, 0 );
150+ #endif
111151
112152 x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
113153 x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);
@@ -118,6 +158,11 @@ void main() {
118158 xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
119159 xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);
120160
121- write_texel(xkout, x_pos_1, xout_tex_1);
122- write_texel(xkout, x_pos_2, xout_tex_2);
161+ #ifdef USING_BUFFER
162+ t_xkout[x_texel_bufi_1] = xout_tex_1;
163+ t_xkout[x_texel_bufi_2] = xout_tex_2;
164+ #else // USING_TEXTURE
165+ imageStore(t_xkout, x_pos_1, xout_tex_1);
166+ imageStore(t_xkout, x_pos_2, xout_tex_2);
167+ #endif
123168}
0 commit comments