Skip to content

Commit f20cbc1

Browse files
TonyTony
authored andcommitted
initial commit:
- added val-interval argument. Eval and checkpointing is only applied every val-interval epochs. - Changed `float` to `Optional[float]` in typing Scheduler step function parameter `metric` - Skipping step of base scheduler in plateau scheduler to avoid `TypeError` when converting `None` to `float` - added `or last_batch` to logging logic during training to be consistent with validation
1 parent 47c18f4 commit f20cbc1

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

timm/scheduler/plateau_lr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Hacked together by / Copyright 2020 Ross Wightman
66
"""
77
import torch
8-
from typing import List
8+
from typing import List, Optional
99

1010
from .scheduler import Scheduler
1111

@@ -86,12 +86,14 @@ def step(self, epoch, metric=None):
8686
param_group['lr'] = self.restore_lr[i]
8787
self.restore_lr = None
8888

89-
self.lr_scheduler.step(metric, epoch) # step the base scheduler
89+
# step the base scheduler if metric given
90+
if metric is not None:
91+
self.lr_scheduler.step(metric, epoch)
9092

9193
if self._is_apply_noise(epoch):
9294
self._apply_noise(epoch)
9395

94-
def step_update(self, num_updates: int, metric: float = None):
96+
def step_update(self, num_updates: int, metric: Optional[float] = None):
9597
return None
9698

9799
def _apply_noise(self, epoch):

timm/scheduler/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
7474
return None
7575
return self._get_lr(t)
7676

77-
def step(self, epoch: int, metric: float = None) -> None:
77+
def step(self, epoch: int, metric: Optional[float] = None) -> None:
7878
self.metric = metric
7979
values = self._get_values(epoch, on_epoch=True)
8080
if values is not None:
8181
values = self._add_noise(values, epoch)
8282
self.update_groups(values)
8383

84-
def step_update(self, num_updates: int, metric: float = None):
84+
def step_update(self, num_updates: int, metric: Optional[float] = None):
8585
self.metric = metric
8686
values = self._get_values(num_updates, on_epoch=False)
8787
if values is not None:

train.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@
370370
help='worker seed mode (default: all)')
371371
group.add_argument('--log-interval', type=int, default=50, metavar='N',
372372
help='how many batches to wait before logging training status')
373+
group.add_argument('--val-interval', type=int, default=1, metavar='N',
374+
help='how many epochs between validation and checkpointing')
373375
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
374376
help='how many batches to wait before writing recovery checkpoint')
375377
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
@@ -1013,6 +1015,16 @@ def main():
10131015
_logger.info("Distributing BatchNorm running means and vars")
10141016
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
10151017

1018+
if (epoch + 1) % args.val_interval != 0:
1019+
if utils.is_primary(args):
1020+
_logger.info("Skipping eval and checkpointing ")
1021+
if lr_scheduler is not None:
1022+
# step LR for next epoch
1023+
# careful when using metric dependent lr_scheduler
1024+
lr_scheduler.step(epoch + 1, metric=None)
1025+
# skip validation and metric logic
1026+
continue
1027+
10161028
if loader_eval is not None:
10171029
eval_metrics = validate(
10181030
model,
@@ -1252,7 +1264,7 @@ def _backward(_loss):
12521264
update_time_m.update(time.time() - update_start_time)
12531265
update_start_time = time_now
12541266

1255-
if update_idx % args.log_interval == 0:
1267+
if update_idx % args.log_interval == 0 or last_batch:
12561268
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
12571269
lr = sum(lrl) / len(lrl)
12581270

0 commit comments

Comments
 (0)