1919from pytorch_lightning .callbacks .progress import TQDMProgressBar
2020from torch import nn
2121from torch .nn import functional as F
22- from torch .utils .data import DataLoader , random_split
22+ from torch .utils .data import DataLoader , random_split , RandomSampler
2323from torchmetrics import Accuracy
2424from torchvision import transforms
2525from torchvision .datasets import MNIST
@@ -127,7 +127,11 @@ def setup(self, stage=None):
127127 )
128128
129129 def train_dataloader (self ):
130- return DataLoader (self .mnist_train , batch_size = BATCH_SIZE )
130+ return DataLoader (
131+ self .mnist_train ,
132+ batch_size = BATCH_SIZE ,
133+ sampler = RandomSampler (self .mnist_train , num_samples = 1000 ),
134+ )
131135
132136 def val_dataloader (self ):
133137 return DataLoader (self .mnist_val , batch_size = BATCH_SIZE )
@@ -147,10 +151,11 @@ def test_dataloader(self):
147151trainer = Trainer (
148152 accelerator = "auto" ,
149153 # devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
150- max_epochs = 5 ,
154+ max_epochs = 3 ,
151155 callbacks = [TQDMProgressBar (refresh_rate = 20 )],
152156 num_nodes = int (os .environ .get ("GROUP_WORLD_SIZE" , 1 )),
153157 devices = int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 )),
158+ replace_sampler_ddp = False ,
154159 strategy = "ddp" ,
155160)
156161
0 commit comments