-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace NumPy with Torch in examples/fabric/ #17279
Replace NumPy with Torch in examples/fabric/ #17279
Conversation
lightning-fabric (GPUs) (testing pkg: fabric) is failing at the standalone test for I'm not sure how to go about resolving the failed test as the examples I changed aren't going into that test.. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :)
minor comments regarding the creation of tensors.
examples/fabric/reinforcement_learning/train_fabric_decoupled.py
Outdated
Show resolved
Hide resolved
They are unrelated, you can ignore them and we will rerun them before merging. |
change numpy functions to the torch counterparts remove seeding numpy rng remove the numpy import
replace np.array in layer_init calls remove numpy import
replace the np.sqrt in layer init default arg replace np.logical_or with built-in or, no need for tensors/array. remove numpy import
subtraction was not valid for tensor of bools. Replace with logical not
same changes for all files remove numpy import
for more information, see https://pre-commit.ci
Subtraction with tensors of bools is not supported.
tolist() is significantly slower than numpy() Part of the torch so still allows for removal of numpy import
also to(device) moved within the tensor creation where possible tensor(..., device=device)
remove "from torch import Tensor" in some files where not needed
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
9bb27c9
to
aa49d64
Compare
What does this PR do?
Refactoring to remove all numpy function calls in the
examples/fabric
folder.Part of Issue: #16649
Related to Pull Requests: #17264, #17267 and #17278
Removal of numpy package import in the listed files.
Most replacements are simple substitutions of
np.<func>
totorch.<func>
, but also the inclusion of logical_not() where subtraction with tensors of bools are not allowed.Removal of seeding numpy rng
Where.numpy()
is called, it has been changed to.tolist()
as input to external functions..tolist()
is much slower than.numpy()
and since.numpy()
is part of the torch package this doesn't require the numpy import.Replace
Tensor(...).to(device)
withtorch.tensor(..., device=device)
Files:
examples/fabric/meta_learning/train_fabric.py
examples/fabric/meta_learning/train_torch.py
examples/fabric/reinforcement_learning/rl/agent.py
examples/fabric/reinforcement_learning/rl/utils.py
examples/fabric/reinforcement_learning/train_torch.py
examples/fabric/reinforcement_learning/train_fabric.py
examples/fabric/reinforcement_learning/train_fabric_decoupled.py
Part of #16649
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist