11program test_conv2d_network
22
33 use iso_fortran_env, only: stderr = > error_unit
4- use nf, only: conv2d, input, network
4+ use nf, only: conv2d, input, network, dense, sgd, maxpool2d
55
66 implicit none
77
@@ -21,6 +21,7 @@ program test_conv2d_network
2121 ok = .false.
2222 end if
2323
24+ ! Test for output shape
2425 allocate (sample_input(3 , 32 , 32 ))
2526 sample_input = 0
2627
@@ -32,6 +33,115 @@ program test_conv2d_network
3233 ok = .false.
3334 end if
3435
36+ deallocate (sample_input, output)
37+
38+ training1: block
39+
40+ type (network) :: cnn
41+ real :: y(1 )
42+ real :: tolerance = 1e-5
43+ integer :: n
44+ integer , parameter :: num_iterations = 1000
45+
46+ ! Test training of a minimal constant mapping
47+ allocate (sample_input(1 , 5 , 5 ))
48+ call random_number (sample_input)
49+
50+ cnn = network([ &
51+ input(shape (sample_input)), &
52+ conv2d(filters= 1 , kernel_size= 3 ), &
53+ conv2d(filters= 1 , kernel_size= 3 ), &
54+ dense(1 ) &
55+ ])
56+
57+ y = [0.1234567 ]
58+
59+ do n = 1 , num_iterations
60+ call cnn % forward(sample_input)
61+ call cnn % backward(y)
62+ call cnn % update(optimizer= sgd(learning_rate= 1 .))
63+ if (all (abs (cnn % predict(sample_input) - y) < tolerance)) exit
64+ end do
65+
66+ if (.not. n <= num_iterations) then
67+ write (stderr, ' (a)' ) &
68+ ' convolutional network 1 should converge in simple training.. failed'
69+ ok = .false.
70+ end if
71+
72+ end block training1
73+
74+ training2: block
75+
76+ type (network) :: cnn
77+ real :: x(1 , 8 , 8 )
78+ real :: y(1 )
79+ real :: tolerance = 1e-5
80+ integer :: n
81+ integer , parameter :: num_iterations = 1000
82+
83+ call random_number (x)
84+ y = [0.1234567 ]
85+
86+ cnn = network([ &
87+ input(shape (x)), &
88+ conv2d(filters= 1 , kernel_size= 3 ), &
89+ maxpool2d(pool_size= 2 ), &
90+ conv2d(filters= 1 , kernel_size= 3 ), &
91+ dense(1 ) &
92+ ])
93+
94+ do n = 1 , num_iterations
95+ call cnn % forward(x)
96+ call cnn % backward(y)
97+ call cnn % update(optimizer= sgd(learning_rate= 1 .))
98+ if (all (abs (cnn % predict(x) - y) < tolerance)) exit
99+ end do
100+
101+ if (.not. n <= num_iterations) then
102+ write (stderr, ' (a)' ) &
103+ ' convolutional network 2 should converge in simple training.. failed'
104+ ok = .false.
105+ end if
106+
107+ end block training2
108+
109+ training3: block
110+
111+ type (network) :: cnn
112+ real :: x(1 , 12 , 12 )
113+ real :: y(9 )
114+ real :: tolerance = 1e-5
115+ integer :: n
116+ integer , parameter :: num_iterations = 5000
117+
118+ call random_number (x)
119+ y = [0.12345 , 0.23456 , 0.34567 , 0.45678 , 0.56789 , 0.67890 , 0.78901 , 0.89012 , 0.90123 ]
120+
121+ cnn = network([ &
122+ input(shape (x)), &
123+ conv2d(filters= 1 , kernel_size= 3 ), & ! 1x12x12 input, 1x10x10 output
124+ maxpool2d(pool_size= 2 ), & ! 1x10x10 input, 1x5x5 output
125+ conv2d(filters= 1 , kernel_size= 3 ), & ! 1x5x5 input, 1x3x3 output
126+ dense(9 ) & ! 9 outputs
127+ ])
128+
129+ do n = 1 , num_iterations
130+ call cnn % forward(x)
131+ call cnn % backward(y)
132+ call cnn % update(optimizer= sgd(learning_rate= 1 .))
133+ if (all (abs (cnn % predict(x) - y) < tolerance)) exit
134+ end do
135+
136+ if (.not. n <= num_iterations) then
137+ write (stderr, ' (a)' ) &
138+ ' convolutional network 3 should converge in simple training.. failed'
139+ ok = .false.
140+ end if
141+
142+ end block training3
143+
144+
35145 if (ok) then
36146 print ' (a)' , ' test_conv2d_network: All tests passed.'
37147 else
0 commit comments