From 0670c225133b08a3ad4e195bc2554b2c82795606 Mon Sep 17 00:00:00 2001 From: Leo Schick Date: Wed, 22 Nov 2023 10:11:16 +0100 Subject: [PATCH] add max_retries for parallel tasks --- mara_pipelines/parallel_tasks/files.py | 31 ++++++++++++++++--------- mara_pipelines/parallel_tasks/python.py | 6 +++-- mara_pipelines/parallel_tasks/sql.py | 3 ++- mara_pipelines/pipelines.py | 4 +++- 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/mara_pipelines/parallel_tasks/files.py b/mara_pipelines/parallel_tasks/files.py index 1e09c90..8ab1798 100644 --- a/mara_pipelines/parallel_tasks/files.py +++ b/mara_pipelines/parallel_tasks/files.py @@ -34,10 +34,12 @@ def __init__(self, id: str, description: str, file_pattern: str, read_mode: Read max_number_of_parallel_tasks: Optional[int] = None, file_dependencies: Optional[List[str]] = None, date_regex: Optional[str] = None, partition_target_table_by_day_id: bool = False, truncate_partitions: bool = False, commands_before: Optional[List[pipelines.Command]] = None, commands_after: Optional[List[pipelines.Command]] = None, - db_alias: Optional[str] = None, timezone: Optional[str] = None) -> None: + db_alias: Optional[str] = None, timezone: Optional[str] = None, + max_retries: Optional[int] = None) -> None: pipelines.ParallelTask.__init__(self, id=id, description=description, max_number_of_parallel_tasks=max_number_of_parallel_tasks, - commands_before=commands_before, commands_after=commands_after) + commands_before=commands_before, commands_after=commands_after, + max_retries=max_retries) self.file_pattern = file_pattern self.read_mode = read_mode self.date_regex = date_regex @@ -139,12 +141,14 @@ def update_file_dependencies(): id='create_partitions', description='Creates required target table partitions', commands=[sql.ExecuteSQL(sql_statement='\n'.join(slice), echo_queries=False, db_alias=self.db_alias) - for slice in more_itertools.sliced(sql_statements, 50)]) + for slice in more_itertools.sliced(sql_statements, 50)], + max_retries=self.max_retries) sub_pipeline.add(create_partitions_task) for n, chunk in enumerate(more_itertools.chunked(files_per_day.items(), chunk_size)): - task = pipelines.Task(id=str(n), description='Reads a portion of the files') + task = pipelines.Task(id=str(n), description='Reads a portion of the files', + max_retries=self.max_retries) for (day, files) in chunk: target_table = self.target_table + '_' + day.strftime("%Y%m%d") for file in files: @@ -155,7 +159,8 @@ def update_file_dependencies(): for n, chunk in enumerate(more_itertools.chunked(files, chunk_size)): sub_pipeline.add( pipelines.Task(id=str(n), description=f'Reads {len(chunk)} files', - commands=sum([self.parallel_commands(x[0]) for x in chunk], []))) + commands=sum([self.parallel_commands(x[0]) for x in chunk], []), + max_retries=self.max_retries)) def parallel_commands(self, file_name: str) -> List[pipelines.Command]: return [self.read_command(file_name)] + ( @@ -180,14 +185,16 @@ def __init__(self, id: str, description: str, file_pattern: str, read_mode: Read mapper_script_file_name: Optional[str] = None, make_unique: bool = False, db_alias: Optional[str] = None, delimiter_char: Optional[str] = None, quote_char: Optional[str] = None, null_value_string: Optional[str] = None, skip_header: Optional[bool] = None, csv_format: bool = False, - timezone: Optional[str] = None, max_number_of_parallel_tasks: Optional[int] = None) -> None: + timezone: Optional[str] = None, max_number_of_parallel_tasks: Optional[int] = None, + max_retries: Optional[int] = None) -> None: _ParallelRead.__init__(self, id=id, description=description, file_pattern=file_pattern, read_mode=read_mode, target_table=target_table, file_dependencies=file_dependencies, date_regex=date_regex, partition_target_table_by_day_id=partition_target_table_by_day_id, truncate_partitions=truncate_partitions, commands_before=commands_before, commands_after=commands_after, db_alias=db_alias, timezone=timezone, - max_number_of_parallel_tasks=max_number_of_parallel_tasks) + max_number_of_parallel_tasks=max_number_of_parallel_tasks, + max_retries=max_retries) self.compression = compression self.mapper_script_file_name = mapper_script_file_name or '' self.make_unique = make_unique @@ -231,16 +238,18 @@ def html_doc_items(self) -> List[Tuple[str, str]]: class ParallelReadSqlite(_ParallelRead): def __init__(self, id: str, description: str, file_pattern: str, read_mode: ReadMode, sql_file_name: str, - target_table: str, file_dependencies: List[str] = None, date_regex: str = None, + target_table: str, file_dependencies: Optional[List[str]] = None, date_regex: Optional[str] = None, partition_target_table_by_day_id: bool = False, truncate_partitions: bool = False, - commands_before: List[pipelines.Command] = None, commands_after: List[pipelines.Command] = None, - db_alias: str = None, timezone=None, max_number_of_parallel_tasks: int = None) -> None: + commands_before: Optional[List[pipelines.Command]] = None, commands_after: Optional[List[pipelines.Command]] = None, + db_alias: Optional[str] = None, timezone=None, max_number_of_parallel_tasks: Optional[int] = None, + max_retries: Optional[int] = None) -> None: _ParallelRead.__init__(self, id=id, description=description, file_pattern=file_pattern, read_mode=read_mode, target_table=target_table, file_dependencies=file_dependencies, date_regex=date_regex, partition_target_table_by_day_id=partition_target_table_by_day_id, truncate_partitions=truncate_partitions, commands_before=commands_before, commands_after=commands_after, db_alias=db_alias, - timezone=timezone, max_number_of_parallel_tasks=max_number_of_parallel_tasks) + timezone=timezone, max_number_of_parallel_tasks=max_number_of_parallel_tasks, + max_retries=max_retries) self.sql_file_name = sql_file_name def read_command(self, file_name: str) -> List[pipelines.Command]: diff --git a/mara_pipelines/parallel_tasks/python.py b/mara_pipelines/parallel_tasks/python.py index b274f77..bf9c4e7 100644 --- a/mara_pipelines/parallel_tasks/python.py +++ b/mara_pipelines/parallel_tasks/python.py @@ -27,7 +27,8 @@ def add_parallel_tasks(self, sub_pipeline: 'pipelines.Pipeline') -> None: sub_pipeline.add(pipelines.Task( id='_'.join([re.sub('[^0-9a-z\-_]+', '', str(x).lower().replace('-', '_')) for x in parameter_tuple]), description=f'Runs the script with parameters {repr(parameter_tuple)}', - commands=[python.ExecutePython(file_name=self.file_name, args=list(parameter_tuple))])) + commands=[python.ExecutePython(file_name=self.file_name, args=list(parameter_tuple))], + max_retries=self.max_retries)) def html_doc_items(self) -> List[Tuple[str, str]]: path = self.parent.base_path() / self.file_name @@ -58,7 +59,8 @@ def add_parallel_tasks(self, sub_pipeline: 'pipelines.Pipeline') -> None: sub_pipeline.add(pipelines.Task( id=str(parameter).lower().replace(' ', '_').replace('-', '_'), description=f'Runs the function with parameters {repr(parameter)}', - commands=[python.RunFunction(lambda args=parameter: self.function(args))])) + commands=[python.RunFunction(lambda args=parameter: self.function(args))], + max_retries=self.max_retries)) def html_doc_items(self) -> List[Tuple[str, str]]: return [('function', _.pre[escape(str(self.function))]), diff --git a/mara_pipelines/parallel_tasks/sql.py b/mara_pipelines/parallel_tasks/sql.py index 9ec8f84..1006653 100644 --- a/mara_pipelines/parallel_tasks/sql.py +++ b/mara_pipelines/parallel_tasks/sql.py @@ -51,7 +51,8 @@ def add_parallel_tasks(self, sub_pipeline: 'pipelines.Pipeline') -> None: echo_queries=self.echo_queries, timezone=self.timezone, replace=replace) if self.sql_file_name else sql.ExecuteSQL(sql_statement=self.sql_statement, db_alias=self.db_alias, - echo_queries=self.echo_queries, timezone=self.timezone, replace=replace)])) + echo_queries=self.echo_queries, timezone=self.timezone, replace=replace)], + max_retries=self.max_retries)) def html_doc_items(self) -> List[Tuple[str, str]]: return [('db', _.tt[self.db_alias])] \ diff --git a/mara_pipelines/pipelines.py b/mara_pipelines/pipelines.py index 7fc538f..fa5e661 100644 --- a/mara_pipelines/pipelines.py +++ b/mara_pipelines/pipelines.py @@ -109,7 +109,8 @@ def run(self): class ParallelTask(Node): def __init__(self, id: str, description: str, max_number_of_parallel_tasks: Optional[int] = None, - commands_before: Optional[List[Command]] = None, commands_after: Optional[List[Command]] = None) -> None: + commands_before: Optional[List[Command]] = None, commands_after: Optional[List[Command]] = None, + max_retries: Optional[int] = None) -> None: super().__init__(id, description) self.commands_before = [] for command in commands_before or []: @@ -117,6 +118,7 @@ def __init__(self, id: str, description: str, max_number_of_parallel_tasks: Opti self.commands_after = [] for command in commands_after or []: self.add_command_after(command) + self.max_retries = max_retries self.max_number_of_parallel_tasks = max_number_of_parallel_tasks def add_command_before(self, command: Command):