diff --git a/airflow/cli/commands/celery_command.py b/airflow/cli/commands/celery_command.py index 2466465c0d6ee..e01f17774d32b 100644 --- a/airflow/cli/commands/celery_command.py +++ b/airflow/cli/commands/celery_command.py @@ -156,24 +156,20 @@ def worker(args): if args.daemon: # Run Celery worker as daemon handle = setup_logging(log_file) - stdout = open(stdout, 'w+') - stderr = open(stderr, 'w+') - if args.umask: - umask = args.umask + with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + if args.umask: + umask = args.umask - ctx = daemon.DaemonContext( - files_preserve=[handle], - umask=int(umask, 8), - stdout=stdout, - stderr=stderr, - ) - with ctx: - sub_proc = _serve_logs(skip_serve_logs) - worker_instance.run(**options) - - stdout.close() - stderr.close() + ctx = daemon.DaemonContext( + files_preserve=[handle], + umask=int(umask, 8), + stdout=stdout_handle, + stderr=stderr_handle, + ) + with ctx: + sub_proc = _serve_logs(skip_serve_logs) + worker_instance.run(**options) else: # Run Celery worker in the same process sub_proc = _serve_logs(skip_serve_logs) diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index 6e4aad811030a..fe2f329fef5db 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -195,7 +195,12 @@ def dag_show(args): def _display_dot_via_imgcat(dot: Dot): data = dot.pipe(format='png') try: - proc = subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) + with subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) as proc: + out, err = proc.communicate(data) + if out: + print(out.decode('utf-8')) + if err: + print(err.decode('utf-8')) except OSError as e: if e.errno == errno.ENOENT: raise SystemExit( @@ -203,11 +208,6 @@ def _display_dot_via_imgcat(dot: Dot): ) else: raise - out, err = proc.communicate(data) - if out: - print(out.decode('utf-8')) - if err: - print(err.decode('utf-8')) def _save_dot_to_file(dot: Dot, filename: str): diff --git a/airflow/cli/commands/info_command.py b/airflow/cli/commands/info_command.py index 2842722ea47ed..7aedf412d48b4 100644 --- a/airflow/cli/commands/info_command.py +++ b/airflow/cli/commands/info_command.py @@ -188,17 +188,17 @@ def __init__(self, anonymizer): def _get_version(cmd: List[str], grep: Optional[bytes] = None): """Return tools version.""" try: - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as proc: + stdoutdata, _ = proc.communicate() + data = [f for f in stdoutdata.split(b"\n") if f] + if grep: + data = [line for line in data if grep in line] + if len(data) != 1: + return "NOT AVAILABLE" + else: + return data[0].decode() except OSError: return "NOT AVAILABLE" - stdoutdata, _ = proc.communicate() - data = [f for f in stdoutdata.split(b"\n") if f] - if grep: - data = [line for line in data if grep in line] - if len(data) != 1: - return "NOT AVAILABLE" - else: - return data[0].decode() @staticmethod def _task_logging_handler(): diff --git a/airflow/cli/commands/kerberos_command.py b/airflow/cli/commands/kerberos_command.py index a75c596183f35..d087ebe4fc223 100644 --- a/airflow/cli/commands/kerberos_command.py +++ b/airflow/cli/commands/kerberos_command.py @@ -34,19 +34,14 @@ def kerberos(args): pid, stdout, stderr, _ = setup_locations( "kerberos", args.pid, args.stdout, args.stderr, args.log_file ) - stdout = open(stdout, 'w+') - stderr = open(stderr, 'w+') - - ctx = daemon.DaemonContext( - pidfile=TimeoutPIDLockFile(pid, -1), - stdout=stdout, - stderr=stderr, - ) - - with ctx: - krb.run(principal=args.principal, keytab=args.keytab) - - stdout.close() - stderr.close() + with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(pid, -1), + stdout=stdout_handle, + stderr=stderr_handle, + ) + + with ctx: + krb.run(principal=args.principal, keytab=args.keytab) else: krb.run(principal=args.principal, keytab=args.keytab) diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py index 100a0f1d0f775..b66dafc032a16 100644 --- a/airflow/cli/commands/scheduler_command.py +++ b/airflow/cli/commands/scheduler_command.py @@ -42,20 +42,15 @@ def scheduler(args): "scheduler", args.pid, args.stdout, args.stderr, args.log_file ) handle = setup_logging(log_file) - stdout = open(stdout, 'w+') - stderr = open(stderr, 'w+') - - ctx = daemon.DaemonContext( - pidfile=TimeoutPIDLockFile(pid, -1), - files_preserve=[handle], - stdout=stdout, - stderr=stderr, - ) - with ctx: - job.run() - - stdout.close() - stderr.close() + with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle: + ctx = daemon.DaemonContext( + pidfile=TimeoutPIDLockFile(pid, -1), + files_preserve=[handle], + stdout=stdout_handle, + stderr=stderr_handle, + ) + with ctx: + job.run() else: signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGTERM, sigint_handler) diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py index c7a9c8bd9ed48..e786eb3656875 100644 --- a/airflow/cli/commands/webserver_command.py +++ b/airflow/cli/commands/webserver_command.py @@ -480,5 +480,5 @@ def monitor_gunicorn(gunicorn_master_pid: int): monitor_gunicorn(gunicorn_master_proc.pid) else: - gunicorn_master_proc = subprocess.Popen(run_args, close_fds=True) - monitor_gunicorn(gunicorn_master_proc.pid) + with subprocess.Popen(run_args, close_fds=True) as gunicorn_master_proc: + monitor_gunicorn(gunicorn_master_proc.pid) diff --git a/airflow/hooks/subprocess.py b/airflow/hooks/subprocess.py index f3317153aa646..409dbb67551fa 100644 --- a/airflow/hooks/subprocess.py +++ b/airflow/hooks/subprocess.py @@ -62,6 +62,7 @@ def pre_exec(): self.log.info('Running command: %s', command) + # pylint: disable=consider-using-with self.sub_process = Popen( # pylint: disable=subprocess-popen-preexec-fn command, stdout=PIPE, diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index ac1201a0ff37b..12113e3f2f3df 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -316,45 +316,45 @@ def _load_modules_from_file(self, filepath, safe_mode): def _load_modules_from_zip(self, filepath, safe_mode): mods = [] - current_zip_file = zipfile.ZipFile(filepath) - for zip_info in current_zip_file.infolist(): - head, _ = os.path.split(zip_info.filename) - mod_name, ext = os.path.splitext(zip_info.filename) - if ext not in [".py", ".pyc"]: - continue - if head: - continue - - if mod_name == '__init__': - self.log.warning("Found __init__.%s at root of %s", ext, filepath) - - self.log.debug("Reading %s from %s", zip_info.filename, filepath) - - if not might_contain_dag(zip_info.filename, safe_mode, current_zip_file): - # todo: create ignore list - # Don't want to spam user with skip messages - if not self.has_logged: - self.has_logged = True - self.log.info( - "File %s:%s assumed to contain no DAGs. Skipping.", filepath, zip_info.filename - ) - continue - - if mod_name in sys.modules: - del sys.modules[mod_name] + with zipfile.ZipFile(filepath) as current_zip_file: + for zip_info in current_zip_file.infolist(): + head, _ = os.path.split(zip_info.filename) + mod_name, ext = os.path.splitext(zip_info.filename) + if ext not in [".py", ".pyc"]: + continue + if head: + continue + + if mod_name == '__init__': + self.log.warning("Found __init__.%s at root of %s", ext, filepath) + + self.log.debug("Reading %s from %s", zip_info.filename, filepath) + + if not might_contain_dag(zip_info.filename, safe_mode, current_zip_file): + # todo: create ignore list + # Don't want to spam user with skip messages + if not self.has_logged: + self.has_logged = True + self.log.info( + "File %s:%s assumed to contain no DAGs. Skipping.", filepath, zip_info.filename + ) + continue + + if mod_name in sys.modules: + del sys.modules[mod_name] - try: - sys.path.insert(0, filepath) - current_module = importlib.import_module(mod_name) - mods.append(current_module) - except Exception as e: # pylint: disable=broad-except - self.log.exception("Failed to import: %s", filepath) - if self.dagbag_import_error_tracebacks: - self.import_errors[filepath] = traceback.format_exc( - limit=-self.dagbag_import_error_traceback_depth - ) - else: - self.import_errors[filepath] = str(e) + try: + sys.path.insert(0, filepath) + current_module = importlib.import_module(mod_name) + mods.append(current_module) + except Exception as e: # pylint: disable=broad-except + self.log.exception("Failed to import: %s", filepath) + if self.dagbag_import_error_tracebacks: + self.import_errors[filepath] = traceback.format_exc( + limit=-self.dagbag_import_error_traceback_depth + ) + else: + self.import_errors[filepath] = str(e) return mods def _process_modules(self, filepath, mods, file_last_changed_on_disk): diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py index 9216f846cf388..4f2409560c953 100644 --- a/airflow/operators/bash.py +++ b/airflow/operators/bash.py @@ -149,7 +149,6 @@ def __init__( self.skip_exit_code = skip_exit_code if kwargs.get('xcom_push') is not None: raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead") - self.sub_process = None @cached_property def subprocess_hook(self): diff --git a/airflow/providers/amazon/aws/operators/s3_file_transform.py b/airflow/providers/amazon/aws/operators/s3_file_transform.py index 65849ec09e814..1084e2bc764f7 100644 --- a/airflow/providers/amazon/aws/operators/s3_file_transform.py +++ b/airflow/providers/amazon/aws/operators/s3_file_transform.py @@ -135,25 +135,24 @@ def execute(self, context): f_source.flush() if self.transform_script is not None: - process = subprocess.Popen( + with subprocess.Popen( [self.transform_script, f_source.name, f_dest.name, *self.script_args], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True, - ) - - self.log.info("Output:") - for line in iter(process.stdout.readline, b''): - self.log.info(line.decode(self.output_encoding).rstrip()) - - process.wait() - - if process.returncode: - raise AirflowException(f"Transform script failed: {process.returncode}") - else: - self.log.info( - "Transform script successful. Output temporarily located at %s", f_dest.name - ) + ) as process: + self.log.info("Output:") + for line in iter(process.stdout.readline, b''): + self.log.info(line.decode(self.output_encoding).rstrip()) + + process.wait() + + if process.returncode: + raise AirflowException(f"Transform script failed: {process.returncode}") + else: + self.log.info( + "Transform script successful. Output temporarily located at %s", f_dest.name + ) self.log.info("Uploading transformed file to S3") f_dest.flush() diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py index 539640aaafdfd..9c63278231ad4 100644 --- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py @@ -121,16 +121,15 @@ def execute(self, context) -> None: table = AwsDynamoDBHook().get_conn().Table(self.dynamodb_table_name) scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {} err = None - f = NamedTemporaryFile() - try: - f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table) - except Exception as e: - err = e - raise e - finally: - if err is None: - _upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix) - f.close() + with NamedTemporaryFile() as f: + try: + f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table) + except Exception as e: + err = e + raise e + finally: + if err is None: + _upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix) def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO: while True: @@ -150,5 +149,6 @@ def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, tabl if getsize(temp_file.name) >= self.file_size: _upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix) temp_file.close() + # pylint: disable=consider-using-with temp_file = NamedTemporaryFile() return temp_file diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py index 2436210e2bd01..1cf7a8ba15a80 100644 --- a/airflow/providers/apache/beam/hooks/beam.py +++ b/airflow/providers/apache/beam/hooks/beam.py @@ -94,6 +94,7 @@ def __init__( self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd)) self.process_line_callback = process_line_callback self.job_id: Optional[str] = None + # pylint: disable=consider-using-with self._proc = subprocess.Popen( cmd, shell=False, diff --git a/airflow/providers/apache/hive/transfers/hive_to_mysql.py b/airflow/providers/apache/hive/transfers/hive_to_mysql.py index f8856c020b818..dfd80f8a5ae8a 100644 --- a/airflow/providers/apache/hive/transfers/hive_to_mysql.py +++ b/airflow/providers/apache/hive/transfers/hive_to_mysql.py @@ -96,29 +96,20 @@ def execute(self, context): if self.hive_conf: hive_conf.update(self.hive_conf) if self.bulk_load: - tmp_file = NamedTemporaryFile() - hive.to_csv( - self.sql, - tmp_file.name, - delimiter='\t', - lineterminator='\n', - output_header=False, - hive_conf=hive_conf, - ) + with NamedTemporaryFile() as tmp_file: + hive.to_csv( + self.sql, + tmp_file.name, + delimiter='\t', + lineterminator='\n', + output_header=False, + hive_conf=hive_conf, + ) + mysql = self._call_preoperator() + mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name) else: hive_results = hive.get_records(self.sql, hive_conf=hive_conf) - - mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) - - if self.mysql_preoperator: - self.log.info("Running MySQL preoperator") - mysql.run(self.mysql_preoperator) - - self.log.info("Inserting rows into MySQL") - if self.bulk_load: - mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name) - tmp_file.close() - else: + mysql = self._call_preoperator() mysql.insert_rows(table=self.mysql_table, rows=hive_results) if self.mysql_postoperator: @@ -126,3 +117,11 @@ def execute(self, context): mysql.run(self.mysql_postoperator) self.log.info("Done.") + + def _call_preoperator(self): + mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) + if self.mysql_preoperator: + self.log.info("Running MySQL preoperator") + mysql.run(self.mysql_preoperator) + self.log.info("Inserting rows into MySQL") + return mysql diff --git a/airflow/providers/apache/hive/transfers/mssql_to_hive.py b/airflow/providers/apache/hive/transfers/mssql_to_hive.py index 090a70285af77..3a277ba0c1832 100644 --- a/airflow/providers/apache/hive/transfers/mssql_to_hive.py +++ b/airflow/providers/apache/hive/transfers/mssql_to_hive.py @@ -101,12 +101,13 @@ def __init__( self.tblproperties = tblproperties @classmethod + # pylint: disable=c-extension-no-member,no-member def type_map(cls, mssql_type: int) -> str: """Maps MsSQL type to Hive type.""" map_dict = { - pymssql.BINARY.value: 'INT', # pylint: disable=c-extension-no-member - pymssql.DECIMAL.value: 'FLOAT', # pylint: disable=c-extension-no-member - pymssql.NUMBER.value: 'INT', # pylint: disable=c-extension-no-member + pymssql.BINARY.value: 'INT', + pymssql.DECIMAL.value: 'FLOAT', + pymssql.NUMBER.value: 'INT', } return map_dict.get(mssql_type, 'STRING') diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py index 984742d0f4585..b20d0262446f0 100644 --- a/airflow/providers/apache/pinot/hooks/pinot.py +++ b/airflow/providers/apache/pinot/hooks/pinot.py @@ -225,26 +225,25 @@ def run_cli(self, cmd: List[str], verbose: bool = True) -> str: if verbose: self.log.info(" ".join(command)) - sub_process = subprocess.Popen( + with subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True, env=env - ) - - stdout = "" - if sub_process.stdout: - for line in iter(sub_process.stdout.readline, b''): - stdout += line.decode("utf-8") - if verbose: - self.log.info(line.decode("utf-8").strip()) - - sub_process.wait() - - # As of Pinot v0.1.0, either of "Error: ..." or "Exception caught: ..." - # is expected to be in the output messages. See: - # https://github.com/apache/incubator-pinot/blob/release-0.1.0/pinot-tools/src/main/java/org/apache/pinot/tools/admin/PinotAdministrator.java#L98-L101 - if (self.pinot_admin_system_exit and sub_process.returncode) or ( - "Error" in stdout or "Exception" in stdout - ): - raise AirflowException(stdout) + ) as sub_process: + stdout = "" + if sub_process.stdout: + for line in iter(sub_process.stdout.readline, b''): + stdout += line.decode("utf-8") + if verbose: + self.log.info(line.decode("utf-8").strip()) + + sub_process.wait() + + # As of Pinot v0.1.0, either of "Error: ..." or "Exception caught: ..." + # is expected to be in the output messages. See: + # https://github.com/apache/incubator-pinot/blob/release-0.1.0/pinot-tools/src/main/java/org/apache/pinot/tools/admin/PinotAdministrator.java#L98-L101 + if (self.pinot_admin_system_exit and sub_process.returncode) or ( + "Error" in stdout or "Exception" in stdout + ): + raise AirflowException(stdout) return stdout diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py index ce3fdbc3f3126..b690f2cf78637 100644 --- a/airflow/providers/apache/spark/hooks/spark_sql.py +++ b/airflow/providers/apache/spark/hooks/spark_sql.py @@ -158,6 +158,7 @@ def run_query(self, cmd: str = "", **kwargs: Any) -> None: :type kwargs: dict """ spark_sql_cmd = self._prepare_command(cmd) + # pylint: disable=consider-using-with self._sp = subprocess.Popen(spark_sql_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs) for line in iter(self._sp.stdout): # type: ignore diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index 62839ba901c8b..6655fc88604ff 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -426,6 +426,7 @@ def submit(self, application: str = "", **kwargs: Any) -> None: env.update(self._env) kwargs["env"] = env + # pylint: disable=consider-using-with self._submit_sp = subprocess.Popen( spark_submit_cmd, stdout=subprocess.PIPE, @@ -644,11 +645,12 @@ def on_kill(self) -> None: self.log.info('Killing driver %s on cluster', self._driver_id) kill_cmd = self._build_spark_driver_kill_command() - driver_kill = subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - - self.log.info( - "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() - ) + with subprocess.Popen( + kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) as driver_kill: + self.log.info( + "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() + ) if self._submit_sp and self._submit_sp.poll() is None: self.log.info('Sending kill signal to %s', self._connection['spark_binary']) @@ -665,11 +667,10 @@ def on_kill(self) -> None: env = os.environ.copy() env["KRB5CCNAME"] = airflow_conf.get('kerberos', 'ccache') - yarn_kill = subprocess.Popen( + with subprocess.Popen( kill_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - - self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) + ) as yarn_kill: + self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) if self._kubernetes_driver_pod: self.log.info('Killing pod %s on Kubernetes', self._kubernetes_driver_pod) diff --git a/airflow/providers/apache/sqoop/hooks/sqoop.py b/airflow/providers/apache/sqoop/hooks/sqoop.py index cab6063370fb4..ed5378af20a07 100644 --- a/airflow/providers/apache/sqoop/hooks/sqoop.py +++ b/airflow/providers/apache/sqoop/hooks/sqoop.py @@ -82,7 +82,6 @@ def __init__( self.num_mappers = num_mappers self.properties = properties or {} self.log.info("Using connection to: %s:%s/%s", self.conn.host, self.conn.port, self.conn.schema) - self.sub_process: Any = None def get_conn(self) -> Any: return self.conn @@ -107,17 +106,13 @@ def popen(self, cmd: List[str], **kwargs: Any) -> None: """ masked_cmd = ' '.join(self.cmd_mask_password(cmd)) self.log.info("Executing command: %s", masked_cmd) - self.sub_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs) - - for line in iter(self.sub_process.stdout): # type: ignore - self.log.info(line.strip()) - - self.sub_process.wait() - - self.log.info("Command exited with return code %s", self.sub_process.returncode) - - if self.sub_process.returncode: - raise AirflowException(f"Sqoop command failed: {masked_cmd}") + with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs) as sub_process: + for line in iter(sub_process.stdout): # type: ignore + self.log.info(line.strip()) + sub_process.wait() + self.log.info("Command exited with return code %s", sub_process.returncode) + if sub_process.returncode: + raise AirflowException(f"Sqoop command failed: {masked_cmd}") def _prepare_command(self, export: bool = False) -> List[str]: sqoop_cmd_type = "export" if export else "import" diff --git a/airflow/providers/ftp/hooks/ftp.py b/airflow/providers/ftp/hooks/ftp.py index 48996d28128b6..a03e461a6169a 100644 --- a/airflow/providers/ftp/hooks/ftp.py +++ b/airflow/providers/ftp/hooks/ftp.py @@ -174,6 +174,7 @@ def write_to_file_with_progress(data): # file-like buffer if not callback: if is_path: + # pylint: disable=consider-using-with output_handle = open(local_full_path_or_buffer, 'wb') else: output_handle = local_full_path_or_buffer @@ -209,6 +210,7 @@ def store_file(self, remote_full_path: str, local_full_path_or_buffer: Any) -> N is_path = isinstance(local_full_path_or_buffer, str) if is_path: + # pylint: disable=consider-using-with input_handle = open(local_full_path_or_buffer, 'rb') else: input_handle = local_full_path_or_buffer diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py b/airflow/providers/google/cloud/hooks/cloud_sql.py index 51812b8878747..0ccc0ebdc3945 100644 --- a/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -567,6 +567,7 @@ def start_proxy(self) -> None: Path(self.cloud_sql_proxy_socket_directory).mkdir(parents=True, exist_ok=True) command_to_run.extend(self._get_credential_parameters()) # pylint: disable=no-value-for-parameter self.log.info("Running the command: `%s`", " ".join(command_to_run)) + # pylint: disable=consider-using-with self.sql_proxy_process = Popen(command_to_run, stdin=PIPE, stdout=PIPE, stderr=PIPE) self.log.info("The pid of cloud_sql_proxy: %s", self.sql_proxy_process.pid) while True: diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index 91a8b57627ea3..8c8f241690fba 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -638,17 +638,17 @@ def execute(self, context: dict) -> None: self.log.info("Starting the transformation") cmd = [self.transform_script] if isinstance(self.transform_script, str) else self.transform_script cmd += [source_file.name, destination_file.name] - process = subprocess.Popen( + with subprocess.Popen( args=cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True - ) - self.log.info("Process output:") - if process.stdout: - for line in iter(process.stdout.readline, b''): - self.log.info(line.decode(self.output_encoding).rstrip()) + ) as process: + self.log.info("Process output:") + if process.stdout: + for line in iter(process.stdout.readline, b''): + self.log.info(line.decode(self.output_encoding).rstrip()) - process.wait() - if process.returncode: - raise AirflowException(f"Transform script failed: {process.returncode}") + process.wait() + if process.returncode: + raise AirflowException(f"Transform script failed: {process.returncode}") self.log.info("Transformation succeeded. Output temporarily located at %s", destination_file.name) @@ -865,17 +865,17 @@ def execute(self, context: dict) -> None: timespan_start.replace(microsecond=0).isoformat(), timespan_end.replace(microsecond=0).isoformat(), ] - process = subprocess.Popen( + with subprocess.Popen( args=cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True - ) - self.log.info("Process output:") - if process.stdout: - for line in iter(process.stdout.readline, b''): - self.log.info(line.decode(self.output_encoding).rstrip()) - - process.wait() - if process.returncode: - raise AirflowException(f"Transform script failed: {process.returncode}") + ) as process: + self.log.info("Process output:") + if process.stdout: + for line in iter(process.stdout.readline, b''): + self.log.info(line.decode(self.output_encoding).rstrip()) + + process.wait() + if process.returncode: + raise AirflowException(f"Transform script failed: {process.returncode}") self.log.info("Transformation succeeded. Output temporarily located at %s", temp_output_dir) diff --git a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py index db83f9b5a47cd..bc348db9f7cf9 100644 --- a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py @@ -185,6 +185,7 @@ def execute(self, context: Dict[str, str]): # Close all sessions and connection associated with this Cassandra cluster hook.shutdown_cluster() + # pylint: disable=consider-using-with def _write_local_data_files(self, cursor): """ Takes a cursor, and writes results to a local file. @@ -211,6 +212,7 @@ def _write_local_data_files(self, cursor): return tmp_file_handles + # pylint: disable=consider-using-with def _write_local_schema_file(self, cursor): """ Takes a cursor, and writes the BigQuery schema for the results to a diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py index ec77112c3da7d..e62364a073851 100644 --- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py @@ -184,6 +184,7 @@ def _write_local_data_files(self, cursor): schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description)) col_type_dict = self._get_col_type_dict() file_no = 0 + # pylint: disable=consider-using-with tmp_file_handle = NamedTemporaryFile(delete=True) if self.export_format == 'csv': file_mime_type = 'text/csv' @@ -234,6 +235,7 @@ def _write_local_data_files(self, cursor): # Stop if the file exceeds the file size limit. if tmp_file_handle.tell() >= self.approx_max_file_size_bytes: file_no += 1 + # pylint: disable=consider-using-with tmp_file_handle = NamedTemporaryFile(delete=True) files_to_upload.append( { @@ -337,6 +339,7 @@ def _write_local_schema_file(self, cursor): self.log.info('Using schema for %s', self.schema_filename) self.log.debug("Current schema: %s", schema) + # pylint: disable=consider-using-with tmp_schema_file_handle = NamedTemporaryFile(delete=True) tmp_schema_file_handle.write(schema.encode('utf-8')) schema_file_to_upload = { diff --git a/airflow/providers/microsoft/mssql/hooks/mssql.py b/airflow/providers/microsoft/mssql/hooks/mssql.py index 928d0c45d3fec..eebce5eb466cd 100644 --- a/airflow/providers/microsoft/mssql/hooks/mssql.py +++ b/airflow/providers/microsoft/mssql/hooks/mssql.py @@ -38,13 +38,13 @@ def __init__(self, *args, **kwargs) -> None: def get_conn( self, - ) -> pymssql.connect: # pylint: disable=protected-access # pylint: disable=c-extension-no-member + ) -> pymssql.connect: # pylint: disable=protected-access,c-extension-no-member,no-member """Returns a mssql connection object""" conn = self.get_connection( self.mssql_conn_id # type: ignore[attr-defined] # pylint: disable=no-member ) # pylint: disable=c-extension-no-member - conn = pymssql.connect( + conn = pymssql.connect( # pylint: disable=no-member server=conn.host, user=conn.login, password=conn.password, @@ -55,10 +55,10 @@ def get_conn( def set_autocommit( self, - conn: pymssql.connect, # pylint: disable=c-extension-no-member + conn: pymssql.connect, # pylint: disable=c-extension-no-member, no-member autocommit: bool, ) -> None: conn.autocommit(autocommit) - def get_autocommit(self, conn: pymssql.connect): # pylint: disable=c-extension-no-member + def get_autocommit(self, conn: pymssql.connect): # pylint: disable=c-extension-no-member, no-member return conn.autocommit_state diff --git a/airflow/providers/mysql/transfers/vertica_to_mysql.py b/airflow/providers/mysql/transfers/vertica_to_mysql.py index 85da9ad787476..819cc26bb69ad 100644 --- a/airflow/providers/mysql/transfers/vertica_to_mysql.py +++ b/airflow/providers/mysql/transfers/vertica_to_mysql.py @@ -104,17 +104,16 @@ def execute(self, context): selected_columns = [d.name for d in cursor.description] if self.bulk_load: - tmpfile = NamedTemporaryFile("w") + with NamedTemporaryFile("w") as tmpfile: + self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name) + self.log.info(self.sql) - self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name) - self.log.info(self.sql) - - csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') - for row in cursor.iterate(): - csv_writer.writerow(row) - count += 1 + csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') + for row in cursor.iterate(): + csv_writer.writerow(row) + count += 1 - tmpfile.flush() + tmpfile.flush() else: self.log.info("Selecting rows from Vertica...") self.log.info(self.sql) diff --git a/airflow/providers/qubole/hooks/qubole.py b/airflow/providers/qubole/hooks/qubole.py index 6655de5f0f6d8..99e6027408709 100644 --- a/airflow/providers/qubole/hooks/qubole.py +++ b/airflow/providers/qubole/hooks/qubole.py @@ -202,6 +202,7 @@ def kill(self, ti): self.log.info('Sending KILL signal to Qubole Command Id: %s', self.cmd.id) self.cmd.cancel() + # pylint: disable=consider-using-with def get_results(self, ti=None, fp=None, inline: bool = True, delim=None, fetch: bool = True) -> str: """ Get results (or just s3 locations) of a command from Qubole and save into a file diff --git a/airflow/security/kerberos.py b/airflow/security/kerberos.py index f1ddfdf3c3845..299c327fe91e8 100644 --- a/airflow/security/kerberos.py +++ b/airflow/security/kerberos.py @@ -73,26 +73,26 @@ def renew_from_kt(principal: str, keytab: str, exit_on_fail: bool = True): ] log.info("Re-initialising kerberos from keytab: %s", " ".join(cmdv)) - subp = subprocess.Popen( + with subprocess.Popen( cmdv, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True, bufsize=-1, universal_newlines=True, - ) - subp.wait() - if subp.returncode != 0: - log.error( - "Couldn't reinit from keytab! `kinit' exited with %s.\n%s\n%s", - subp.returncode, - "\n".join(subp.stdout.readlines() if subp.stdout else []), - "\n".join(subp.stderr.readlines() if subp.stderr else []), - ) - if exit_on_fail: - sys.exit(subp.returncode) - else: - return subp.returncode + ) as subp: + subp.wait() + if subp.returncode != 0: + log.error( + "Couldn't reinit from keytab! `kinit' exited with %s.\n%s\n%s", + subp.returncode, + "\n".join(subp.stdout.readlines() if subp.stdout else []), + "\n".join(subp.stderr.readlines() if subp.stderr else []), + ) + if exit_on_fail: + sys.exit(subp.returncode) + else: + return subp.returncode global NEED_KRB181_WORKAROUND # pylint: disable=global-statement if NEED_KRB181_WORKAROUND is None: diff --git a/airflow/sensors/bash.py b/airflow/sensors/bash.py index 1b428049c699d..21c6ea98c7248 100644 --- a/airflow/sensors/bash.py +++ b/airflow/sensors/bash.py @@ -66,7 +66,8 @@ def poke(self, context): script_location = tmp_dir + "/" + fname self.log.info("Temporary script location: %s", script_location) self.log.info("Running command: %s", bash_command) - resp = Popen( # pylint: disable=subprocess-popen-preexec-fn + # pylint: disable=subprocess-popen-preexec-fn + with Popen( ['bash', fname], stdout=PIPE, stderr=STDOUT, @@ -74,13 +75,12 @@ def poke(self, context): cwd=tmp_dir, env=self.env, preexec_fn=os.setsid, - ) + ) as resp: + self.log.info("Output:") + for line in iter(resp.stdout.readline, b''): + line = line.decode(self.output_encoding).strip() + self.log.info(line) + resp.wait() + self.log.info("Command exited with return code %s", resp.returncode) - self.log.info("Output:") - for line in iter(resp.stdout.readline, b''): - line = line.decode(self.output_encoding).strip() - self.log.info(line) - resp.wait() - self.log.info("Command exited with return code %s", resp.returncode) - - return not resp.returncode + return not resp.returncode diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py index 07e862537891d..c8035b04a86e2 100644 --- a/airflow/task/task_runner/base_task_runner.py +++ b/airflow/task/task_runner/base_task_runner.py @@ -84,6 +84,7 @@ def __init__(self, local_task_job): # - the runner can read/execute those values as it needs cfg_path = tmp_configuration_copy(chmod=0o600) + # pylint: disable=consider-using-with self._error_file = NamedTemporaryFile(delete=True) self._cfg_path = cfg_path self._command = ( @@ -132,7 +133,7 @@ def run_command(self, run_with=None): self.log.info("Running on host: %s", get_hostname()) self.log.info('Running: %s', full_cmd) - # pylint: disable=subprocess-popen-preexec-fn + # pylint: disable=subprocess-popen-preexec-fn,consider-using-with proc = subprocess.Popen( full_cmd, stdout=subprocess.PIPE, diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 96515a03633e5..c5dca17d53793 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -41,6 +41,7 @@ def TemporaryDirectory(*args, **kwargs): # pylint: disable=invalid-name DeprecationWarning, stacklevel=2, ) + # pylint: disable=consider-using-with return TmpDir(*args, **kwargs) @@ -90,6 +91,7 @@ def open_maybe_zipped(fileloc, mode='r'): if archive and zipfile.is_zipfile(archive): return io.TextIOWrapper(zipfile.ZipFile(archive, mode=mode).open(filename)) else: + # pylint: disable=consider-using-with return open(fileloc, mode=mode) diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py index 38607bddeeed7..b76ca7dd715ed 100644 --- a/airflow/utils/process_utils.py +++ b/airflow/utils/process_utils.py @@ -133,14 +133,16 @@ def execute_in_subprocess(cmd: List[str]): :type cmd: List[str] """ log.info("Executing cmd: %s", " ".join([shlex.quote(c) for c in cmd])) - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=0, close_fds=True) - log.info("Output:") - if proc.stdout: - with proc.stdout: - for line in iter(proc.stdout.readline, b''): - log.info("%s", line.decode().rstrip()) - - exit_code = proc.wait() + with subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=0, close_fds=True + ) as proc: + log.info("Output:") + if proc.stdout: + with proc.stdout: + for line in iter(proc.stdout.readline, b''): + log.info("%s", line.decode().rstrip()) + + exit_code = proc.wait() if exit_code != 0: raise subprocess.CalledProcessError(exit_code, cmd) @@ -160,19 +162,18 @@ def execute_interactive(cmd: List[str], **kwargs): master_fd, slave_fd = pty.openpty() try: # pylint: disable=too-many-nested-blocks # use os.setsid() make it run in a new process group, or bash job control will not be enabled - proc = subprocess.Popen( + with subprocess.Popen( cmd, stdin=slave_fd, stdout=slave_fd, stderr=slave_fd, universal_newlines=True, **kwargs - ) - - while proc.poll() is None: - readable_fbs, _, _ = select.select([sys.stdin, master_fd], [], []) - if sys.stdin in readable_fbs: - input_data = os.read(sys.stdin.fileno(), 10240) - os.write(master_fd, input_data) - if master_fd in readable_fbs: - output_data = os.read(master_fd, 10240) - if output_data: - os.write(sys.stdout.fileno(), output_data) + ) as proc: + while proc.poll() is None: + readable_fbs, _, _ = select.select([sys.stdin, master_fd], [], []) + if sys.stdin in readable_fbs: + input_data = os.read(sys.stdin.fileno(), 10240) + os.write(master_fd, input_data) + if master_fd in readable_fbs: + output_data = os.read(master_fd, 10240) + if output_data: + os.write(sys.stdout.fileno(), output_data) finally: # restore tty settings back termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty) diff --git a/setup.py b/setup.py index 6582428c7c04e..0139f821bdd80 100644 --- a/setup.py +++ b/setup.py @@ -502,7 +502,7 @@ def get_sphinx_theme_version() -> str: 'paramiko', 'pipdeptree', 'pre-commit', - 'pylint~=2.7.4', + 'pylint~=2.8.1', 'pysftp', 'pytest~=6.0', 'pytest-cov', diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index ed696b083acac..df2ea76f51cb4 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -182,14 +182,17 @@ def test_show_dag_dave(self, mock_render_dag): @mock.patch("airflow.cli.commands.dag_command.render_dag") def test_show_dag_imgcat(self, mock_render_dag, mock_popen): mock_render_dag.return_value.pipe.return_value = b"DOT_DATA" - mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR") + mock_proc = mock.MagicMock() + mock_proc.returncode = 0 + mock_proc.communicate.return_value = (b"OUT", b"ERR") + mock_popen.return_value.__enter__.return_value = mock_proc with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: dag_command.dag_show( self.parser.parse_args(['dags', 'show', 'example_bash_operator', '--imgcat']) ) out = temp_stdout.getvalue() mock_render_dag.return_value.pipe.assert_called_once_with(format='png') - mock_popen.return_value.communicate.assert_called_once_with(b'DOT_DATA') + mock_proc.communicate.assert_called_once_with(b'DOT_DATA') assert "OUT" in out assert "ERR" in out diff --git a/tests/cli/commands/test_user_command.py b/tests/cli/commands/test_user_command.py index 867ff9d91d4d1..9e94045ede15f 100644 --- a/tests/cli/commands/test_user_command.py +++ b/tests/cli/commands/test_user_command.py @@ -248,21 +248,21 @@ def find_by_username(username): def _import_users_from_file(self, user_list): json_file_content = json.dumps(user_list) - f = tempfile.NamedTemporaryFile(delete=False) - try: - f.write(json_file_content.encode()) - f.flush() + with tempfile.NamedTemporaryFile(delete=False) as f: + try: + f.write(json_file_content.encode()) + f.flush() - args = self.parser.parse_args(['users', 'import', f.name]) - user_command.users_import(args) - finally: - os.remove(f.name) + args = self.parser.parse_args(['users', 'import', f.name]) + user_command.users_import(args) + finally: + os.remove(f.name) def _export_users_to_file(self): - f = tempfile.NamedTemporaryFile(delete=False) - args = self.parser.parse_args(['users', 'export', f.name]) - user_command.users_export(args) - return f.name + with tempfile.NamedTemporaryFile(delete=False) as f: + args = self.parser.parse_args(['users', 'export', f.name]) + user_command.users_export(args) + return f.name def test_cli_add_user_role(self): args = self.parser.parse_args( diff --git a/tests/cli/commands/test_variable_command.py b/tests/cli/commands/test_variable_command.py index 8b64ff8a66899..cd497f2f9c491 100644 --- a/tests/cli/commands/test_variable_command.py +++ b/tests/cli/commands/test_variable_command.py @@ -132,30 +132,31 @@ def test_variables_export(self): def test_variables_isolation(self): """Test isolation of variables""" - tmp1 = tempfile.NamedTemporaryFile(delete=True) - tmp2 = tempfile.NamedTemporaryFile(delete=True) + with tempfile.NamedTemporaryFile(delete=True) as tmp1, tempfile.NamedTemporaryFile( + delete=True + ) as tmp2: - # First export - variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', '{"foo":"bar"}'])) - variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'bar', 'original'])) - variable_command.variables_export(self.parser.parse_args(['variables', 'export', tmp1.name])) - - first_exp = open(tmp1.name) + # First export + variable_command.variables_set( + self.parser.parse_args(['variables', 'set', 'foo', '{"foo":"bar"}']) + ) + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'bar', 'original'])) + variable_command.variables_export(self.parser.parse_args(['variables', 'export', tmp1.name])) - variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'bar', 'updated'])) - variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'foo', '{"foo":"oops"}'])) - variable_command.variables_delete(self.parser.parse_args(['variables', 'delete', 'foo'])) - variable_command.variables_import(self.parser.parse_args(['variables', 'import', tmp1.name])) + with open(tmp1.name) as first_exp: - assert 'original' == Variable.get('bar') - assert '{\n "foo": "bar"\n}' == Variable.get('foo') + variable_command.variables_set(self.parser.parse_args(['variables', 'set', 'bar', 'updated'])) + variable_command.variables_set( + self.parser.parse_args(['variables', 'set', 'foo', '{"foo":"oops"}']) + ) + variable_command.variables_delete(self.parser.parse_args(['variables', 'delete', 'foo'])) + variable_command.variables_import(self.parser.parse_args(['variables', 'import', tmp1.name])) - # Second export - variable_command.variables_export(self.parser.parse_args(['variables', 'export', tmp2.name])) + assert 'original' == Variable.get('bar') + assert '{\n "foo": "bar"\n}' == Variable.get('foo') - second_exp = open(tmp2.name) - assert first_exp.read() == second_exp.read() + # Second export + variable_command.variables_export(self.parser.parse_args(['variables', 'export', tmp2.name])) - # Clean up files - second_exp.close() - first_exp.close() + with open(tmp2.name) as second_exp: + assert first_exp.read() == second_exp.read() diff --git a/tests/cli/commands/test_webserver_command.py b/tests/cli/commands/test_webserver_command.py index 36ac81a342373..5e44c39f5ea92 100644 --- a/tests/cli/commands/test_webserver_command.py +++ b/tests/cli/commands/test_webserver_command.py @@ -287,6 +287,7 @@ def test_cli_webserver_foreground(self): AIRFLOW__WEBSERVER__WORKERS="1", ): # Run webserver in foreground and terminate it. + # pylint: disable=consider-using-with proc = subprocess.Popen(["airflow", "webserver"]) assert proc.poll() is None @@ -309,6 +310,7 @@ def test_cli_webserver_foreground_with_pid(self): AIRFLOW__CORE__LOAD_EXAMPLES="False", AIRFLOW__WEBSERVER__WORKERS="1", ): + # pylint: disable=consider-using-with proc = subprocess.Popen(["airflow", "webserver", "--pid", pidfile]) assert proc.poll() is None @@ -334,6 +336,7 @@ def test_cli_webserver_background(self): logfile = f"{tmpdir}/airflow-webserver.log" try: # Run webserver as daemon in background. Note that the wait method is not called. + # pylint: disable=consider-using-with proc = subprocess.Popen( [ "airflow", @@ -409,6 +412,7 @@ def test_cli_webserver_access_log_format(self): ): access_logfile = f"{tmpdir}/access.log" # Run webserver in foreground and terminate it. + # pylint: disable=consider-using-with proc = subprocess.Popen( [ "airflow", @@ -424,11 +428,12 @@ def test_cli_webserver_access_log_format(self): # Wait for webserver process time.sleep(10) + # pylint: disable=consider-using-with proc2 = subprocess.Popen(["curl", "http://localhost:8080"]) proc2.wait(10) try: - file = open(access_logfile) - log = json.loads(file.read()) + with open(access_logfile) as file: + log = json.loads(file.read()) assert '127.0.0.1' == log.get('remote_ip') assert len(log) == 9 assert 'GET' == log.get('request_method') diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index ae03db9cc604a..3c9589640009b 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -130,12 +130,12 @@ def test_process_file_that_contains_multi_bytes_char(self): """ test that we're able to parse file that contains multi-byte char """ - f = NamedTemporaryFile() - f.write('\u3042'.encode()) # write multi-byte char (hiragana) - f.flush() + with NamedTemporaryFile() as f: + f.write('\u3042'.encode()) # write multi-byte char (hiragana) + f.flush() - dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) - assert [] == dagbag.process_file(f.name) + dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) + assert [] == dagbag.process_file(f.name) def test_zip_skip_log(self): """ @@ -285,13 +285,13 @@ def process_dag(self, create_dag): """ # write source to file source = textwrap.dedent(''.join(inspect.getsource(create_dag).splitlines(True)[1:-1])) - f = NamedTemporaryFile() - f.write(source.encode('utf8')) - f.flush() + with NamedTemporaryFile() as f: + f.write(source.encode('utf8')) + f.flush() - dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) - found_dags = dagbag.process_file(f.name) - return dagbag, found_dags, f.name + dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) + found_dags = dagbag.process_file(f.name) + return dagbag, found_dags, f.name def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, should_be_found=True): expected_dag_ids = list(map(lambda dag: dag.dag_id, expected_parent_dag.subdags)) diff --git a/tests/providers/amazon/aws/operators/test_s3_file_transform.py b/tests/providers/amazon/aws/operators/test_s3_file_transform.py index a5df0e3e9c0a0..3b749438cb71b 100644 --- a/tests/providers/amazon/aws/operators/test_s3_file_transform.py +++ b/tests/providers/amazon/aws/operators/test_s3_file_transform.py @@ -135,10 +135,11 @@ def test_execute_with_select_expression(self, mock_select_key): @staticmethod def mock_process(mock_popen, return_code=0, process_output=None): - process = mock_popen.return_value - process.stdout.readline.side_effect = process_output or [] - process.wait.return_value = None - process.returncode = return_code + mock_proc = mock.MagicMock() + mock_proc.returncode = return_code + mock_proc.stdout.readline.side_effect = process_output or [] + mock_proc.wait.return_value = None + mock_popen.return_value.__enter__.return_value = mock_proc def s3_paths(self): conn = boto3.client('s3') diff --git a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py index e8d75ba05c37f..c1fddd23202ed 100644 --- a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py +++ b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py @@ -18,7 +18,7 @@ import os import re import unittest -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import MagicMock, patch from airflow.providers.apache.hive.transfers.hive_to_mysql import HiveToMySqlOperator from airflow.utils import timezone @@ -73,26 +73,28 @@ def test_execute_with_mysql_postoperator(self, mock_hive_hook, mock_mysql_hook): @patch('airflow.providers.apache.hive.transfers.hive_to_mysql.MySqlHook') @patch('airflow.providers.apache.hive.transfers.hive_to_mysql.NamedTemporaryFile') @patch('airflow.providers.apache.hive.transfers.hive_to_mysql.HiveServer2Hook') - def test_execute_bulk_load(self, mock_hive_hook, mock_tmp_file, mock_mysql_hook): - type(mock_tmp_file).name = PropertyMock(return_value='tmp_file') + def test_execute_bulk_load(self, mock_hive_hook, mock_tmp_file_context, mock_mysql_hook): + mock_tmp_file = MagicMock() + mock_tmp_file.name = 'tmp_file' + mock_tmp_file_context.return_value.__enter__.return_value = mock_tmp_file context = {} self.kwargs.update(dict(bulk_load=True)) HiveToMySqlOperator(**self.kwargs).execute(context=context) - mock_tmp_file.assert_called_once_with() + mock_tmp_file_context.assert_called_once_with() mock_hive_hook.return_value.to_csv.assert_called_once_with( self.kwargs['sql'], - mock_tmp_file.return_value.name, + 'tmp_file', delimiter='\t', lineterminator='\n', output_header=False, hive_conf=context_to_airflow_vars(context), ) mock_mysql_hook.return_value.bulk_load.assert_called_once_with( - table=self.kwargs['mysql_table'], tmp_file=mock_tmp_file.return_value.name + table=self.kwargs['mysql_table'], tmp_file='tmp_file' ) - mock_tmp_file.return_value.close.assert_called_once_with() + mock_tmp_file_context.return_value.__exit__.assert_called_once_with(None, None, None) @patch('airflow.providers.apache.hive.transfers.hive_to_mysql.MySqlHook') def test_execute_with_hive_conf(self, mock_mysql_hook): diff --git a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py index 953affd7cba8a..99455dab0fb17 100644 --- a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py +++ b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py @@ -41,19 +41,19 @@ def setUp(self): self.kwargs = dict(sql='sql', hive_table='table', task_id='test_mssql_to_hive', dag=None) def test_type_map_binary(self): - # pylint: disable=c-extension-no-member + # pylint: disable=c-extension-no-member, no-member mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(pymssql.BINARY.value) assert mapped_type == 'INT' def test_type_map_decimal(self): - # pylint: disable=c-extension-no-member + # pylint: disable=c-extension-no-member, no-member mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(pymssql.DECIMAL.value) assert mapped_type == 'FLOAT' def test_type_map_number(self): - # pylint: disable=c-extension-no-member + # pylint: disable=c-extension-no-member, no-member mapped_type = MsSqlToHiveOperator(**self.kwargs).type_map(pymssql.NUMBER.value) assert mapped_type == 'INT' diff --git a/tests/providers/apache/pinot/hooks/test_pinot.py b/tests/providers/apache/pinot/hooks/test_pinot.py index e7e6a5139f478..af0676f75e9f3 100644 --- a/tests/providers/apache/pinot/hooks/test_pinot.py +++ b/tests/providers/apache/pinot/hooks/test_pinot.py @@ -161,7 +161,7 @@ def test_run_cli_success(self, mock_popen): mock_proc = mock.MagicMock() mock_proc.returncode = 0 mock_proc.stdout = io.BytesIO(b'') - mock_popen.return_value = mock_proc + mock_popen.return_value.__enter__.return_value = mock_proc params = ["foo", "bar", "baz"] self.db_hook.run_cli(params) @@ -176,8 +176,7 @@ def test_run_cli_failure_error_message(self, mock_popen): mock_proc = mock.MagicMock() mock_proc.returncode = 0 mock_proc.stdout = io.BytesIO(msg) - mock_popen.return_value = mock_proc - + mock_popen.return_value.__enter__.return_value = mock_proc params = ["foo", "bar", "baz"] with pytest.raises(AirflowException): self.db_hook.run_cli(params) @@ -191,7 +190,7 @@ def test_run_cli_failure_status_code(self, mock_popen): mock_proc = mock.MagicMock() mock_proc.returncode = 1 mock_proc.stdout = io.BytesIO(b'') - mock_popen.return_value = mock_proc + mock_popen.return_value.__enter__.return_value = mock_proc self.db_hook.pinot_admin_system_exit = True params = ["foo", "bar", "baz"] diff --git a/tests/providers/apache/sqoop/hooks/test_sqoop.py b/tests/providers/apache/sqoop/hooks/test_sqoop.py index 08926d4568503..8c72ca9696e94 100644 --- a/tests/providers/apache/sqoop/hooks/test_sqoop.py +++ b/tests/providers/apache/sqoop/hooks/test_sqoop.py @@ -21,6 +21,7 @@ import json import unittest from io import StringIO +from unittest import mock from unittest.mock import call, patch import pytest @@ -95,13 +96,15 @@ def setUp(self): @patch('subprocess.Popen') def test_popen(self, mock_popen): # Given - mock_popen.return_value.stdout = StringIO('stdout') - mock_popen.return_value.stderr = StringIO('stderr') - mock_popen.return_value.returncode = 0 - mock_popen.return_value.communicate.return_value = [ + mock_proc = mock.MagicMock() + mock_proc.returncode = 0 + mock_proc.stdout = StringIO('stdout') + mock_proc.stderr = StringIO('stderr') + mock_proc.communicate.return_value = [ StringIO('stdout\nstdout'), StringIO('stderr\nstderr'), ] + mock_popen.return_value.__enter__.return_value = mock_proc # When hook = SqoopHook(conn_id='sqoop_test') diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index f4b50bc14bee3..8af62ef4762c5 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -781,6 +781,7 @@ def setUp(self): self.gcs_hook = gcs.GCSHook(gcp_conn_id='test') # generate a 384KiB test file (larger than the minimum 256KiB multipart chunk size) + # pylint: disable=consider-using-with self.testfile = tempfile.NamedTemporaryFile(delete=False) self.testfile.write(b"x" * 393216) self.testfile.flush() diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index 00ccebbd7995f..cac11ccf03f74 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -160,6 +160,7 @@ class TestGCSFileTransformOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess") @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook") def test_execute(self, mock_hook, mock_subprocess, mock_tempfile): + source_bucket = TEST_BUCKET source_object = "test.txt" destination_bucket = TEST_BUCKET + "-dest" @@ -177,11 +178,16 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempfile): mock_tempfile.return_value.__enter__.side_effect = [mock1, mock2] + mock_proc = mock.MagicMock() + mock_proc.returncode = 0 + mock_proc.stdout.readline = lambda: b"" + mock_proc.wait.return_value = None + mock_popen = mock.MagicMock() + mock_popen.return_value.__enter__.return_value = mock_proc + + mock_subprocess.Popen = mock_popen mock_subprocess.PIPE = "pipe" mock_subprocess.STDOUT = "stdout" - mock_subprocess.Popen.return_value.stdout.readline = lambda: b"" - mock_subprocess.Popen.return_value.wait.return_value = None - mock_subprocess.Popen.return_value.returncode = 0 op = GCSFileTransformOperator( task_id=TASK_ID, @@ -278,11 +284,16 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir): f"{source_prefix}/{file2}", ] + mock_proc = mock.MagicMock() + mock_proc.returncode = 0 + mock_proc.stdout.readline = lambda: b"" + mock_proc.wait.return_value = None + mock_popen = mock.MagicMock() + mock_popen.return_value.__enter__.return_value = mock_proc + + mock_subprocess.Popen = mock_popen mock_subprocess.PIPE = "pipe" mock_subprocess.STDOUT = "stdout" - mock_subprocess.Popen.return_value.stdout.readline = lambda: b"" - mock_subprocess.Popen.return_value.wait.return_value = None - mock_subprocess.Popen.return_value.returncode = 0 op = GCSTimeSpanFileTransformOperator( task_id=TASK_ID, diff --git a/tests/test_utils/logging_command_executor.py b/tests/test_utils/logging_command_executor.py index 5fca2446e76b2..6a7968755cd35 100644 --- a/tests/test_utils/logging_command_executor.py +++ b/tests/test_utils/logging_command_executor.py @@ -31,35 +31,35 @@ def execute_cmd(self, cmd, silent=False, cwd=None, env=None): return subprocess.call(args=cmd, stdout=dev_null, stderr=subprocess.STDOUT, env=env, cwd=cwd) else: self.log.info("Executing: '%s'", " ".join([shlex.quote(c) for c in cmd])) - process = subprocess.Popen( + with subprocess.Popen( args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, cwd=cwd, env=env, - ) - output, err = process.communicate() - retcode = process.poll() - self.log.info("Stdout: %s", output) - self.log.info("Stderr: %s", err) - if retcode: - self.log.error("Error when executing %s", " ".join([shlex.quote(c) for c in cmd])) - return retcode + ) as process: + output, err = process.communicate() + retcode = process.poll() + self.log.info("Stdout: %s", output) + self.log.info("Stderr: %s", err) + if retcode: + self.log.error("Error when executing %s", " ".join([shlex.quote(c) for c in cmd])) + return retcode def check_output(self, cmd): self.log.info("Executing for output: '%s'", " ".join([shlex.quote(c) for c in cmd])) - process = subprocess.Popen(args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - output, err = process.communicate() - retcode = process.poll() - if retcode: - self.log.error("Error when executing '%s'", " ".join([shlex.quote(c) for c in cmd])) - self.log.info("Stdout: %s", output) - self.log.info("Stderr: %s", err) - raise AirflowException( - f"Retcode {retcode} on {' '.join(cmd)} with stdout: {output}, stderr: {err}" - ) - return output + with subprocess.Popen(args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as process: + output, err = process.communicate() + retcode = process.poll() + if retcode: + self.log.error("Error when executing '%s'", " ".join([shlex.quote(c) for c in cmd])) + self.log.info("Stdout: %s", output) + self.log.info("Stderr: %s", err) + raise AirflowException( + f"Retcode {retcode} on {' '.join(cmd)} with stdout: {output}, stderr: {err}" + ) + return output def get_executor() -> LoggingCommandExecutor: diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py index b680fdc46fdc5..28d43284ee8d7 100644 --- a/tests/utils/test_email.py +++ b/tests/utils/test_email.py @@ -128,22 +128,22 @@ def test_build_mime_message(self): class TestEmailSmtp(unittest.TestCase): @mock.patch('airflow.utils.email.send_mime_email') def test_send_smtp(self, mock_send_mime): - attachment = tempfile.NamedTemporaryFile() - attachment.write(b'attachment') - attachment.seek(0) - utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name]) - assert mock_send_mime.called - _, call_args = mock_send_mime.call_args - assert conf.get('smtp', 'SMTP_MAIL_FROM') == call_args['e_from'] - assert ['to'] == call_args['e_to'] - msg = call_args['mime_msg'] - assert 'subject' == msg['Subject'] - assert conf.get('smtp', 'SMTP_MAIL_FROM') == msg['From'] - assert 2 == len(msg.get_payload()) - filename = 'attachment; filename="' + os.path.basename(attachment.name) + '"' - assert filename == msg.get_payload()[-1].get('Content-Disposition') - mimeapp = MIMEApplication('attachment') - assert mimeapp.get_payload() == msg.get_payload()[-1].get_payload() + with tempfile.NamedTemporaryFile() as attachment: + attachment.write(b'attachment') + attachment.seek(0) + utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name]) + assert mock_send_mime.called + _, call_args = mock_send_mime.call_args + assert conf.get('smtp', 'SMTP_MAIL_FROM') == call_args['e_from'] + assert ['to'] == call_args['e_to'] + msg = call_args['mime_msg'] + assert 'subject' == msg['Subject'] + assert conf.get('smtp', 'SMTP_MAIL_FROM') == msg['From'] + assert 2 == len(msg.get_payload()) + filename = 'attachment; filename="' + os.path.basename(attachment.name) + '"' + assert filename == msg.get_payload()[-1].get('Content-Disposition') + mimeapp = MIMEApplication('attachment') + assert mimeapp.get_payload() == msg.get_payload()[-1].get_payload() @mock.patch('airflow.utils.email.send_mime_email') def test_send_smtp_with_multibyte_content(self, mock_send_mime): @@ -156,23 +156,25 @@ def test_send_smtp_with_multibyte_content(self, mock_send_mime): @mock.patch('airflow.utils.email.send_mime_email') def test_send_bcc_smtp(self, mock_send_mime): - attachment = tempfile.NamedTemporaryFile() - attachment.write(b'attachment') - attachment.seek(0) - utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name], cc='cc', bcc='bcc') - assert mock_send_mime.called - _, call_args = mock_send_mime.call_args - assert conf.get('smtp', 'SMTP_MAIL_FROM') == call_args['e_from'] - assert ['to', 'cc', 'bcc'] == call_args['e_to'] - msg = call_args['mime_msg'] - assert 'subject' == msg['Subject'] - assert conf.get('smtp', 'SMTP_MAIL_FROM') == msg['From'] - assert 2 == len(msg.get_payload()) - assert 'attachment; filename="' + os.path.basename(attachment.name) + '"' == msg.get_payload()[ - -1 - ].get('Content-Disposition') - mimeapp = MIMEApplication('attachment') - assert mimeapp.get_payload() == msg.get_payload()[-1].get_payload() + with tempfile.NamedTemporaryFile() as attachment: + attachment.write(b'attachment') + attachment.seek(0) + utils.email.send_email_smtp( + 'to', 'subject', 'content', files=[attachment.name], cc='cc', bcc='bcc' + ) + assert mock_send_mime.called + _, call_args = mock_send_mime.call_args + assert conf.get('smtp', 'SMTP_MAIL_FROM') == call_args['e_from'] + assert ['to', 'cc', 'bcc'] == call_args['e_to'] + msg = call_args['mime_msg'] + assert 'subject' == msg['Subject'] + assert conf.get('smtp', 'SMTP_MAIL_FROM') == msg['From'] + assert 2 == len(msg.get_payload()) + assert 'attachment; filename="' + os.path.basename(attachment.name) + '"' == msg.get_payload()[ + -1 + ].get('Content-Disposition') + mimeapp = MIMEApplication('attachment') + assert mimeapp.get_payload() == msg.get_payload()[-1].get_payload() @mock.patch('smtplib.SMTP_SSL') @mock.patch('smtplib.SMTP')