Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Crivella committed Mar 1, 2024
1 parent 00a6aee commit 2f5aa33
Showing 1 changed file with 43 additions and 25 deletions.
68 changes: 43 additions & 25 deletions hpctestlib/sciapps/qespresso/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import os
from typing import TypeVar

R = TypeVar('R')

import reframe as rfm
import reframe.utility.sanity as sn
from reframe.core.builtins import (performance_function, run_after, run_before,
Expand All @@ -12,6 +10,8 @@
from reframe.core.parameters import TestParam as parameter
from reframe.core.variables import TestVar as variable

R = TypeVar('R')

INPUT_TEMPLATE = """&CONTROL
calculation = "scf",
prefix = "Si",
Expand All @@ -26,7 +26,7 @@
nat = 2,
ntyp = 1,
nbnd = {nbnd}
ecutwfc = {ecut}
ecutwfc = {ecut}
/
&ELECTRONS
conv_thr = 1.D-8,
Expand All @@ -42,6 +42,7 @@
"""


@rfm.simple_test
class QEspressoPWCheck(rfm.RunOnlyRegressionTest):
"""QuantumESPRESSO benchmark test.
Expand All @@ -53,9 +54,10 @@ class QEspressoPWCheck(rfm.RunOnlyRegressionTest):
The benchmarks consist on a set of different inputs files ..."""

# Tests the performance of the FFTW algorithm, higher ecut -> more FFTs
ecut = parameter([50,150], loggable=True)
# Tests the performance of the diagonalization algorithm, higher nbnd -> bigger matrices
nbnd = parameter([10,200], loggable=True)
ecut = parameter([50, 150], loggable=True)
# Tests the performance of the diagonalization algorithm,
# higher nbnd -> bigger matrices
nbnd = parameter([10, 200], loggable=True)

executable = 'pw.x'
tags = {'sciapp', 'chemistry'}
Expand All @@ -68,26 +70,29 @@ class QEspressoPWCheck(rfm.RunOnlyRegressionTest):
@run_after('init')
def prepare_test(self):
"""Hook to the set the downloading of the pseudo-potentials"""
self.prerun_cmds = [
f'wget -q http://pseudopotentials.quantum-espresso.org/upf_files/{self.pp_name}'
]
self.prerun_cmds = [(
'wget -q http://pseudopotentials.quantum-espresso.org/upf_files/'
f'{self.pp_name}'
)]
self.executable_opts += [f'-in {self.input_name}']

@run_after('setup')
def write_input(self):
"""Write the input file for the calculation"""
inp_file = os.path.join(self.stagedir, self.input_name)
with open(inp_file, 'w', encoding='utf-8') as file:
file.write(INPUT_TEMPLATE.format(
ecut=self.ecut,
nbnd=self.nbnd,
pseudo=self.pp_name,
file.write(
INPUT_TEMPLATE.format(
ecut=self.ecut,
nbnd=self.nbnd,
pseudo=self.pp_name,
))

@staticmethod
@sn.deferrable
def extractsingle_or_val(*args, on_except_value: str = '0s') -> str:
"""Wrap extractsingle_or_val to return a default value if the regex is not found
"""Wrap extractsingle_or_val to return a default value if the regex is
not found.
Returns:
str: The value of the regular expression
Expand All @@ -96,7 +101,7 @@ def extractsingle_or_val(*args, on_except_value: str = '0s') -> str:
res = sn.extractsingle(*args).evaluate()
except SanityError:
res = on_except_value

return res

@staticmethod
Expand All @@ -106,23 +111,29 @@ def convert_timings(timing: str) -> float:

if timing is None:
return 0


days, timing = (['0', '0'] + timing.split('d'))[-2:]
hours, timing = (['0', '0'] + timing.split('h'))[-2:]
minutes, timing = (['0', '0'] + timing.split('m'))[-2:]
seconds = timing.split('s')[0]

return float(days) * 86400 + float(hours) * 3600 + float(minutes) * 60 + float(seconds)
return (
float(days) * 86400 +
float(hours) * 3600 +
float(minutes) * 60 +
float(seconds)
)


@performance_function('s')
def extract_report_time(self, name: str = None, kind: str = None) -> float:
"""Extract timings from pw.x stdout
Args:
name (str, optional): Name of the timing to extract. Defaults to None.
kind (str, optional): Kind ('cpu' or 'wall) of timing to extract. Defaults to None.
name (str, optional): Name of the timing to extract.
Defaults to None.
kind (str, optional): Kind ('cpu' or 'wall) of timing to extract.
Defaults to None.
Raises:
ValueError: If the kind is not 'cpu' or 'wall'
Expand All @@ -143,17 +154,24 @@ def extract_report_time(self, name: str = None, kind: str = None) -> float:
# Possible formats
# PWSCF : 4d 6h19m CPU 10d14h38m WALL
# --> (Should also catch spaces)
return self.convert_timings(self.extractsingle_or_val(
fr'{name}\s+:\s+(.+)\s+CPU\s+(.+)\s+WALL', self.stdout, tag, str
))
return self.convert_timings(
self.extractsingle_or_val(
fr'{name}\s+:\s+(.+)\s+CPU\s+(.+)\s+WALL',
self.stdout, tag, str
))

@run_before('performance')
def set_perf_variables(self):
"""Build a dictionary of performance variables"""

for name in ['PWSCF', 'electrons', 'c_bands', 'cegterg', 'calbec', 'fft', 'ffts', 'fftw']:
timings = [
'PWSCF', 'electrons', 'c_bands', 'cegterg', 'calbec',
'fft', 'ffts', 'fftw'
]
for name in timings:
for kind in ['cpu', 'wall']:
self.perf_variables[f'{name}_{kind}'] = self.extract_report_time(name, kind)
res = self.extract_report_time(name, kind)
self.perf_variables[f'{name}_{kind}'] = res

@sanity_function
def assert_job_finished(self):
Expand Down

0 comments on commit 2f5aa33

Please sign in to comment.