Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to parallelize between GPUs #30

Merged
merged 31 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9e13221
better printings
matsen Jun 5, 2024
03ea63e
factoring apart load_and_add_shm...
matsen Jun 5, 2024
fe8fa85
finding least used gpu
matsen Jun 5, 2024
4332ce0
picking device midstream
matsen Jun 5, 2024
a6444a2
pick device in train
matsen Jun 5, 2024
a857b8d
device for neutral model
matsen Jun 5, 2024
c2648ad
flopping model around every 10 epochs
matsen Jun 5, 2024
533a816
less flopping, global for GPU to use
matsen Jun 6, 2024
f90963a
moving optimizer as well
matsen Jun 6, 2024
7fbbfdd
cleaning out silly global
matsen Jun 6, 2024
d2a024c
not flopping midstream
matsen Jun 6, 2024
e320694
restoring tensor flopping; more thorough move?
matsen Jun 6, 2024
5a405f3
test: can we flop?
matsen Jun 6, 2024
22485ad
moving as part of cycle
matsen Jun 6, 2024
c20b826
moving datasets as well
matsen Jun 6, 2024
ea38435
giving up on switching gpus
matsen Jun 6, 2024
435d210
dropping model_and_optimizer_to
matsen Jun 6, 2024
03f68d2
feat: Add function to find least used CUDA GPU
matsen Jun 6, 2024
c6f0193
blarg
matsen Jun 6, 2024
addd57f
feat: Improve CUDA device selection for job
matsen Jun 6, 2024
3f3710c
no default for jobid
matsen Jun 6, 2024
db0607f
format
matsen Jun 6, 2024
ffffaec
gpu index improvement
matsen Jun 6, 2024
8f570f0
reverting changes to dnsm.py
matsen Jun 6, 2024
8195ca6
cleanup
matsen Jun 6, 2024
5f90930
cleanup and update
matsen Jun 6, 2024
248068e
device output
matsen Jun 6, 2024
399d413
device debug
matsen Jun 6, 2024
6455ab6
branch lengths are tensors
matsen Jun 6, 2024
647dde3
simplifying device handling of crepes
matsen Jun 6, 2024
3792d5f
clineau
matsen Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions netam/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import inspect
import itertools
import subprocess

import numpy as np
import torch
Expand Down Expand Up @@ -141,7 +142,32 @@ def stack_heterogeneous(tensors, pad_value=0.0):
return torch.stack(padded_tensors)


def pick_device():
def find_least_used_cuda_gpu():
"""
Find the least used CUDA GPU on the system using nvidia-smi.
If they are all idle, return None.
"""
result = subprocess.run(
["nvidia-smi", "--query-gpu=utilization.gpu", "--format=csv,nounits,noheader"],
stdout=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
print(f"Error running nvidia-smi.")
return None
utilization = [int(x) for x in result.stdout.strip().split("\n")]
if max(utilization) == 0:
return None # All GPUs are idle.
# else:
return utilization.index(min(utilization))


def pick_device(gpu_index=0):
"""
Pick a device for PyTorch to use. If CUDA is available, use the least used
GPU, and if all are idle use the gpu_index modulo the number of GPUs.
"""

# check that CUDA is usable
def check_CUDA():
try:
Expand All @@ -151,8 +177,11 @@ def check_CUDA():
return False

if torch.backends.cudnn.is_available() and check_CUDA():
print("Using CUDA")
return torch.device("cuda")
which_gpu = find_least_used_cuda_gpu()
if which_gpu is None:
which_gpu = gpu_index % torch.cuda.device_count()
print(f"Using CUDA GPU {which_gpu}")
return torch.device(f"cuda:{which_gpu}")
elif torch.backends.mps.is_available():
print("Using Metal Performance Shaders")
return torch.device("mps")
Expand Down
6 changes: 3 additions & 3 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ def export_branch_lengths(self, out_csv_path):
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values
self.branch_lengths = torch.Tensor(
pd.read_csv(in_csv_path)["branch_length"].values
)

def update_neutral_aa_mut_probs(self):
print("consolidating neutral rates into substitution probabilities...")

neutral_aa_mut_prob_l = []

for nt_parent, mask, rates, branch_length, subs_probs in zip(
Expand Down
30 changes: 17 additions & 13 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,12 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.encoder = encoder
self.model = model
self.training_hyperparameters = training_hyperparameters
self.device = None

@property
def device(self):
return next(self.model.parameters()).device

def to(self, device):
self.device = device
self.model.to(device)

def encode_sequences(self, sequences):
Expand All @@ -261,10 +263,9 @@ def encode_sequences(self, sequences):

def __call__(self, sequences):
encoded_parents, masks, wt_base_modifiers = self.encode_sequences(sequences)
if self.device is not None:
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
with torch.no_grad():
outputs = self.model(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)
Expand Down Expand Up @@ -344,9 +345,7 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents):
return trimmed_rates, trimmed_csps


def load_and_add_shm_model_outputs_to_pcp_df(
pcp_df_path_gz, crepe_prefix, sample_count=None, chosen_v_families=None
):
def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):
pcp_df = pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0).reset_index(
drop=True
)
Expand All @@ -356,7 +355,11 @@ def load_and_add_shm_model_outputs_to_pcp_df(
pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)]
if sample_count is not None:
pcp_df = pcp_df.sample(sample_count)
crepe = load_crepe(crepe_prefix)
return pcp_df


