|
12 | 12 |
|
13 | 13 | def ddp_setup(): |
14 | 14 | rank = int(os.environ["LOCAL_RANK"]) |
15 | | - if torch.accelerator.is_available(): |
16 | | - device_type = torch.accelerator.current_accelerator() |
17 | | - device: torch.device = torch.device(f"{device_type}:{rank}") |
18 | | - torch.accelerator.device_index(rank) |
19 | | - print(f"Running on rank {rank} on device {device}") |
20 | | - backend = torch.distributed.get_default_backend_for_device(device) |
21 | | - torch.distributed.init_process_group(backend=backend) |
22 | | - return device_type |
23 | | - else: |
24 | | - device = torch.device("cpu") |
25 | | - print(f"Running on device {device}") |
26 | | - torch.distributed.init_process_group(backend="gloo") |
27 | | - return device |
| 15 | + |
| 16 | + device = torch.device(f"{torch.accelerator.current_accelerator()}:{rank}") |
| 17 | + torch.accelerator.set_device_index(rank) |
| 18 | + print(f"Running on rank {rank} on device {device}") |
| 19 | + |
| 20 | + backend = torch.distributed.get_default_backend_for_device(rank) |
| 21 | + torch.distributed.init_process_group(backend=backend, rank=rank, device_id=rank) |
| 22 | + |
28 | 23 |
|
29 | 24 | class Trainer: |
30 | 25 | def __init__( |
@@ -52,7 +47,8 @@ def __init__( |
52 | 47 | self.model = DDP(self.model, device_ids=[self.local_rank]) |
53 | 48 |
|
54 | 49 | def _load_snapshot(self, snapshot_path): |
55 | | - loc = str(self.device) |
| 50 | + loc = str(torch.accelerator.current_accelerator()) |
| 51 | + |
56 | 52 | snapshot = torch.load(snapshot_path, map_location=loc) |
57 | 53 | self.model.load_state_dict(snapshot["MODEL_STATE"]) |
58 | 54 | self.epochs_run = snapshot["EPOCHS_RUN"] |
@@ -118,8 +114,8 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str |
118 | 114 | if __name__ == "__main__": |
119 | 115 | import argparse |
120 | 116 | parser = argparse.ArgumentParser(description='simple distributed training job') |
121 | | - parser.add_argument('total_epochs', type=int, help='Total epochs to train the model') |
122 | | - parser.add_argument('save_every', type=int, help='How often to save a snapshot') |
| 117 | + parser.add_argument('total_epochs', default=50, type=int, help='Total epochs to train the model') |
| 118 | + parser.add_argument('save_every', default=5, type=int, help='How often to save a snapshot') |
123 | 119 | parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') |
124 | 120 | args = parser.parse_args() |
125 | 121 |
|
|
0 commit comments