Skip to content

Commit

Permalink
Add loss multiplication to preserver the single-process performance
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhi.chen committed Jul 6, 2020
1 parent e838055 commit 16e7c26
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
15 changes: 9 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
try: # Mixed precision training https://github.com/NVIDIA/apex
from apex import amp
except:
#from apex.parallel import DistributedDataParallel as DDP
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
mixed_precision = False # not installed

Expand Down Expand Up @@ -120,7 +121,8 @@ def train(hyp, tb_writer, opt, device):
del pg0, pg1, pg2

# Load Model
if opt.local_rank in [-1, 0]:
# Avoid multiple downloads.
with torch_distributed_zero_first(opt.local_rank):
google_utils.attempt_download(weights)
start_epoch, best_fitness = 0, 0.0
if weights.endswith('.pt'): # pytorch format
Expand Down Expand Up @@ -274,6 +276,9 @@ def train(hyp, tb_writer, opt, device):

# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model)
# loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices.
if opt.local_rank != -1:
loss *= torch.distributed.get_world_size()
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results
Expand All @@ -293,9 +298,9 @@ def train(hyp, tb_writer, opt, device):
ema.update(model)

# Print
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses

if opt.local_rank in [-1, 0]:
# TODO: all_reduct mloss if in DDP mode.
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
Expand All @@ -319,7 +324,7 @@ def train(hyp, tb_writer, opt, device):
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data,
batch_size=total_batch_size,
batch_size=batch_size,
imgsz=imgsz_test,
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
Expand Down Expand Up @@ -377,8 +382,6 @@ def train(hyp, tb_writer, opt, device):
if not opt.evolve:
plot_results() # save as results.png
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.local_rank == 0:
dist.destroy_process_group()
torch.cuda.empty_cache()
return results

Expand Down
4 changes: 2 additions & 2 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,6 @@ def update(self, model):
def update_attr(self, model):
# Assign attributes (which may change during training)
for k in model.__dict__.keys():
if not k.startswith('_') and (k != 'module' or not isinstance(getattr(model, k),
(torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer))):
if not (k.startswith('_') or k == 'module' or
isinstance(getattr(model, k), (torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer))):
setattr(self.ema, k, getattr(model, k))

0 comments on commit 16e7c26

Please sign in to comment.