def add_shm_model_outputs_to_pcp_df(pcp_df, crepe_prefix, device=None):
crepe = load_crepe(crepe_prefix, device=device)
rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
pcp_df["rates"] = rates
pcp_df["subs_probs"] = csps
Expand Down Expand Up @@ -638,7 +641,7 @@ def standardize_and_optimize_branch_lengths(self, **optimization_kwargs):
optimization_kwargs["optimization_tol"] = 1e-3
# We do the branch length optimization on CPU but want to restore the
# model to the device it was on before.
device = next(self.model.parameters()).device
device = self.device
self.model.to("cpu")
for dataset in [self.train_dataset, self.val_dataset]:
if dataset is None:
Expand Down Expand Up @@ -705,7 +708,6 @@ def joint_train(
schedule that uses a weighted geometric mean of the current learning
rate and the initial learning rate that progressively moves towards
keeping the current learning rate as the cycles progress.

"""
if training_method == "full":
optimize_branch_lengths = self.standardize_and_optimize_branch_lengths
Expand All @@ -719,9 +721,11 @@ def joint_train(
optimize_branch_lengths()
self.mark_branch_lengths_optimized(0)
for cycle in range(cycle_count):
print(f"### Beginning cycle {cycle + 1}/{cycle_count}")
self.mark_branch_lengths_optimized(cycle + 1)
current_lr = self.optimizer.param_groups[0]["lr"]
# set new_lr to be the geometric mean of current_lr and the learning rate
# set new_lr to be the geometric mean of current_lr and the
# originally-specified learning rate
weight = 0.5 + cycle / (2 * cycle_count)
new_lr = np.exp(
weight * np.log(current_lr) + (1 - weight) * np.log(self.learning_rate)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from netam.framework import (
crepe_exists,
load_crepe,
load_and_add_shm_model_outputs_to_pcp_df,
load_pcp_df,
add_shm_model_outputs_to_pcp_df,
)
from netam.common import aa_idx_tensor_of_str_ambig, MAX_AMBIG_AA_IDX
from netam.models import TransformerBinarySelectionModelWiggleAct
Expand All @@ -22,8 +23,11 @@ def test_aa_idx_tensor_of_str_ambig():

@pytest.fixture
def pcp_df():
df = load_and_add_shm_model_outputs_to_pcp_df(
df = load_pcp_df(
"data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz",
)
df = add_shm_model_outputs_to_pcp_df(
df,
"data/cnn_joi_sml-shmoof_small",
)
return df
Expand Down