77#include < iostream>
88#include < sycl/sycl.hpp>
99
10- // using joint_matrix = sycl::ext::oneapi::experimental::matrix;
1110using use = sycl::ext::oneapi::experimental::matrix::use;
1211using layout = sycl::ext::oneapi::experimental::matrix::layout;
1312using bfloat16 = sycl::ext::oneapi::bfloat16;
1413
15- # define SG_SZ 16
14+ constexpr size_t SG_SZ = 16 ;
1615
17- # define TM 8
18- # define TN SG_SZ
19- # define TK 16
16+ constexpr size_t TM = 8 ;
17+ constexpr size_t TN = SG_SZ;
18+ constexpr size_t TK = 16 ;
2019
21- #define BF16_EPSILON 0.00781250
20+ constexpr float ALPHA = 2.0 ;
21+
22+ constexpr float BF16_EPSILON = 0.00781250 ;
2223
2324template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
2425private:
@@ -42,10 +43,9 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
4243
4344 sycl::queue q;
4445 q.submit ([&](sycl::handler &cgh) {
45- sycl::accessor accC (bufC, cgh, sycl::read_write, sycl::no_init );
46+ sycl::accessor accC (bufC, cgh, sycl::read_write);
4647 sycl::accessor accA (bufA, cgh, sycl::read_only);
4748 sycl::accessor accB (bufB, cgh, sycl::read_only);
48-
4949 cgh.parallel_for (
5050 sycl::nd_range<2 >({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ}),
5151 [=](sycl::nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]]
@@ -66,30 +66,32 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
6666 // For B, we assume B has been already VNNIed.
6767 sycl::ext::oneapi::experimental::matrix::joint_matrix<
6868 sycl::sub_group, bfloat16, use::b, TK, TN,
69- sycl::ext::intel::experimental::matrix:: layout::packed >
69+ layout::ext_intel_packed >
7070 sub_b;
7171 sycl::ext::oneapi::experimental::matrix::joint_matrix<
7272 sycl::sub_group, float , use::accumulator, TM, TN>
7373 sub_c;
7474
75- joint_matrix_load (sg, sub_c,
76- accC.get_pointer () + (sg_startx * TM) * N +
77- sg_starty / SG_SZ * TN,
78- N, layout::row_major);
79- for (int k = 0 ; k < K / TK; k += 1 ) { //
75+ joint_matrix_fill (sg, sub_c, 1.0 );
76+ for (int k = 0 ; k < K / TK; k += 1 ) {
8077 joint_matrix_load (
81- sg, sub_a, accA.get_pointer () + (sg_startx * TM) * K + k * TK,
78+ sg, sub_a,
79+ accA.template get_multi_ptr <sycl::access::decorated::no>() +
80+ (sg_startx * TM) * K + k * TK,
8281 K);
83- joint_matrix_load (sg, sub_b,
84- accB.get_pointer () + (k * TK / 2 ) * (N * 2 ) +
85- sg_starty / SG_SZ * TN * 2 ,
86- N * 2 );
87- sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
82+ joint_matrix_load (
83+ sg, sub_b,
84+ accB.template get_multi_ptr <sycl::access::decorated::no>() +
85+ (k * TK / 2 ) * (N * 2 ) + sg_starty / SG_SZ * TN * 2 ,
86+ N * 2 );
87+ joint_matrix_mad (sg, sub_c, sub_a, sub_b, sub_c);
8888 }
89- joint_matrix_store (sg, sub_c,
90- accC.get_pointer () + (sg_startx * TM) * N +
91- sg_starty / SG_SZ * TN,
92- N, layout::row_major);
89+ joint_matrix_apply (sg, sub_c, [=](float &x) { x *= ALPHA; });
90+ joint_matrix_store (
91+ sg, sub_c,
92+ accC.template get_multi_ptr <sycl::access::decorated::no>() +
93+ (sg_startx * TM) * N + sg_starty / SG_SZ * TN,
94+ N, layout::row_major);
9395 }); // parallel for
9496 }).wait ();
9597 // kernel end
@@ -100,53 +102,43 @@ static constexpr size_t MATRIX_N = TN * 2;
100102static constexpr size_t MATRIX_K = TK * 2 ;
101103bfloat16 A[MATRIX_M][MATRIX_K];
102104bfloat16 B[MATRIX_K / 2 ][MATRIX_N * 2 ];
103- unsigned short Aref[MATRIX_M][MATRIX_K];
104- unsigned short Bref[MATRIX_K / 2 ][MATRIX_N * 2 ];
105105float C[MATRIX_M][MATRIX_N];
106106float D[MATRIX_M][MATRIX_N];
107107
108- float make_fp32 (short x) {
109- unsigned int y = x ;
108+ float make_fp32 (bfloat16 x) {
109+ unsigned int y = *(( int *)&x) ;
110110 y = y << 16 ;
111111 float *res = reinterpret_cast <float *>(&y);
112112 return *res;
113113}
114114
115- unsigned short make_bf16 (float x) {
116- int *res = reinterpret_cast <int *>(&x);
117- *res = *res >> 16 ;
118- return (unsigned short )*res;
119- }
120-
121115void matrix_multiply_ref (int *A_mem, int *B_mem, int *C_mem, int M, int N,
122116 int K) {
123117 for (int m = 0 ; m < M; m++)
124118 for (int n = 0 ; n < N; n++) {
125119 for (int k = 0 ; k < K; k++) {
126- short *va = (short *)(A_mem + m * K + k);
127- short *vb = (short *)(B_mem + k * N + n);
120+ // Because B was assumed VNNIed
121+ bfloat16 *va = (bfloat16 *)(A_mem + m * K + k);
122+ bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n);
128123 float acc = *((float *)(C_mem + m * N + n));
129124 for (int i = 0 ; i < 2 ; i++) {
130125 acc += (make_fp32 (va[i]) * make_fp32 (vb[i]));
131126 }
132127 *((float *)(C_mem + m * N + n)) = acc;
133128 }
129+ *((float *)(C_mem + m * N + n)) *= ALPHA;
134130 }
135131}
136132
137133int main () {
138134 for (int i = 0 ; i < MATRIX_M; i++) {
139135 for (int j = 0 ; j < MATRIX_K; j++) {
140- // bfloat16 is created using unsigned short since conversion from float to
141- // bfloat16 is not supported on the host side yet
142136 A[i][j] = bfloat16 (1 .0f * (i + j));
143- Aref[i][j] = make_bf16 (1 .0f * (i + j));
144137 }
145138 }
146139 for (int i = 0 ; i < MATRIX_K / 2 ; i++) {
147140 for (int j = 0 ; j < MATRIX_N * 2 ; j++) {
148141 B[i][j] = bfloat16 (2 .0f * i + 3 .0f * j);
149- Bref[i][j] = make_bf16 (2 .0f * i + 3 .0f * j);
150142 }
151143 }
152144 for (int i = 0 ; i < MATRIX_M; i++) {
@@ -161,13 +153,13 @@ int main() {
161153 big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA ((bfloat16 *)&A);
162154 big_matrix<bfloat16, MATRIX_K / 2 , MATRIX_N * 2 > MB ((bfloat16 *)&B);
163155 matrix_multiply (MC, MA, MB);
164- matrix_multiply_ref ((int32_t *)Aref , (int32_t *)Bref , (int32_t *)D, MATRIX_M,
156+ matrix_multiply_ref ((int32_t *)A , (int32_t *)B , (int32_t *)D, MATRIX_M,
165157 MATRIX_N, MATRIX_K / 2 );
166158
167159 bool res = true ;
168160 for (int i = 0 ; i < MATRIX_M; i++) {
169161 for (int j = 0 ; j < MATRIX_N; j++) {
170- if ((fabs (C[i][j]) - fabs ( D[i][j])) > BF16_EPSILON)
162+ if ((fabs (C[i][j] - D[i][j])) > BF16_EPSILON)
171163 res = false ;
172164 }
173165 }
0 commit comments