From 2e7bf95cc14ce3cf36552de53dbc6176c3625b2e Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Wed, 27 Mar 2024 17:43:11 +0100 Subject: [PATCH 01/24] Initial commit for an EESSI PyTorch test that uses torchvision models --- .../tests/apps/PyTorch/PyTorch_torchvision.py | 153 +++++++++++++ .../src/pytorch_synthetic_benchmark.py | 213 ++++++++++++++++++ 2 files changed, 366 insertions(+) create mode 100644 eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py create mode 100644 eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py new file mode 100644 index 00000000..91d0a708 --- /dev/null +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -0,0 +1,153 @@ +import reframe as rfm +import reframe.utility.sanity as sn + +from eessi.testsuite import hooks +from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, CPU_SOCKET, GPU +from eessi.testsuite.utils import find_modules, log + +class PyTorch_torchvision(rfm.RunOnlyRegressionTest): + nn_model = parameter(['vgg16', 'resnet50', 'resnet152', 'densenet121', 'mobilenet_v3_large']) + ### SHOULD BE DETERMINED BY SCALE + #n_processes = parameter([1, 2, 4, 8, 16]) + scale = parameter(SCALES.keys()) + # Not sure how we would ensure the horovod module is _also_ loaded... + # parallel_strategy = parameter([None, 'horovod', 'ddp']) + parallel_strategy = parameter([None, 'ddp']) + compute_device = variable(str) + # module_name = parameter(find_modules('PyTorch-bundle')) + module_name = parameter(find_modules('torchvision')) + + descr = 'Benchmark that runs a selected torchvision model on synthetic data' + + executable = 'python' + + valid_prog_environs = ['default'] + valid_systems = ['*'] + + time_limit = '30m' + + @run_after('init') + def prepare_test(self): + + # Set nn_model as executable option + self.executable_opts = ['pytorch_synthetic_benchmark.py --model %s' % self.nn_model] + + # If not a GPU run, disable CUDA + if self.compute_device != DEVICE_TYPES[GPU]: + self.executable_opts += ['--no-cuda'] + + + + @run_after('init') + def apply_init_hooks(self): + # Filter on which scales are supported by the partitions defined in the ReFrame configuration + hooks.filter_supported_scales(self) + + # Make sure that GPU tests run in partitions that support running on a GPU, + # and that CPU-only tests run in partitions that support running CPU-only. + # Also support setting valid_systems on the cmd line. + hooks.filter_valid_systems_by_device_type(self, required_device_type=self.compute_device) + + # Support selecting modules on the cmd line. + hooks.set_modules(self) + + # Support selecting scales on the cmd line via tags. + hooks.set_tag_scale(self) + + @run_after('init') + def set_tag_ci(self): + if self.nn_model == 'resnet50': + self.tags.add(TAGS['CI']) + + @run_after('setup') + def apply_setup_hooks(self): + if self.compute_device==DEVICE_TYPES[GPU]: + hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[GPU]) + else: + # Hybrid code, so launch 1 rank per socket. + # Probably, launching 1 task per NUMA domain is even better, but the current hook doesn't support it + hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[CPU_SOCKET]) + + # This is a hybrid test, binding is important for performance + hooks.set_compact_process_binding(self) + + @run_after('setup') + def set_ddp_env_vars(self): + # Set environment variables for PyTorch DDP + ### TODO: THIS WILL ONLY WORK WITH SLURM, WE SHOULD MAKE A SKIP_IF BASED ON THE SCHEDULER + if self.parallel_strategy == 'ddp': + self.prerun_cmds = [ + 'export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))', + 'export WORLD_SIZE=%s' % self.num_tasks, + 'echo "WORLD_SIZE="${WORLD_SIZE}', + 'master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)', + 'export MASTER_ADDR=${master_addr}', + 'echo "MASTER_ADDR"=${master_addr}', + ] + + + @run_after('setup') + def filter_invalid_parameter_combinations(self): + # We cannot detect this situation before the setup phase, because it requires self.num_tasks. + # Thus, the core count of the node needs to be known, which is only the case after the setup phase. + msg=f"Skipping test: parallel strategy is 'None', but requested process count is larger than one ({self.num_tasks})" + self.skip_if(self.num_tasks > 1 and self.parallel_strategy is None, msg) + msg=f"Skipping test: parallel strategy is {self.parallel_strategy}, but only one process is requested" + self.skip_if(self.num_tasks == 1 and not self.parallel_strategy is None, msg) + + @run_after('setup') + def pass_parallel_strategy(self): + # Set parallelization strategy when using more than one process + if self.num_tasks != 1: + self.executable_opts += ['--use-%s' % self.parallel_strategy] + + @run_after('setup') + def avoid_horovod_cpu_contention(self): + # Horovod had issues with CPU performance, see https://github.com/horovod/horovod/issues/2804 + # The root cause is Horovod having two threads with very high utilization, which interferes with + # the compute threads. It was fixed, but seems to be broken again in Horovod 0.28.1 + # The easiest workaround is to reduce the number of compute threads by 2 + if self.compute_device == DEVICE_TYPES[CPU] and self.parallel_strategy == 'horovod': + self.env_vars['OMP_NUM_THREADS'] = max(self.num_cpus_per_task-2, 2) # Never go below 2 compute threads + + @sanity_function + def assert_num_ranks(self): + '''Assert that the number of reported CPUs/GPUs used is correct''' + return sn.assert_found(r'Total img/sec on %s .PU\(s\):.*' % self.num_tasks, self.stdout) + + + @performance_function('img/sec') + def total_throughput(self): + '''Total training throughput, aggregated over all CPUs/GPUs''' + return sn.extractsingle(r'Total img/sec on [0-9]+ .PU\(s\):\s+(?P\S+)', self.stdout, 'perf', float) + + @performance_function('img/sec') + def througput_per_CPU(self): + '''Training througput per CPU''' + if self.compute_device == DEVICE_TYPES[CPU]: + return sn.extractsingle(r'Img/sec per CPU:\s+(?P\S+)', self.stdout, 'perf_per_cpu', float) + else: + return sn.extractsingle(r'Img/sec per GPU:\s+(?P\S+)', self.stdout, 'perf_per_gpu', float) + +@rfm.simple_test +class PyTorch_torchvision_CPU(PyTorch_torchvision): + compute_device = DEVICE_TYPES[CPU] + + +@rfm.simple_test +class PyTorch_torchvision_GPU(PyTorch_torchvision): + compute_device = DEVICE_TYPES[GPU] + precision = parameter(['default', 'mixed']) + + @run_after('init') + def prepare_gpu_test(self): + # Set precision + if self.precision == 'mixed': + self.executable_opts += ['--use-amp'] + + @run_after('init') + def skip_hvd_plus_amp(self): + '''Skip combination of horovod and AMP, it does not work see https://github.com/horovod/horovod/issues/1417''' + if self.parallel_strategy == 'horovod' and self.precision == 'mixed': + self.valid_systems = [] + diff --git a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py new file mode 100644 index 00000000..f3237e7d --- /dev/null +++ b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py @@ -0,0 +1,213 @@ +from __future__ import print_function + +import argparse +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data.distributed +from torchvision import models +import timeit +import numpy as np +import os + +# Benchmark settings +parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--fp16-allreduce', action='store_true', default=False, + help='use fp16 compression during allreduce') + +parser.add_argument('--model', type=str, default='resnet50', + help='model to benchmark') +parser.add_argument('--batch-size', type=int, default=32, + help='input batch size') + +parser.add_argument('--num-warmup-batches', type=int, default=10, + help='number of warm-up batches that don\'t count towards benchmark') +parser.add_argument('--num-batches-per-iter', type=int, default=10, + help='number of batches per benchmark iteration') +parser.add_argument('--num-iters', type=int, default=10, + help='number of benchmark iterations') + +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + +parser.add_argument('--use-adasum', action='store_true', default=False, + help='use adasum algorithm to do reduction') +parser.add_argument('--use-horovod', action='store_true', default=False) +parser.add_argument('--use-ddp', action='store_true', default=False) + +parser.add_argument('--use-amp', action='store_true', default=False, + help='Use PyTorch Automatic Mixed Precision (AMP)') + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +if args.use_horovod and args.use_ddp: + print("You can't specify to use both Horovod and Pytorch DDP, exiting...") + exit(1) + +# Set a default rank and world size, also for when ddp and horovod are not used +rank = 0 +world_size=1 +if args.use_horovod: + import horovod.torch as hvd + hvd.init() + rank = hvd.local_rank() + world_size = hvd.size() + + if args.cuda: + # If launched with srun, you are in a CGROUP with only 1 GPU, so you don't need to set it. + # If launched with mpirun, you see ALL local GPUs on the node, and you need to set which one + # this rank should use. + visible_gpus = torch.cuda.device_count() + # Horovod: pin GPU to local rank. + if visible_gpus > 1: + torch.cuda.set_device(hvd.local_rank()) + + # Should only be uncommented for debugging + # In ReFrame tests, a print from each rank can mess up the output file, causing + # performance and sanity patterns to not be found + # print(f"hvd.local_rank: {rank}", flush=True) + + +if args.use_ddp: + import torch.distributed as dist + from torch.nn.parallel import DistributedDataParallel as DDP + from socket import gethostname + + def setup(rank, world_size): + # initialize the process group + if args.cuda: + dist.init_process_group("nccl", rank=rank, world_size=world_size) + else: + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + def cleanup(): + # clean up the distributed environment + dist.destroy_process_group() + + world_size = int(os.environ["SLURM_NTASKS"]) + # If launched with mpirun, get rank from this + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", -1)) + if rank == -1: + # Else it's launched with srun, get rank from this + rank = int(os.environ["SLURM_PROCID"]) + + setup(rank, world_size) + # log(f"Group initialized? {dist.is_initialized()}", rank) + if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True) + + # If launched with srun, you are in a CGROUP with only 1 GPU, so you don't need to set it. + # If launched with mpirun, you see ALL local GPUs on the node, and you need to set which one + # this rank should use. + visible_gpus = torch.cuda.device_count() + if visible_gpus > 1: + local_rank = rank - visible_gpus * (rank // visible_gpus) + torch.cuda.set_device(local_rank) + print(f"host: {gethostname()}, rank: {rank}, local_rank: {local_rank}") + else: + print(f"host: {gethostname()}, rank: {rank}") + +# This relies on the 'rank' set in the if args.use_horovod or args.use_ddp sections +def log(s, nl=True): + if (args.use_horovod or args.use_ddp) and rank != 0: + return + print(s, end='\n' if nl else '', flush=True) + +log(f"World size: {world_size}") + +# Used to be needed, but now seems that different SLURM tasks run within their own cgroup +# Each cgroup only contains a single GPU, which has GPU ID 0. So no longer needed to set +# one of the ranks to GPU 0 and one to GPU 1 +#if args.cuda and args.use_horovod: +# # Horovod: pin GPU to local rank. +# torch.cuda.set_device(hvd.local_rank()) + +torch.set_num_threads(int(os.environ['OMP_NUM_THREADS'])) +torch.set_num_interop_threads(2) + +cudnn.benchmark = True + +# Set up standard model. +model = getattr(models, args.model)() + +# By default, Adasum doesn't need scaling up learning rate. +lr_scaler = hvd.size() if not args.use_adasum and args.use_horovod else 1 + +if args.cuda: + # Move model to GPU. + model.cuda() + # If using GPU Adasum allreduce, scale learning rate by local_size. + if args.use_horovod and args.use_adasum and hvd.nccl_built(): + lr_scaler = hvd.local_size() + +# If using DDP, wrap model +if args.use_ddp: + model = DDP(model) + +optimizer = optim.SGD(model.parameters(), lr=0.01 * lr_scaler) + +# Horovod: (optional) compression algorithm. +if args.use_horovod: + compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none + +# Horovod: wrap optimizer with DistributedOptimizer. +if args.use_horovod: + optimizer = hvd.DistributedOptimizer(optimizer, + named_parameters=model.named_parameters(), + compression=compression, + op=hvd.Adasum if args.use_adasum else hvd.Average) + + # Horovod: broadcast parameters & optimizer state. + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + +# Set up fixed fake data +data = torch.randn(args.batch_size, 3, 224, 224) +target = torch.LongTensor(args.batch_size).random_() % 1000 +if args.cuda: + data, target = data.cuda(), target.cuda() + +# Create GradScaler for automatic mixed precision +scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp) + +# Set device_type for AMP +if args.cuda: + device_type="cuda" +else: + device_type="cpu" + +def benchmark_step(): + optimizer.zero_grad() + with torch.autocast(device_type=device_type, enabled=args.use_amp): + output = model(data) + loss = F.cross_entropy(output, target) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + +log('Model: %s' % args.model) +log('Batch size: %d' % args.batch_size) +device = 'GPU' if args.cuda else 'CPU' +if args.use_horovod: + log('Number of %ss: %d' % (device, hvd.size())) + +# Warm-up +log('Running warmup...') +timeit.timeit(benchmark_step, number=args.num_warmup_batches) + +# Benchmark +log('Running benchmark...') +img_secs = [] +for x in range(args.num_iters): + time = timeit.timeit(benchmark_step, number=args.num_batches_per_iter) + img_sec = args.batch_size * args.num_batches_per_iter / time + log('Iter #%d: %.1f img/sec per %s' % (x, img_sec, device)) + img_secs.append(img_sec) + +# Results +img_sec_mean = np.mean(img_secs) +img_sec_conf = 1.96 * np.std(img_secs) +log('Img/sec per %s: %.1f +-%.1f' % (device, img_sec_mean, img_sec_conf)) +log('Total img/sec on %d %s(s): %.1f +-%.1f' % + (world_size, device, world_size * img_sec_mean, world_size * img_sec_conf)) From e089760e44eff8b88ede0d8bb38627392af3ac45 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Fri, 29 Mar 2024 12:00:40 +0100 Subject: [PATCH 02/24] Add option to assign number of tasks and cpus per task based on the amount of numa nodes in a node --- eessi/testsuite/hooks.py | 62 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/eessi/testsuite/hooks.py b/eessi/testsuite/hooks.py index c06ff572..883a0e4d 100644 --- a/eessi/testsuite/hooks.py +++ b/eessi/testsuite/hooks.py @@ -57,7 +57,7 @@ def _assign_default_num_gpus_per_node(test: rfm.RegressionTest): def assign_tasks_per_compute_unit(test: rfm.RegressionTest, compute_unit: str, num_per: int = 1): """ - Assign one task per compute unit (COMPUTE_UNIT[CPU], COMPUTE_UNIT[CPU_SOCKET] or COMPUTE_UNIT[GPU]). + Assign one task per compute unit. Automatically sets num_tasks, num_tasks_per_node, num_cpus_per_task, and num_gpus_per_node, based on the current scale and the current partition’s num_cpus, max_avail_gpus_per_node and num_nodes. For GPU tests, one task per GPU is set, and num_cpus_per_task is based on the ratio of CPU-cores/GPUs. @@ -109,6 +109,8 @@ def assign_tasks_per_compute_unit(test: rfm.RegressionTest, compute_unit: str, n _assign_one_task_per_cpu(test) elif compute_unit == COMPUTE_UNIT[CPU_SOCKET]: _assign_one_task_per_cpu_socket(test) + elif compute_unit == COMPUTE_UNIT[NUMA_NODE]: + _assign_one_task_per_numa_node(test) elif compute_unit == COMPUTE_UNIT[NODE]: _assign_num_tasks_per_node(test, num_per) else: @@ -175,7 +177,7 @@ def _assign_one_task_per_cpu_socket(test: rfm.RegressionTest): test.num_tasks_per_node * test.num_cpus_per_task == test.default_num_cpus_per_node. Default resources requested: - - num_tasks_per_node = default_num_cpus_per_node + - num_tasks_per_node = default_num_cpus_per_node / num_cpus_per_socket - num_cpus_per_task = default_num_cpus_per_node / num_tasks_per_node """ # neither num_tasks_per_node nor num_cpus_per_task are set @@ -206,6 +208,62 @@ def _assign_one_task_per_cpu_socket(test: rfm.RegressionTest): log(f'num_tasks set to {test.num_tasks}') +def _assign_one_task_per_numa_node(test: rfm.RegressionTest): + """ + Determines the number of tasks per node by dividing the default_num_cpus_per_node by + the number of cpus available per numa node, and rounding up. The result is that for full-node jobs the default + will spawn one task per numa node, with a number of cpus per task equal to the number of cpus per numa node. + Other examples: + - half a node (i.e. node_part=2) on a system with 4 numa nodes would result in 2 tasks per node, + with number of cpus per task equal to the number of cpus per numa node. + - a quarter node (i.e. node_part=4) on a system with 2 numa nodes would result in 1 task per node, + with number of cpus equal to half a numa node. + - 2 cores (i.e. default_num_cpus_per_node=2) on a system with 4 cores per numa node would result in + 1 task per node, with 2 cpus per task + - 8 cores (i.e. default_num_cpus_per_node=4) on a system with 4 cores per numa node would result in + 2 tasks per node, with 4 cpus per task + + This default is set unless the test is run with: + --setvar num_tasks_per_node= and/or + --setvar num_cpus_per_task=. + In those cases, those take precedence, and the remaining variable (num_cpus_per task or + num_tasks_per_node respectively) is calculated based on the equality + test.num_tasks_per_node * test.num_cpus_per_task == test.default_num_cpus_per_node. + + Default resources requested: + - num_tasks_per_node = default_num_cpus_per_node / num_cores_per_numa_node + - num_cpus_per_task = default_num_cpus_per_node / num_tasks_per_node + """ + # neither num_tasks_per_node nor num_cpus_per_task are set + if not test.num_tasks_per_node and not test.num_cpus_per_task: + # Not needed, if num_cores_per_numa_node is really defined by reframe... https://reframe-hpc.readthedocs.io/en/stable/regression_test_api.html#reframe.core.systems.ProcessorInfo + # check_proc_attribute_defined(test, 'num_cpus') + check_proc_attribute_defined(test, 'num_cores_per_numa_node') + # num_cpus_per_socket = test.current_partition.processor.num_cpus / test.current_partition.processor.num_sockets + test.num_tasks_per_node = math.ceil(test.default_num_cpus_per_node / num_cores_per_numa_node) + test.num_cpus_per_task = int(test.default_num_cpus_per_node / test.num_tasks_per_node) + + # num_tasks_per_node is not set, but num_cpus_per_task is + elif not test.num_tasks_per_node: + # check_proc_attribute_defined(test, 'num_cpus') + check_proc_attribute_defined(test, 'num_cores_per_numa_node') + # num_cpus_per_socket = test.current_partition.processor.num_cpus / test.current_partition.processor.num_sockets # Unused? + test.num_tasks_per_node = int(test.default_num_cpus_per_node / test.num_cpus_per_task) + + # num_cpus_per_task is not set, but num_tasks_per_node is + elif not test.num_cpus_per_task: + test.num_cpus_per_task = int(test.default_num_cpus_per_node / test.num_tasks_per_node) + + else: + pass # both num_tasks_per_node and num_cpus_per_node are already set + + test.num_tasks = test.num_nodes * test.num_tasks_per_node + log(f'Number of tasks per node set to: {test.num_tasks_per_node}') + log(f'Number of cpus per task set to {test.num_cpus_per_task}') + log(f'num_tasks set to {test.num_tasks}') + + + def _assign_one_task_per_cpu(test: rfm.RegressionTest): """ Sets num_tasks_per_node and num_cpus_per_task such that it will run one task per core, From a675fc976437cb533c0f1ff5d80903b16c31925c Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Fri, 29 Mar 2024 12:09:14 +0100 Subject: [PATCH 03/24] We don't need to first calculate cpus_per_socket, it is available directly from the refraem.core.systems.ProcessorInfo object. See https://reframe-hpc.readthedocs.io/en/stable/regression_test_api.html#reframe.core.systems.ProcessorInfo --- eessi/testsuite/hooks.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/eessi/testsuite/hooks.py b/eessi/testsuite/hooks.py index 883a0e4d..5956ac61 100644 --- a/eessi/testsuite/hooks.py +++ b/eessi/testsuite/hooks.py @@ -182,17 +182,12 @@ def _assign_one_task_per_cpu_socket(test: rfm.RegressionTest): """ # neither num_tasks_per_node nor num_cpus_per_task are set if not test.num_tasks_per_node and not test.num_cpus_per_task: - check_proc_attribute_defined(test, 'num_cpus') - check_proc_attribute_defined(test, 'num_sockets') - num_cpus_per_socket = test.current_partition.processor.num_cpus / test.current_partition.processor.num_sockets - test.num_tasks_per_node = math.ceil(test.default_num_cpus_per_node / num_cpus_per_socket) + check_proc_attribute_defined(test, 'num_cores_per_socket') + test.num_tasks_per_node = math.ceil(test.default_num_cpus_per_node / test.current_partition.processor.num_cores_per_socket) test.num_cpus_per_task = int(test.default_num_cpus_per_node / test.num_tasks_per_node) # num_tasks_per_node is not set, but num_cpus_per_task is elif not test.num_tasks_per_node: - check_proc_attribute_defined(test, 'num_cpus') - check_proc_attribute_defined(test, 'num_sockets') - num_cpus_per_socket = test.current_partition.processor.num_cpus / test.current_partition.processor.num_sockets test.num_tasks_per_node = int(test.default_num_cpus_per_node / test.num_cpus_per_task) # num_cpus_per_task is not set, but num_tasks_per_node is @@ -236,18 +231,12 @@ def _assign_one_task_per_numa_node(test: rfm.RegressionTest): """ # neither num_tasks_per_node nor num_cpus_per_task are set if not test.num_tasks_per_node and not test.num_cpus_per_task: - # Not needed, if num_cores_per_numa_node is really defined by reframe... https://reframe-hpc.readthedocs.io/en/stable/regression_test_api.html#reframe.core.systems.ProcessorInfo - # check_proc_attribute_defined(test, 'num_cpus') check_proc_attribute_defined(test, 'num_cores_per_numa_node') - # num_cpus_per_socket = test.current_partition.processor.num_cpus / test.current_partition.processor.num_sockets test.num_tasks_per_node = math.ceil(test.default_num_cpus_per_node / num_cores_per_numa_node) test.num_cpus_per_task = int(test.default_num_cpus_per_node / test.num_tasks_per_node) # num_tasks_per_node is not set, but num_cpus_per_task is elif not test.num_tasks_per_node: - # check_proc_attribute_defined(test, 'num_cpus') - check_proc_attribute_defined(test, 'num_cores_per_numa_node') - # num_cpus_per_socket = test.current_partition.processor.num_cpus / test.current_partition.processor.num_sockets # Unused? test.num_tasks_per_node = int(test.default_num_cpus_per_node / test.num_cpus_per_task) # num_cpus_per_task is not set, but num_tasks_per_node is From a6e53bc0124b7163dfede9d430c4d62046e4670b Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Fri, 29 Mar 2024 12:09:37 +0100 Subject: [PATCH 04/24] Define constant for numa node as a compute unit --- eessi/testsuite/constants.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/eessi/testsuite/constants.py b/eessi/testsuite/constants.py index 9b7d6ac3..5254f710 100644 --- a/eessi/testsuite/constants.py +++ b/eessi/testsuite/constants.py @@ -6,6 +6,7 @@ CI = 'CI' CPU = 'CPU' CPU_SOCKET = 'CPU_SOCKET' +NUMA_NODE = 'NUMA_NODE' GPU = 'GPU' GPU_VENDOR = 'GPU_VENDOR' INTEL = 'INTEL' @@ -21,6 +22,7 @@ COMPUTE_UNIT = { CPU: 'cpu', CPU_SOCKET: 'cpu_socket', + NUMA_NODE: 'numa_node', GPU: 'gpu', NODE: 'node', } From 8bb9bbd8bd631b387be3566b7b902138d495bc10 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Fri, 29 Mar 2024 12:26:36 +0100 Subject: [PATCH 05/24] Fix num_cores_per_numa_node, use it from the current_partition info --- eessi/testsuite/hooks.py | 2 +- .../tests/apps/PyTorch/src/python_get_free_socket.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 eessi/testsuite/tests/apps/PyTorch/src/python_get_free_socket.py diff --git a/eessi/testsuite/hooks.py b/eessi/testsuite/hooks.py index 5956ac61..b9a3dc93 100644 --- a/eessi/testsuite/hooks.py +++ b/eessi/testsuite/hooks.py @@ -232,7 +232,7 @@ def _assign_one_task_per_numa_node(test: rfm.RegressionTest): # neither num_tasks_per_node nor num_cpus_per_task are set if not test.num_tasks_per_node and not test.num_cpus_per_task: check_proc_attribute_defined(test, 'num_cores_per_numa_node') - test.num_tasks_per_node = math.ceil(test.default_num_cpus_per_node / num_cores_per_numa_node) + test.num_tasks_per_node = math.ceil(test.default_num_cpus_per_node / test.current_partition.processor.num_cores_per_numa_node) test.num_cpus_per_task = int(test.default_num_cpus_per_node / test.num_tasks_per_node) # num_tasks_per_node is not set, but num_cpus_per_task is diff --git a/eessi/testsuite/tests/apps/PyTorch/src/python_get_free_socket.py b/eessi/testsuite/tests/apps/PyTorch/src/python_get_free_socket.py new file mode 100644 index 00000000..a2981304 --- /dev/null +++ b/eessi/testsuite/tests/apps/PyTorch/src/python_get_free_socket.py @@ -0,0 +1,8 @@ +# Based on https://unix.stackexchange.com/a/132524 +import socket + +s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +s.bind(('', 0)) +addr = s.getsockname() +print(addr[1]) +s.close() From a6bf34d3ad79a2f4734cdede08ac9cca12562944 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Fri, 29 Mar 2024 12:28:19 +0100 Subject: [PATCH 06/24] Work around the issue of not being able to export varialbes in a launcher-agnostic way, by simply passing them to the python script as argument, and having that set it to the environment. Print clear error if SLURM or openMPIs mpirun are not used - we still rely on these to get the local rank, there is no other way --- .../tests/apps/PyTorch/PyTorch_torchvision.py | 19 ++++------- .../src/pytorch_synthetic_benchmark.py | 34 ++++++++++++++----- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index 91d0a708..6367c63e 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -2,7 +2,7 @@ import reframe.utility.sanity as sn from eessi.testsuite import hooks -from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, CPU_SOCKET, GPU +from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, NUMA_NODE, GPU, INVALID_SYSTEM from eessi.testsuite.utils import find_modules, log class PyTorch_torchvision(rfm.RunOnlyRegressionTest): @@ -66,7 +66,7 @@ def apply_setup_hooks(self): else: # Hybrid code, so launch 1 rank per socket. # Probably, launching 1 task per NUMA domain is even better, but the current hook doesn't support it - hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[CPU_SOCKET]) + hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[NUMA_NODE]) # This is a hybrid test, binding is important for performance hooks.set_compact_process_binding(self) @@ -76,15 +76,10 @@ def set_ddp_env_vars(self): # Set environment variables for PyTorch DDP ### TODO: THIS WILL ONLY WORK WITH SLURM, WE SHOULD MAKE A SKIP_IF BASED ON THE SCHEDULER if self.parallel_strategy == 'ddp': - self.prerun_cmds = [ - 'export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))', - 'export WORLD_SIZE=%s' % self.num_tasks, - 'echo "WORLD_SIZE="${WORLD_SIZE}', - 'master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)', - 'export MASTER_ADDR=${master_addr}', - 'echo "MASTER_ADDR"=${master_addr}', - ] - + # Set additional options required by DDP + self.executable_opts += ["--master-port $(python python_get_free_socket.py)"] + self.executable_opts += ["--master-address $(hostname --fqdn)"] + self.executable_opts += ["--world-size %s" % self.num_tasks] @run_after('setup') def filter_invalid_parameter_combinations(self): @@ -149,5 +144,5 @@ def prepare_gpu_test(self): def skip_hvd_plus_amp(self): '''Skip combination of horovod and AMP, it does not work see https://github.com/horovod/horovod/issues/1417''' if self.parallel_strategy == 'horovod' and self.precision == 'mixed': - self.valid_systems = [] + self.valid_systems = [INVALID_SYSTEM] diff --git a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py index f3237e7d..4c0db3be 100644 --- a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py +++ b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py @@ -1,14 +1,15 @@ -from __future__ import print_function - import argparse +import timeit +import os +import random + +import numpy as np + import torch.backends.cudnn as cudnn import torch.nn.functional as F import torch.optim as optim import torch.utils.data.distributed from torchvision import models -import timeit -import numpy as np -import os # Benchmark settings parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark', @@ -38,6 +39,12 @@ parser.add_argument('--use-amp', action='store_true', default=False, help='Use PyTorch Automatic Mixed Precision (AMP)') +parser.add_argument('--world-size', type=int, default=1, + help='Define the world size for ddp') +parser.add_argument('--master-port', type=int, default=False, + help='Define a master port for ddp') +parser.add_argument('--master-address', type=str, default='localhost', + help='Define a master address for ddp') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() @@ -46,9 +53,15 @@ print("You can't specify to use both Horovod and Pytorch DDP, exiting...") exit(1) +# Set MASTER_ADDR and MASTER_PORT environment variables +# By doing it as part of this python script, we don't need to have the launchers export them +# This saves us from having to find a launcher-agnostic way of exporting variables +os.environ['MASTER_ADDR'] = args.master_address +os.environ['MASTER_PORT'] = '%s' % args.master_port + # Set a default rank and world size, also for when ddp and horovod are not used rank = 0 -world_size=1 +world_size = args.world_size if args.use_horovod: import horovod.torch as hvd hvd.init() @@ -86,12 +99,17 @@ def cleanup(): # clean up the distributed environment dist.destroy_process_group() - world_size = int(os.environ["SLURM_NTASKS"]) + # world_size = int(os.environ["SLURM_NTASKS"]) ## No longer needed now we pass it as argument? # If launched with mpirun, get rank from this rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", -1)) if rank == -1: # Else it's launched with srun, get rank from this - rank = int(os.environ["SLURM_PROCID"]) + rank = int(os.environ.get("SLURM_PROCID", -1)) + if rank == -1: + err_msg = "ERROR: cannot determine local rank. This test currently only supports OpenMPI" + err_msg += " and srun as launchers. If you've configured a different launcher for your system" + err_msg += " this test will need to be extended with a method to get it's local rank for that launcher." + print(err_msg) setup(rank, world_size) # log(f"Group initialized? {dist.is_initialized()}", rank) From fce2a4518dbe43591abc5ec6d3992ffd3dac20bb Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Tue, 9 Apr 2024 10:50:47 +0200 Subject: [PATCH 07/24] Change order of imports so that initialization only happens after required environment variables have been set. --- .../src/pytorch_synthetic_benchmark.py | 46 ++++++++++++------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py index 4c0db3be..78ba43df 100644 --- a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py +++ b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py @@ -11,6 +11,7 @@ import torch.utils.data.distributed from torchvision import models + # Benchmark settings parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -84,20 +85,9 @@ if args.use_ddp: + from socket import gethostname import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP - from socket import gethostname - - def setup(rank, world_size): - # initialize the process group - if args.cuda: - dist.init_process_group("nccl", rank=rank, world_size=world_size) - else: - dist.init_process_group("gloo", rank=rank, world_size=world_size) - - def cleanup(): - # clean up the distributed environment - dist.destroy_process_group() # world_size = int(os.environ["SLURM_NTASKS"]) ## No longer needed now we pass it as argument? # If launched with mpirun, get rank from this @@ -110,22 +100,46 @@ def cleanup(): err_msg += " and srun as launchers. If you've configured a different launcher for your system" err_msg += " this test will need to be extended with a method to get it's local rank for that launcher." print(err_msg) - - setup(rank, world_size) - # log(f"Group initialized? {dist.is_initialized()}", rank) - if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True) # If launched with srun, you are in a CGROUP with only 1 GPU, so you don't need to set it. # If launched with mpirun, you see ALL local GPUs on the node, and you need to set which one # this rank should use. visible_gpus = torch.cuda.device_count() if visible_gpus > 1: + print("Listing visible devices") + for i in range(torch.cuda.device_count()): + print(f"Device {i}: {torch.cuda.device(i)}") local_rank = rank - visible_gpus * (rank // visible_gpus) torch.cuda.set_device(local_rank) + print("Listing visible devices after setting one") + for i in range(torch.cuda.device_count()): + print(f"Device {i}: {torch.cuda.device(i)}") + # We should also set CUDA_VISIBLE_DEVICES, which gets respected by NCCL + os.environ['CUDA_VISIBLE_DEVICES'] = '%s' % local_rank print(f"host: {gethostname()}, rank: {rank}, local_rank: {local_rank}") else: print(f"host: {gethostname()}, rank: {rank}") + + def setup(rank, world_size): + + # initialize the process group + if args.cuda: + dist.init_process_group("nccl", rank=rank, world_size=world_size) + else: + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + def cleanup(): + # clean up the distributed environment + dist.destroy_process_group() + + setup(rank, world_size) + # log(f"Group initialized? {dist.is_initialized()}", rank) + if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True) + + + + # This relies on the 'rank' set in the if args.use_horovod or args.use_ddp sections def log(s, nl=True): if (args.use_horovod or args.use_ddp) and rank != 0: From b868cc14b4eb2be2ee721cec531baef4183eeec0 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Fri, 3 May 2024 15:45:36 +0200 Subject: [PATCH 08/24] Add comment on explicit assumption for computing the local rank --- .../tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py index 78ba43df..c349b089 100644 --- a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py +++ b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py @@ -109,6 +109,10 @@ print("Listing visible devices") for i in range(torch.cuda.device_count()): print(f"Device {i}: {torch.cuda.device(i)}") + # This assumes compact mapping of ranks to available hardware + # e.g. rank 0-x to node 1, rank x-y to node 2, etc + # Assuming the set_compact_process_binding hook from the EESSI testsuite is called, + # this condition should be satisfied local_rank = rank - visible_gpus * (rank // visible_gpus) torch.cuda.set_device(local_rank) print("Listing visible devices after setting one") From fed890617e3ed556f23b72b80e0e11246a88d4f0 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 6 May 2024 11:04:35 +0200 Subject: [PATCH 09/24] Use EESSI prefix to name test --- eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index 6367c63e..ef5a854c 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -5,7 +5,7 @@ from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, NUMA_NODE, GPU, INVALID_SYSTEM from eessi.testsuite.utils import find_modules, log -class PyTorch_torchvision(rfm.RunOnlyRegressionTest): +class EESSI_PyTorch_torchvision(rfm.RunOnlyRegressionTest): nn_model = parameter(['vgg16', 'resnet50', 'resnet152', 'densenet121', 'mobilenet_v3_large']) ### SHOULD BE DETERMINED BY SCALE #n_processes = parameter([1, 2, 4, 8, 16]) From 357f649296361339a127025cfbd0efc071d8058e Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 6 May 2024 11:08:50 +0200 Subject: [PATCH 10/24] Child classes should also be renamed and inherit from renamed class --- eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index ef5a854c..f7e4286b 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -125,12 +125,12 @@ def througput_per_CPU(self): return sn.extractsingle(r'Img/sec per GPU:\s+(?P\S+)', self.stdout, 'perf_per_gpu', float) @rfm.simple_test -class PyTorch_torchvision_CPU(PyTorch_torchvision): +class EESSI_PyTorch_torchvision_CPU(EESSI_PyTorch_torchvision): compute_device = DEVICE_TYPES[CPU] @rfm.simple_test -class PyTorch_torchvision_GPU(PyTorch_torchvision): +class EESSI_PyTorch_torchvision_GPU(EESSI_PyTorch_torchvision): compute_device = DEVICE_TYPES[GPU] precision = parameter(['default', 'mixed']) From 6b1e36ae63bdc49c3812ae5bac81ac063547da97 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 6 May 2024 14:44:56 +0200 Subject: [PATCH 11/24] Remove stray blank line --- eessi/testsuite/hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/eessi/testsuite/hooks.py b/eessi/testsuite/hooks.py index 1043a425..73306286 100644 --- a/eessi/testsuite/hooks.py +++ b/eessi/testsuite/hooks.py @@ -252,7 +252,6 @@ def _assign_one_task_per_numa_node(test: rfm.RegressionTest): log(f'num_tasks set to {test.num_tasks}') - def _assign_one_task_per_cpu(test: rfm.RegressionTest): """ Sets num_tasks_per_node and num_cpus_per_task such that it will run one task per core, From 2d3314174ab51d1867d1815edfcd2cc2cd0d9822 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 6 May 2024 15:29:58 +0200 Subject: [PATCH 12/24] Rephrased comment, some changes to make the linter happy --- .../tests/apps/PyTorch/PyTorch_torchvision.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index f7e4286b..8011cd0d 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -1,21 +1,19 @@ import reframe as rfm import reframe.utility.sanity as sn +from reframe.core.builtins import parameter, run_after # added only to make the linter happy from eessi.testsuite import hooks from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, NUMA_NODE, GPU, INVALID_SYSTEM -from eessi.testsuite.utils import find_modules, log +from eessi.testsuite.utils import find_modules + class EESSI_PyTorch_torchvision(rfm.RunOnlyRegressionTest): nn_model = parameter(['vgg16', 'resnet50', 'resnet152', 'densenet121', 'mobilenet_v3_large']) - ### SHOULD BE DETERMINED BY SCALE - #n_processes = parameter([1, 2, 4, 8, 16]) scale = parameter(SCALES.keys()) - # Not sure how we would ensure the horovod module is _also_ loaded... - # parallel_strategy = parameter([None, 'horovod', 'ddp']) parallel_strategy = parameter([None, 'ddp']) compute_device = variable(str) - # module_name = parameter(find_modules('PyTorch-bundle')) - module_name = parameter(find_modules('torchvision')) + # Both torchvision and PyTorch-bundle modules have everything needed to run this test + module_name = parameter(find_modules('torchvision') + find_modules('PyTorch-bundle')) descr = 'Benchmark that runs a selected torchvision model on synthetic data' @@ -36,8 +34,6 @@ def prepare_test(self): if self.compute_device != DEVICE_TYPES[GPU]: self.executable_opts += ['--no-cuda'] - - @run_after('init') def apply_init_hooks(self): # Filter on which scales are supported by the partitions defined in the ReFrame configuration @@ -74,7 +70,6 @@ def apply_setup_hooks(self): @run_after('setup') def set_ddp_env_vars(self): # Set environment variables for PyTorch DDP - ### TODO: THIS WILL ONLY WORK WITH SLURM, WE SHOULD MAKE A SKIP_IF BASED ON THE SCHEDULER if self.parallel_strategy == 'ddp': # Set additional options required by DDP self.executable_opts += ["--master-port $(python python_get_free_socket.py)"] @@ -109,7 +104,6 @@ def avoid_horovod_cpu_contention(self): def assert_num_ranks(self): '''Assert that the number of reported CPUs/GPUs used is correct''' return sn.assert_found(r'Total img/sec on %s .PU\(s\):.*' % self.num_tasks, self.stdout) - @performance_function('img/sec') def total_throughput(self): @@ -124,6 +118,7 @@ def througput_per_CPU(self): else: return sn.extractsingle(r'Img/sec per GPU:\s+(?P\S+)', self.stdout, 'perf_per_gpu', float) + @rfm.simple_test class EESSI_PyTorch_torchvision_CPU(EESSI_PyTorch_torchvision): compute_device = DEVICE_TYPES[CPU] From 4cb7b360ceeed5de4859328f4850a63b9724ebe3 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 6 May 2024 16:25:00 +0200 Subject: [PATCH 13/24] Fix some more linter issues --- eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index 8011cd0d..b2aa36bf 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -1,6 +1,6 @@ import reframe as rfm import reframe.utility.sanity as sn -from reframe.core.builtins import parameter, run_after # added only to make the linter happy +from reframe.core.builtins import parameter, variable, run_after, sanity_function, performance_function # added only to make the linter happy from eessi.testsuite import hooks from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, NUMA_NODE, GPU, INVALID_SYSTEM From 2f0bea270f5c520b9e6ba92f47d767b8f7842002 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 6 May 2024 16:39:27 +0200 Subject: [PATCH 14/24] Fix some more linter issues --- .../tests/apps/PyTorch/PyTorch_torchvision.py | 13 +++++++------ .../apps/PyTorch/src/pytorch_synthetic_benchmark.py | 13 ++++--------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index b2aa36bf..44cdd7f4 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -53,7 +53,7 @@ def apply_init_hooks(self): @run_after('init') def set_tag_ci(self): if self.nn_model == 'resnet50': - self.tags.add(TAGS['CI']) + self.tags.add(TAGS['CI']) @run_after('setup') def apply_setup_hooks(self): @@ -80,10 +80,12 @@ def set_ddp_env_vars(self): def filter_invalid_parameter_combinations(self): # We cannot detect this situation before the setup phase, because it requires self.num_tasks. # Thus, the core count of the node needs to be known, which is only the case after the setup phase. - msg=f"Skipping test: parallel strategy is 'None', but requested process count is larger than one ({self.num_tasks})" + msg = f"Skipping test: parallel strategy is 'None'," + msg += f" but requested process count is larger than one ({self.num_tasks})." self.skip_if(self.num_tasks > 1 and self.parallel_strategy is None, msg) - msg=f"Skipping test: parallel strategy is {self.parallel_strategy}, but only one process is requested" - self.skip_if(self.num_tasks == 1 and not self.parallel_strategy is None, msg) + msg = f"Skipping test: parallel strategy is {self.parallel_strategy}," + msg += f" but only one process is requested." + self.skip_if(self.num_tasks == 1 and self.parallel_strategy is not None, msg) @run_after('setup') def pass_parallel_strategy(self): @@ -98,7 +100,7 @@ def avoid_horovod_cpu_contention(self): # the compute threads. It was fixed, but seems to be broken again in Horovod 0.28.1 # The easiest workaround is to reduce the number of compute threads by 2 if self.compute_device == DEVICE_TYPES[CPU] and self.parallel_strategy == 'horovod': - self.env_vars['OMP_NUM_THREADS'] = max(self.num_cpus_per_task-2, 2) # Never go below 2 compute threads + self.env_vars['OMP_NUM_THREADS'] = max(self.num_cpus_per_task - 2, 2) # Never go below 2 compute threads @sanity_function def assert_num_ranks(self): @@ -140,4 +142,3 @@ def skip_hvd_plus_amp(self): '''Skip combination of horovod and AMP, it does not work see https://github.com/horovod/horovod/issues/1417''' if self.parallel_strategy == 'horovod' and self.precision == 'mixed': self.valid_systems = [INVALID_SYSTEM] - diff --git a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py index c349b089..b414d549 100644 --- a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py +++ b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py @@ -1,7 +1,6 @@ import argparse import timeit import os -import random import numpy as np @@ -124,7 +123,6 @@ else: print(f"host: {gethostname()}, rank: {rank}") - def setup(rank, world_size): # initialize the process group @@ -132,17 +130,14 @@ def setup(rank, world_size): dist.init_process_group("nccl", rank=rank, world_size=world_size) else: dist.init_process_group("gloo", rank=rank, world_size=world_size) - + def cleanup(): # clean up the distributed environment dist.destroy_process_group() - - setup(rank, world_size) - # log(f"Group initialized? {dist.is_initialized()}", rank) - if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True) - - + setup(rank, world_size) + if rank == 0: + print(f"Group initialized? {dist.is_initialized()}", flush=True) # This relies on the 'rank' set in the if args.use_horovod or args.use_ddp sections def log(s, nl=True): From fc067b218a02fe3fdbf700fdec16287dfec882bd Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 6 May 2024 16:51:23 +0200 Subject: [PATCH 15/24] Fix linter issues --- eessi/testsuite/hooks.py | 8 ++++++-- .../PyTorch/src/pytorch_synthetic_benchmark.py | 15 ++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/eessi/testsuite/hooks.py b/eessi/testsuite/hooks.py index 73306286..4bcdf6e6 100644 --- a/eessi/testsuite/hooks.py +++ b/eessi/testsuite/hooks.py @@ -183,7 +183,9 @@ def _assign_one_task_per_cpu_socket(test: rfm.RegressionTest): # neither num_tasks_per_node nor num_cpus_per_task are set if not test.num_tasks_per_node and not test.num_cpus_per_task: check_proc_attribute_defined(test, 'num_cores_per_socket') - test.num_tasks_per_node = math.ceil(test.default_num_cpus_per_node / test.current_partition.processor.num_cores_per_socket) + test.num_tasks_per_node = math.ceil( + test.default_num_cpus_per_node / test.current_partition.processor.num_cores_per_socket + ) test.num_cpus_per_task = int(test.default_num_cpus_per_node / test.num_tasks_per_node) # num_tasks_per_node is not set, but num_cpus_per_task is @@ -232,7 +234,9 @@ def _assign_one_task_per_numa_node(test: rfm.RegressionTest): # neither num_tasks_per_node nor num_cpus_per_task are set if not test.num_tasks_per_node and not test.num_cpus_per_task: check_proc_attribute_defined(test, 'num_cores_per_numa_node') - test.num_tasks_per_node = math.ceil(test.default_num_cpus_per_node / test.current_partition.processor.num_cores_per_numa_node) + test.num_tasks_per_node = math.ceil( + test.default_num_cpus_per_node / test.current_partition.processor.num_cores_per_numa_node + ) test.num_cpus_per_task = int(test.default_num_cpus_per_node / test.num_tasks_per_node) # num_tasks_per_node is not set, but num_cpus_per_task is diff --git a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py index b414d549..373790b5 100644 --- a/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py +++ b/eessi/testsuite/tests/apps/PyTorch/src/pytorch_synthetic_benchmark.py @@ -139,20 +139,15 @@ def cleanup(): if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True) + # This relies on the 'rank' set in the if args.use_horovod or args.use_ddp sections def log(s, nl=True): if (args.use_horovod or args.use_ddp) and rank != 0: return print(s, end='\n' if nl else '', flush=True) -log(f"World size: {world_size}") -# Used to be needed, but now seems that different SLURM tasks run within their own cgroup -# Each cgroup only contains a single GPU, which has GPU ID 0. So no longer needed to set -# one of the ranks to GPU 0 and one to GPU 1 -#if args.cuda and args.use_horovod: -# # Horovod: pin GPU to local rank. -# torch.cuda.set_device(hvd.local_rank()) +log(f"World size: {world_size}") torch.set_num_threads(int(os.environ['OMP_NUM_THREADS'])) torch.set_num_interop_threads(2) @@ -204,9 +199,10 @@ def log(s, nl=True): # Set device_type for AMP if args.cuda: - device_type="cuda" + device_type = "cuda" else: - device_type="cpu" + device_type = "cpu" + def benchmark_step(): optimizer.zero_grad() @@ -217,6 +213,7 @@ def benchmark_step(): scaler.step(optimizer) scaler.update() + log('Model: %s' % args.model) log('Batch size: %d' % args.batch_size) device = 'GPU' if args.cuda else 'CPU' From 4ddfe23d9c844f4b5880278118913ee5bbe454b1 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 6 May 2024 16:52:25 +0200 Subject: [PATCH 16/24] Fix linter issues --- .../testsuite/tests/apps/PyTorch/PyTorch_torchvision.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index 44cdd7f4..68223cef 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -1,6 +1,7 @@ import reframe as rfm import reframe.utility.sanity as sn -from reframe.core.builtins import parameter, variable, run_after, sanity_function, performance_function # added only to make the linter happy +# Added only to make the linter happy +from reframe.core.builtins import parameter, variable, run_after, sanity_function, performance_function from eessi.testsuite import hooks from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, NUMA_NODE, GPU, INVALID_SYSTEM @@ -57,7 +58,7 @@ def set_tag_ci(self): @run_after('setup') def apply_setup_hooks(self): - if self.compute_device==DEVICE_TYPES[GPU]: + if self.compute_device == DEVICE_TYPES[GPU]: hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[GPU]) else: # Hybrid code, so launch 1 rank per socket. @@ -80,11 +81,11 @@ def set_ddp_env_vars(self): def filter_invalid_parameter_combinations(self): # We cannot detect this situation before the setup phase, because it requires self.num_tasks. # Thus, the core count of the node needs to be known, which is only the case after the setup phase. - msg = f"Skipping test: parallel strategy is 'None'," + msg = "Skipping test: parallel strategy is 'None'," msg += f" but requested process count is larger than one ({self.num_tasks})." self.skip_if(self.num_tasks > 1 and self.parallel_strategy is None, msg) msg = f"Skipping test: parallel strategy is {self.parallel_strategy}," - msg += f" but only one process is requested." + msg += " but only one process is requested." self.skip_if(self.num_tasks == 1 and self.parallel_strategy is not None, msg) @run_after('setup') From 73b7e846009a47436562fe5c5cbb230bf287f71c Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Tue, 7 May 2024 11:24:30 +0200 Subject: [PATCH 17/24] Can't combine generators with plus, so use chain --- eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index 68223cef..575527cf 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -1,3 +1,5 @@ +from itertools import chain + import reframe as rfm import reframe.utility.sanity as sn # Added only to make the linter happy @@ -14,7 +16,7 @@ class EESSI_PyTorch_torchvision(rfm.RunOnlyRegressionTest): parallel_strategy = parameter([None, 'ddp']) compute_device = variable(str) # Both torchvision and PyTorch-bundle modules have everything needed to run this test - module_name = parameter(find_modules('torchvision') + find_modules('PyTorch-bundle')) + module_name = parameter(chain(find_modules('torchvision'), find_modules('PyTorch-bundle'))) descr = 'Benchmark that runs a selected torchvision model on synthetic data' From 8298e6a190645752b70b60e4a086a511566e866e Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 1 Jul 2024 12:24:35 +0200 Subject: [PATCH 18/24] Fix comments from Review Sam --- eessi/testsuite/hooks.py | 5 +++-- .../tests/apps/PyTorch/PyTorch_torchvision.py | 21 +++---------------- ..._get_free_socket.py => get_free_socket.py} | 0 3 files changed, 6 insertions(+), 20 deletions(-) rename eessi/testsuite/tests/apps/PyTorch/src/{python_get_free_socket.py => get_free_socket.py} (100%) diff --git a/eessi/testsuite/hooks.py b/eessi/testsuite/hooks.py index a354cba8..fd82087f 100644 --- a/eessi/testsuite/hooks.py +++ b/eessi/testsuite/hooks.py @@ -58,7 +58,8 @@ def _assign_default_num_gpus_per_node(test: rfm.RegressionTest): def assign_tasks_per_compute_unit(test: rfm.RegressionTest, compute_unit: str, num_per: int = 1): """ - Assign one task per compute unit. + Assign one task per compute unit. More than 1 task per compute unit can be assigned with + num_per for compute units that support it. Automatically sets num_tasks, num_tasks_per_node, num_cpus_per_task, and num_gpus_per_node, based on the current scale and the current partition’s num_cpus, max_avail_gpus_per_node and num_nodes. For GPU tests, one task per GPU is set, and num_cpus_per_task is based on the ratio of CPU-cores/GPUs. @@ -80,7 +81,7 @@ def assign_tasks_per_compute_unit(test: rfm.RegressionTest, compute_unit: str, n - assign_tasks_per_compute_unit(test, COMPUTE_UNIT[CPU_SOCKET]) will launch 2 tasks with 64 threads per task """ - if num_per != 1 and compute_unit in [COMPUTE_UNIT[GPU], COMPUTE_UNIT[CPU], COMPUTE_UNIT[CPU_SOCKET]]: + if num_per != 1 and compute_unit not in [COMPUTE_UNIT[NODE]]: raise NotImplementedError( f'Non-default num_per {num_per} is not implemented for compute_unit {compute_unit}.') diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index 575527cf..2235ac36 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -63,19 +63,18 @@ def apply_setup_hooks(self): if self.compute_device == DEVICE_TYPES[GPU]: hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[GPU]) else: - # Hybrid code, so launch 1 rank per socket. - # Probably, launching 1 task per NUMA domain is even better, but the current hook doesn't support it + # Hybrid code, for which launching one task per NUMA_NODE is typically the most efficient hooks.assign_tasks_per_compute_unit(test=self, compute_unit=COMPUTE_UNIT[NUMA_NODE]) # This is a hybrid test, binding is important for performance hooks.set_compact_process_binding(self) @run_after('setup') - def set_ddp_env_vars(self): + def set_ddp_options(self): # Set environment variables for PyTorch DDP if self.parallel_strategy == 'ddp': # Set additional options required by DDP - self.executable_opts += ["--master-port $(python python_get_free_socket.py)"] + self.executable_opts += ["--master-port $(python get_free_socket.py)"] self.executable_opts += ["--master-address $(hostname --fqdn)"] self.executable_opts += ["--world-size %s" % self.num_tasks] @@ -96,15 +95,6 @@ def pass_parallel_strategy(self): if self.num_tasks != 1: self.executable_opts += ['--use-%s' % self.parallel_strategy] - @run_after('setup') - def avoid_horovod_cpu_contention(self): - # Horovod had issues with CPU performance, see https://github.com/horovod/horovod/issues/2804 - # The root cause is Horovod having two threads with very high utilization, which interferes with - # the compute threads. It was fixed, but seems to be broken again in Horovod 0.28.1 - # The easiest workaround is to reduce the number of compute threads by 2 - if self.compute_device == DEVICE_TYPES[CPU] and self.parallel_strategy == 'horovod': - self.env_vars['OMP_NUM_THREADS'] = max(self.num_cpus_per_task - 2, 2) # Never go below 2 compute threads - @sanity_function def assert_num_ranks(self): '''Assert that the number of reported CPUs/GPUs used is correct''' @@ -140,8 +130,3 @@ def prepare_gpu_test(self): if self.precision == 'mixed': self.executable_opts += ['--use-amp'] - @run_after('init') - def skip_hvd_plus_amp(self): - '''Skip combination of horovod and AMP, it does not work see https://github.com/horovod/horovod/issues/1417''' - if self.parallel_strategy == 'horovod' and self.precision == 'mixed': - self.valid_systems = [INVALID_SYSTEM] diff --git a/eessi/testsuite/tests/apps/PyTorch/src/python_get_free_socket.py b/eessi/testsuite/tests/apps/PyTorch/src/get_free_socket.py similarity index 100% rename from eessi/testsuite/tests/apps/PyTorch/src/python_get_free_socket.py rename to eessi/testsuite/tests/apps/PyTorch/src/get_free_socket.py From af30b642af7d96d678eeaae0658a8a1e5b5e8f38 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 1 Jul 2024 12:48:28 +0200 Subject: [PATCH 19/24] Make linter happy --- eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index 2235ac36..37630970 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -6,7 +6,7 @@ from reframe.core.builtins import parameter, variable, run_after, sanity_function, performance_function from eessi.testsuite import hooks -from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, NUMA_NODE, GPU, INVALID_SYSTEM +from eessi.testsuite.constants import SCALES, TAGS, DEVICE_TYPES, COMPUTE_UNIT, CPU, NUMA_NODE, GPU from eessi.testsuite.utils import find_modules @@ -129,4 +129,3 @@ def prepare_gpu_test(self): # Set precision if self.precision == 'mixed': self.executable_opts += ['--use-amp'] - From 7ddeedb11241f9d51ca0c40a069f359190570b17 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Mon, 1 Jul 2024 12:54:41 +0200 Subject: [PATCH 20/24] Remove training whitespace --- eessi/testsuite/hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eessi/testsuite/hooks.py b/eessi/testsuite/hooks.py index fd82087f..553e4d7f 100644 --- a/eessi/testsuite/hooks.py +++ b/eessi/testsuite/hooks.py @@ -59,7 +59,7 @@ def _assign_default_num_gpus_per_node(test: rfm.RegressionTest): def assign_tasks_per_compute_unit(test: rfm.RegressionTest, compute_unit: str, num_per: int = 1): """ Assign one task per compute unit. More than 1 task per compute unit can be assigned with - num_per for compute units that support it. + num_per for compute units that support it. Automatically sets num_tasks, num_tasks_per_node, num_cpus_per_task, and num_gpus_per_node, based on the current scale and the current partition’s num_cpus, max_avail_gpus_per_node and num_nodes. For GPU tests, one task per GPU is set, and num_cpus_per_task is based on the ratio of CPU-cores/GPUs. From d62443b10a19371507a3fc5d76c0abac1e2534d8 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Thu, 25 Jul 2024 16:32:37 +0200 Subject: [PATCH 21/24] Add set_omp_num_threads hook from https://github.com/EESSI/test-suite/pull/133 --- eessi/testsuite/hooks.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/eessi/testsuite/hooks.py b/eessi/testsuite/hooks.py index 553e4d7f..260a94b2 100644 --- a/eessi/testsuite/hooks.py +++ b/eessi/testsuite/hooks.py @@ -2,6 +2,7 @@ Hooks for adding tags, filtering and setting job resources in ReFrame tests """ import math +import os import shlex import warnings @@ -680,6 +681,15 @@ def set_compact_thread_binding(test: rfm.RegressionTest): log(f'Set environment variable KMP_AFFINITY to {test.env_vars["KMP_AFFINITY"]}') +def set_omp_num_threads(test: rfm.RegressionTest): + """ + Set default number of OpenMP threads equal to number of CPUs per task, + unless OMP_NUM_THREADS is already set + """ + test.env_vars['OMP_NUM_THREADS'] = os.getenv('OMP_NUM_THREADS', test.num_cpus_per_task) + log(f'Set environment variable OMP_NUM_THREADS to {test.env_vars["OMP_NUM_THREADS"]}') + + def _check_always_request_gpus(test: rfm.RegressionTest): """ Make sure we always request enough GPUs if required for the current GPU partition (cluster-specific policy) From 00fca31c7904b0e3918e5c5e1a8667567491f9ae Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Thu, 25 Jul 2024 16:34:10 +0200 Subject: [PATCH 22/24] Call hook to set OMP_NUM_THREADS --- eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index 37630970..f5922ca6 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -69,6 +69,9 @@ def apply_setup_hooks(self): # This is a hybrid test, binding is important for performance hooks.set_compact_process_binding(self) + # Set OMP_NUM_THREADS based on the number of cores per task + hooks.set_omp_num_threads(self) + @run_after('setup') def set_ddp_options(self): # Set environment variables for PyTorch DDP From a69e2d3e68be827f0afd5382793507c9cff53998 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Thu, 25 Jul 2024 16:49:38 +0200 Subject: [PATCH 23/24] Revert using the hook, it doesn't make sense to set OMP_NUM_THREADS conditionally, as this would be evaluated on the login node. That environment is irrelevant to the batch job --- eessi/testsuite/hooks.py | 10 ---------- .../tests/apps/PyTorch/PyTorch_torchvision.py | 2 +- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/eessi/testsuite/hooks.py b/eessi/testsuite/hooks.py index 260a94b2..553e4d7f 100644 --- a/eessi/testsuite/hooks.py +++ b/eessi/testsuite/hooks.py @@ -2,7 +2,6 @@ Hooks for adding tags, filtering and setting job resources in ReFrame tests """ import math -import os import shlex import warnings @@ -681,15 +680,6 @@ def set_compact_thread_binding(test: rfm.RegressionTest): log(f'Set environment variable KMP_AFFINITY to {test.env_vars["KMP_AFFINITY"]}') -def set_omp_num_threads(test: rfm.RegressionTest): - """ - Set default number of OpenMP threads equal to number of CPUs per task, - unless OMP_NUM_THREADS is already set - """ - test.env_vars['OMP_NUM_THREADS'] = os.getenv('OMP_NUM_THREADS', test.num_cpus_per_task) - log(f'Set environment variable OMP_NUM_THREADS to {test.env_vars["OMP_NUM_THREADS"]}') - - def _check_always_request_gpus(test: rfm.RegressionTest): """ Make sure we always request enough GPUs if required for the current GPU partition (cluster-specific policy) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index f5922ca6..890be234 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -70,7 +70,7 @@ def apply_setup_hooks(self): hooks.set_compact_process_binding(self) # Set OMP_NUM_THREADS based on the number of cores per task - hooks.set_omp_num_threads(self) + test.env_vars["OMP_NUM_THREADS"] = self.num_cpus_per_task @run_after('setup') def set_ddp_options(self): From 4c5c3e7df0dd1d1ae714da15c557ec360beb6b45 Mon Sep 17 00:00:00 2001 From: Caspar van Leeuwen Date: Thu, 25 Jul 2024 16:51:23 +0200 Subject: [PATCH 24/24] test is not defined, should be 'self' --- eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py index 890be234..13171143 100644 --- a/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py +++ b/eessi/testsuite/tests/apps/PyTorch/PyTorch_torchvision.py @@ -70,7 +70,7 @@ def apply_setup_hooks(self): hooks.set_compact_process_binding(self) # Set OMP_NUM_THREADS based on the number of cores per task - test.env_vars["OMP_NUM_THREADS"] = self.num_cpus_per_task + self.env_vars["OMP_NUM_THREADS"] = self.num_cpus_per_task @run_after('setup') def set_ddp_options(self):