5757import java .util .concurrent .ThreadLocalRandom ;
5858import java .util .concurrent .atomic .AtomicLong ;
5959import java .util .function .Function ;
60+ import java .util .function .Supplier ;
6061import java .util .stream .Collectors ;
62+ import java .util .stream .IntStream ;
6163import java .util .stream .LongStream ;
6264
6365import static org .neo4j .gds .embeddings .graphsage .GraphSageHelper .embeddingsComputationGraph ;
6870public class GraphSageModelTrainer {
6971 private final long randomSeed ;
7072 private final boolean useWeights ;
71- private final double learningRate ;
72- private final double tolerance ;
73- private final int negativeSampleWeight ;
74- private final int concurrency ;
75- private final int epochs ;
76- private final int maxIterations ;
77- private final int maxSearchDepth ;
7873 private final Function <Graph , List <LayerConfig >> layerConfigsFunction ;
7974 private final FeatureFunction featureFunction ;
8075 private final Collection <Weights <Matrix >> labelProjectionWeights ;
8176 private final ExecutorService executor ;
8277 private final ProgressTracker progressTracker ;
83- private final int batchSize ;
78+ private final GraphSageTrainConfig config ;
8479
8580 public GraphSageModelTrainer (GraphSageTrainConfig config , ExecutorService executor , ProgressTracker progressTracker ) {
8681 this (config , executor , progressTracker , new SingleLabelFeatureFunction (), Collections .emptyList ());
@@ -94,14 +89,7 @@ public GraphSageModelTrainer(
9489 Collection <Weights <Matrix >> labelProjectionWeights
9590 ) {
9691 this .layerConfigsFunction = graph -> config .layerConfigs (firstLayerColumns (config , graph ));
97- this .batchSize = config .batchSize ();
98- this .learningRate = config .learningRate ();
99- this .tolerance = config .tolerance ();
100- this .negativeSampleWeight = config .negativeSampleWeight ();
101- this .concurrency = config .concurrency ();
102- this .epochs = config .epochs ();
103- this .maxIterations = config .maxIterations ();
104- this .maxSearchDepth = config .searchDepth ();
92+ this .config = config ;
10593 this .featureFunction = featureFunction ;
10694 this .labelProjectionWeights = labelProjectionWeights ;
10795 this .executor = executor ;
@@ -139,21 +127,29 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
139127
140128 var batchTasks = PartitionUtils .rangePartitionWithBatchSize (
141129 graph .nodeCount (),
142- batchSize ,
130+ config . batchSize () ,
143131 batch -> createBatchTask (graph , features , layers , weights , batch )
144132 );
133+ var random = new Random (randomSeed );
134+ Supplier <List <BatchTask >> batchTaskSampler = () -> IntStream .range (0 , config .batchesPerIteration (graph .nodeCount ()))
135+ .mapToObj (__ -> batchTasks .get (random .nextInt (batchTasks .size ())))
136+ .collect (Collectors .toList ());
145137
146138 progressTracker .endSubTask ("Prepare batches" );
147139
140+ progressTracker .beginSubTask ("Train model" );
141+
148142 boolean converged = false ;
149143 var iterationLossesPerEpoch = new ArrayList <List <Double >>();
150-
151- progressTracker . beginSubTask ( "Train model" );
144+ var prevEpochLoss = Double . NaN ;
145+ int epochs = config . epochs ( );
152146
153147 for (int epoch = 1 ; epoch <= epochs && !converged ; epoch ++) {
154148 progressTracker .beginSubTask ("Epoch" );
155- var epochResult = trainEpoch (batchTasks , weights );
156- iterationLossesPerEpoch .add (epochResult .losses ());
149+ var epochResult = trainEpoch (batchTaskSampler , weights , prevEpochLoss );
150+ List <Double > epochLosses = epochResult .losses ();
151+ iterationLossesPerEpoch .add (epochLosses );
152+ prevEpochLoss = epochLosses .get (epochLosses .size () - 1 );
157153 converged = epochResult .converged ();
158154 progressTracker .endSubTask ("Epoch" );
159155 }
@@ -188,43 +184,52 @@ private BatchTask createBatchTask(
188184 useWeights ? localGraph ::relationshipProperty : UNWEIGHTED ,
189185 embeddingVariable ,
190186 totalBatch ,
191- negativeSampleWeight
187+ config . negativeSampleWeight ()
192188 );
193189
194- return new BatchTask (lossFunction , weights , tolerance , progressTracker );
190+ return new BatchTask (lossFunction , weights , progressTracker );
195191 }
196192
197- private EpochResult trainEpoch (List <BatchTask > batchTasks , List <Weights <? extends Tensor <?>>> weights ) {
198- var updater = new AdamOptimizer (weights , learningRate );
193+ private EpochResult trainEpoch (
194+ Supplier <List <BatchTask >> sampledBatchTaskSupplier ,
195+ List <Weights <? extends Tensor <?>>> weights ,
196+ double prevEpochLoss
197+ ) {
198+ var updater = new AdamOptimizer (weights , config .learningRate ());
199199
200200 int iteration = 1 ;
201201 var iterationLosses = new ArrayList <Double >();
202+ double prevLoss = prevEpochLoss ;
202203 var converged = false ;
203204
204- for (;iteration <= maxIterations ; iteration ++) {
205+ int maxIterations = config .maxIterations ();
206+ for (; iteration <= maxIterations ; iteration ++) {
205207 progressTracker .beginSubTask ("Iteration" );
206208
209+ var sampledBatchTasks = sampledBatchTaskSupplier .get ();
210+
207211 // run forward + maybe backward for each Batch
208- ParallelUtil .runWithConcurrency (concurrency , batchTasks , executor );
209- var avgLoss = batchTasks .stream ().mapToDouble (BatchTask ::loss ).average ().orElseThrow ();
212+ ParallelUtil .runWithConcurrency (config . concurrency (), sampledBatchTasks , executor );
213+ var avgLoss = sampledBatchTasks .stream ().mapToDouble (BatchTask ::loss ).average ().orElseThrow ();
210214 iterationLosses .add (avgLoss );
215+ progressTracker .logMessage (formatWithLocale ("LOSS: %.10f" , avgLoss ));
211216
212- converged = batchTasks . stream (). allMatch ( task -> task . converged );
213- if ( converged ) {
214- progressTracker .endSubTask ();
217+ if ( Math . abs ( prevLoss - avgLoss ) < config . tolerance ()) {
218+ converged = true ;
219+ progressTracker .endSubTask ("Iteration" );
215220 break ;
216221 }
217222
218- var batchedGradients = batchTasks
223+ prevLoss = avgLoss ;
224+
225+ var batchedGradients = sampledBatchTasks
219226 .stream ()
220227 .map (BatchTask ::weightGradients )
221228 .collect (Collectors .toList ());
222229
223230 var meanGradients = averageTensors (batchedGradients );
224231
225232 updater .update (meanGradients );
226-
227- progressTracker .logMessage (formatWithLocale ("LOSS: %.10f" , avgLoss ));
228233 progressTracker .endSubTask ("Iteration" );
229234 }
230235
@@ -243,34 +248,23 @@ static class BatchTask implements Runnable {
243248 private final Variable <Scalar > lossFunction ;
244249 private final List <Weights <? extends Tensor <?>>> weightVariables ;
245250 private List <? extends Tensor <?>> weightGradients ;
246- private final double tolerance ;
247251 private final ProgressTracker progressTracker ;
248- private boolean converged ;
249- private double prevLoss ;
252+ private double loss ;
250253
251254 BatchTask (
252255 Variable <Scalar > lossFunction ,
253256 List <Weights <? extends Tensor <?>>> weightVariables ,
254- double tolerance ,
255257 ProgressTracker progressTracker
256258 ) {
257259 this .lossFunction = lossFunction ;
258260 this .weightVariables = weightVariables ;
259- this .tolerance = tolerance ;
260261 this .progressTracker = progressTracker ;
261262 }
262263
263264 @ Override
264265 public void run () {
265- if (converged ) { // Don't try to go further
266- return ;
267- }
268-
269266 var localCtx = new ComputationContext ();
270- var loss = localCtx .forward (lossFunction ).value ();
271-
272- converged = Math .abs (prevLoss - loss ) < tolerance ;
273- prevLoss = loss ;
267+ loss = localCtx .forward (lossFunction ).value ();
274268
275269 localCtx .backward (lossFunction );
276270 weightGradients = weightVariables .stream ().map (localCtx ::gradient ).collect (Collectors .toList ());
@@ -279,7 +273,7 @@ public void run() {
279273 }
280274
281275 public double loss () {
282- return prevLoss ;
276+ return loss ;
283277 }
284278
285279 List <? extends Tensor <?>> weightGradients () {
@@ -312,7 +306,7 @@ LongStream neighborBatch(Graph graph, Partition batch, long batchLocalSeed) {
312306 // sample a neighbor for each batchNode
313307 batch .consume (nodeId -> {
314308 // randomWalk with at most maxSearchDepth steps and only save last node
315- int searchDepth = localRandom .nextInt (maxSearchDepth ) + 1 ;
309+ int searchDepth = localRandom .nextInt (config . searchDepth () ) + 1 ;
316310 AtomicLong currentNode = new AtomicLong (nodeId );
317311 while (searchDepth > 0 ) {
318312 NeighborhoodSampler neighborhoodSampler = new NeighborhoodSampler (currentNode .get () + searchDepth );
0 commit comments