44#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
55
66submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
7+ use stdlib_linalg_blas, only: gemm
8+ use stdlib_constants
79 implicit none
810
911contains
@@ -36,38 +38,84 @@ contains
3638 end do
3739 end function matmul_chain_order
3840
39- #:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
41+ #:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES
4042
41- pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
43+ pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s, p ) result(r)
4244 ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
43- integer, intent(in) :: start, s(:,2:)
44- ${t}$, allocatable :: r(:,:)
45- integer :: tmp
46- tmp = s(start, start + 2)
47-
48- if (tmp == start) then
49- r = matmul(m1, matmul(m2, m3))
50- else if (tmp == start + 1) then
51- r = matmul(matmul(m1, m2), m3)
45+ integer, intent(in) :: start, s(:,2:), p(:)
46+ ${t}$, allocatable :: r(:,:), temp(:,:)
47+ integer :: ord, m, n, k
48+ ord = s(start, start + 2)
49+ allocate(r(p(start), p(start + 3)))
50+
51+ if (ord == start) then
52+ ! m1*(m2*m3)
53+ m = p(start + 1)
54+ n = p(start + 3)
55+ k = p(start + 2)
56+ allocate(temp(m,n))
57+ call gemm('N', 'N', m, n, k, one_${s}$, m2, m, m3, k, zero_${s}$, temp, m)
58+ m = p(start)
59+ n = p(start + 3)
60+ k = p(start + 1)
61+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
62+ else if (ord == start + 1) then
63+ ! (m1*m2)*m3
64+ m = p(start)
65+ n = p(start + 2)
66+ k = p(start + 1)
67+ allocate(temp(m, n))
68+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
69+ m = p(start)
70+ n = p(start + 3)
71+ k = p(start + 1)
72+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m)
5273 else
5374 error stop "stdlib_matmul: error: unexpected s(i,j)"
5475 end if
5576
5677 end function matmul_chain_mult_${s}$_3
5778
58- pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
79+ pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s, p ) result(r)
5980 ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
60- integer, intent(in) :: start, s(:,2:)
61- ${t}$, allocatable :: r(:,:)
62- integer :: tmp
63- tmp = s(start, start + 3)
64-
65- if (tmp == start) then
66- r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s))
67- else if (tmp == start + 1) then
68- r = matmul(matmul(m1, m2), matmul(m3, m4))
69- else if (tmp == start + 2) then
70- r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
81+ integer, intent(in) :: start, s(:,2:), p(:)
82+ ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
83+ integer :: ord, m, n, k
84+ ord = s(start, start + 3)
85+ allocate(r(p(start), p(start + 4)))
86+
87+ if (ord == start) then
88+ ! m1*(m2*m3*m4)
89+ temp = matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s, p)
90+ m = p(start)
91+ n = p(start + 4)
92+ k = p(start + 1)
93+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
94+ else if (ord == start + 1) then
95+ ! (m1*m2)*(m3*m4)
96+ m = p(start)
97+ n = p(start + 2)
98+ k = p(start + 1)
99+ allocate(temp(m,n))
100+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
101+
102+ m = p(start + 2)
103+ n = p(start + 4)
104+ k = p(start + 3)
105+ allocate(temp1(m,n))
106+ call gemm('N', 'N', m, n, k, one_${s}$, m3, m, m4, k, zero_${s}$, temp1, m)
107+
108+ m = p(start)
109+ n = p(start + 4)
110+ k = p(start + 2)
111+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
112+ else if (ord == start + 2) then
113+ ! (m1*m2*m3)*m4
114+ temp = matmul_chain_mult_${s}$_3(m1, m2, m3, start, s, p)
115+ m = p(start)
116+ n = p(start + 4)
117+ k = p(start + 3)
118+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m)
71119 else
72120 error stop "stdlib_matmul: error: unexpected s(i,j)"
73121 end if
@@ -77,8 +125,8 @@ contains
77125 pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
78126 ${t}$, intent(in) :: m1(:,:), m2(:,:)
79127 ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
80- ${t}$, allocatable :: r(:,:)
81- integer :: p(6), num_present
128+ ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
129+ integer :: p(6), num_present, m, n, k
82130 integer, allocatable :: s(:,:)
83131
84132 p(1) = size(m1, 1)
@@ -102,8 +150,13 @@ contains
102150 num_present = num_present + 1
103151 end if
104152
153+ allocate(r(p(1), p(num_present + 1)))
154+
105155 if (num_present == 2) then
106- r = matmul(m1, m2)
156+ m = p(1)
157+ n = p(3)
158+ k = p(2)
159+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m)
107160 return
108161 end if
109162
@@ -113,24 +166,56 @@ contains
113166 s = matmul_chain_order(p(1: num_present + 1))
114167
115168 if (num_present == 3) then
116- r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s)
169+ r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4) )
117170 return
118171 else if (num_present == 4) then
119- r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s)
172+ r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5) )
120173 return
121174 end if
122175
123176 ! Now num_present is 5
124177
125178 select case (s(1, 5))
126179 case (1)
127- r = matmul(m1, matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s))
180+ ! m1*(m2*m3*m4*m5)
181+ temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p)
182+ m = p(1)
183+ n = p(6)
184+ k = p(2)
185+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
128186 case (2)
129- r = matmul(matmul(m1, m2), matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s))
187+ ! (m1*m2)*(m3*m4*m5)
188+ m = p(1)
189+ n = p(3)
190+ k = p(2)
191+ allocate(temp(m,n))
192+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
193+
194+ temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p)
195+
196+ k = n
197+ n = p(6)
198+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
130199 case (3)
131- r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s), matmul(m4, m5))
200+ ! (m1*m2*m3)*(m4*m5)
201+ temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p)
202+
203+ m = p(4)
204+ n = p(6)
205+ k = p(5)
206+ allocate(temp1(m,n))
207+ call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m)
208+
209+ k = m
210+ m = p(1)
211+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
132212 case (4)
133- r = matmul(matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s), m5)
213+ ! (m1*m2*m3*m4)*m5
214+ temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p)
215+ m = p(1)
216+ n = p(6)
217+ k = p(5)
218+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m)
134219 case default
135220 error stop "stdlib_matmul: error: unexpected s(i,j)"
136221 end select
0 commit comments