Skip to content

Commit

Permalink
update multi gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing committed Apr 16, 2022
1 parent 2d10e3c commit 1b86e78
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion utils/utils_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def train_step(images, labels, net, optimizer, loss, aux_branch, metrics):
@tf.function
def distributed_train_step(images, labels, net, optimizer, loss, aux_branch, metrics):
per_replica_losses, per_replica_score = strategy.run(train_step, args=(images, labels, net, optimizer, loss, aux_branch, metrics))
return strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses,axis=None), strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_score, axis=None)
return strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None), strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_score, axis=None)
return distributed_train_step

@tf.function
Expand Down

0 comments on commit 1b86e78

Please sign in to comment.