55
66submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
77 use stdlib_linalg_blas, only: gemm
8+ use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_VALUE_ERROR
89 use stdlib_constants
910 implicit none
1011
12+ character(len=*), parameter :: this = "stdlib_matmul"
13+
1114contains
1215
1316 ! Algorithm for the optimal parenthesization of matrices
@@ -122,41 +125,76 @@ contains
122125
123126 end function matmul_chain_mult_${s}$_4
124127
125- pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
128+ pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err)
129+ ${t}$, intent(out), allocatable :: res(:,:)
126130 ${t}$, intent(in) :: m1(:,:), m2(:,:)
127131 ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
128- ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
132+ type(linalg_state_type), intent(out), optional :: err
133+ ${t}$, allocatable :: temp(:,:), temp1(:,:)
129134 integer :: p(6), num_present, m, n, k
130135 integer, allocatable :: s(:,:)
131136
137+ type(linalg_state_type) :: err0
138+
132139 p(1) = size(m1, 1)
133140 p(2) = size(m2, 1)
134141 p(3) = size(m2, 2)
135142
143+ if (size(m1, 2) /= p(2)) then
144+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m1, m2 not of compatible sizes')
145+ call linalg_error_handling(err0, err)
146+ allocate(res(0, 0))
147+ return
148+ end if
149+
136150 num_present = 2
137151 if (present(m3)) then
152+
153+ if (size(m3, 1) /= p(3)) then
154+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m2, m3 not of compatible sizes')
155+ call linalg_error_handling(err0, err)
156+ allocate(res(0, 0))
157+ return
158+ end if
159+
138160 p(3) = size(m3, 1)
139161 p(4) = size(m3, 2)
140162 num_present = num_present + 1
141163 end if
142164 if (present(m4)) then
165+
166+ if (size(m4, 1) /= p(4)) then
167+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m3, m4 not of compatible sizes')
168+ call linalg_error_handling(err0, err)
169+ allocate(res(0, 0))
170+ return
171+ end if
172+
143173 p(4) = size(m4, 1)
144174 p(5) = size(m4, 2)
145175 num_present = num_present + 1
146176 end if
147177 if (present(m5)) then
178+
179+ if (size(m5, 1) /= p(5)) then
180+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m4, m5 not of compatible sizes')
181+ call linalg_error_handling(err0, err)
182+ allocate(res(0, 0))
183+ return
184+ end if
185+
148186 p(5) = size(m5, 1)
149187 p(6) = size(m5, 2)
150188 num_present = num_present + 1
151189 end if
152190
153- allocate(r (p(1), p(num_present + 1)))
191+ allocate(res (p(1), p(num_present + 1)))
154192
155193 if (num_present == 2) then
156194 m = p(1)
157195 n = p(3)
158196 k = p(2)
159- call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r , m)
197+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, res , m)
160198 return
161199 end if
162200
@@ -166,10 +204,10 @@ contains
166204 s = matmul_chain_order(p(1: num_present + 1))
167205
168206 if (num_present == 3) then
169- r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4))
207+ res = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4))
170208 return
171209 else if (num_present == 4) then
172- r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5))
210+ res = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5))
173211 return
174212 end if
175213
@@ -182,7 +220,7 @@ contains
182220 m = p(1)
183221 n = p(6)
184222 k = p(2)
185- call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r , m)
223+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, res , m)
186224 case (2)
187225 ! (m1*m2)*(m3*m4*m5)
188226 m = p(1)
@@ -195,7 +233,7 @@ contains
195233
196234 k = n
197235 n = p(6)
198- call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r , m)
236+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res , m)
199237 case (3)
200238 ! (m1*m2*m3)*(m4*m5)
201239 temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p)
@@ -208,18 +246,35 @@ contains
208246
209247 k = m
210248 m = p(1)
211- call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r , m)
249+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res , m)
212250 case (4)
213251 ! (m1*m2*m3*m4)*m5
214252 temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p)
215253 m = p(1)
216254 n = p(6)
217255 k = p(5)
218- call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r , m)
256+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, res , m)
219257 case default
220- error stop "stdlib_matmul: error: unexpected s(i,j)"
258+ error stop "stdlib_matmul: internal error: unexpected s(i,j)"
221259 end select
222260
261+ end subroutine stdlib_matmul_sub_${s}$
262+
263+ pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r)
264+ ${t}$, intent(in) :: m1(:,:), m2(:,:)
265+ ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
266+ ${t}$, allocatable :: r(:,:)
267+
268+ call stdlib_matmul_sub(r, m1, m2, m3, m4, m5)
269+ end function stdlib_matmul_pure_${s}$
270+
271+ module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r)
272+ ${t}$, intent(in) :: m1(:,:), m2(:,:)
273+ ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
274+ type(linalg_state_type), intent(out) :: err
275+ ${t}$, allocatable :: r(:,:)
276+
277+ call stdlib_matmul_sub(r, m1, m2, m3, m4, m5, err=err)
223278 end function stdlib_matmul_${s}$
224279
225280#:endfor
0 commit comments