diff --git a/README.rst b/README.rst index 4aa9f767db..522493a838 100644 --- a/README.rst +++ b/README.rst @@ -173,6 +173,9 @@ Some more companies are using Luigi but haven't had a chance yet to write about * `Okko `_ * `ISVWorld `_ * `Big Data `_ +* `Movio `_ +* `Bonnier News `_ +* `Starsky Robotics `_ We're more than happy to have your company added here. Just send a PR on GitHub. diff --git a/doc/aggregate_artists.png b/doc/aggregate_artists.png index f7a395668b..e434fd2895 100644 Binary files a/doc/aggregate_artists.png and b/doc/aggregate_artists.png differ diff --git a/doc/configuration.rst b/doc/configuration.rst index d152855ca9..f206d53652 100644 --- a/doc/configuration.rst +++ b/doc/configuration.rst @@ -4,9 +4,12 @@ Configuration All configuration can be done by adding configuration files. Supported config parsers: -* ``cfg`` (default) + +* ``cfg`` (default), based on Python's standard ConfigParser_. Values may refer to environment variables using ``${ENVVAR}`` syntax. * ``toml`` +.. _ConfigParser: https://docs.python.org/3/library/configparser.html + You can choose right parser via ``LUIGI_CONFIG_PARSER`` environment variable. For example, ``LUIGI_CONFIG_PARSER=toml``. Default (cfg) parser are looked for in: @@ -202,6 +205,51 @@ rpc-retry-wait Defaults to 30 +[cors] +------ + +.. versionadded:: 2.8.0 + +These parameters control ``/api/`` ``CORS`` behaviour (see: `W3C Cross-Origin Resource Sharing +`_). + +enabled + Enables CORS support. + Defaults to false. + +allowed_origins + A list of allowed origins. Used only if ``allow_any_origin`` is false. + Configure in JSON array format, e.g. ["foo", "bar"]. + Defaults to empty. + +allow_any_origin + Accepts requests from any origin. + Defaults to false. + +allow_null_origin + Allows the request to set ``null`` value of the ``Origin`` header. + Defaults to false. + +max_age + Content of ``Access-Control-Max-Age``. + Defaults to 86400 (24 hours). + +allowed_methods + Content of ``Access-Control-Allow-Methods``. + Defaults to ``GET, OPTIONS``. + +allowed_headers + Content of ``Access-Control-Allow-Headers``. + Defaults to ``Accept, Content-Type, Origin``. + +exposed_headers + Content of ``Access-Control-Expose-Headers``. + Defaults to empty string (will NOT be sent as a response header). + +allow_credentials + Indicates that the actual request can include user credentials. + Defaults to false. + .. _worker-config: [worker] @@ -607,7 +655,7 @@ is good practice to do so when you have a fixed set of resources. .. _retcode-config: [retcode] ----------- +--------- Configure return codes for the Luigi binary. In the case of multiple return codes that could apply, for example a failing task and missing data, the diff --git a/doc/dependency_graph.png b/doc/dependency_graph.png index 4196f30c02..28531b4c9d 100644 Binary files a/doc/dependency_graph.png and b/doc/dependency_graph.png differ diff --git a/doc/execution_model.png b/doc/execution_model.png index 535eb31066..11a2b28ff9 100644 Binary files a/doc/execution_model.png and b/doc/execution_model.png differ diff --git a/doc/history.png b/doc/history.png index e8173fcf85..9aa3e2dcbb 100644 Binary files a/doc/history.png and b/doc/history.png differ diff --git a/doc/history_by_id.png b/doc/history_by_id.png index 97a90f1cc5..c1ede5b7fd 100644 Binary files a/doc/history_by_id.png and b/doc/history_by_id.png differ diff --git a/doc/history_by_name.png b/doc/history_by_name.png index 5bef1291a4..a54fad59bd 100644 Binary files a/doc/history_by_name.png and b/doc/history_by_name.png differ diff --git a/doc/luigi.png b/doc/luigi.png index 586387c3ea..8285638906 100644 Binary files a/doc/luigi.png and b/doc/luigi.png differ diff --git a/doc/parameters_date_algebra.png b/doc/parameters_date_algebra.png index 7da7d616f0..37705065aa 100644 Binary files a/doc/parameters_date_algebra.png and b/doc/parameters_date_algebra.png differ diff --git a/doc/parameters_enum.png b/doc/parameters_enum.png index 9af33a4d51..2a224e44c4 100644 Binary files a/doc/parameters_enum.png and b/doc/parameters_enum.png differ diff --git a/doc/parameters_recursion.png b/doc/parameters_recursion.png index 43b23c2155..61eac1e4a8 100644 Binary files a/doc/parameters_recursion.png and b/doc/parameters_recursion.png differ diff --git a/doc/task_breakdown.png b/doc/task_breakdown.png index eeb793d45c..f2628b8210 100644 Binary files a/doc/task_breakdown.png and b/doc/task_breakdown.png differ diff --git a/doc/task_parameters.png b/doc/task_parameters.png index a26c2a5a74..b5677b55e5 100644 Binary files a/doc/task_parameters.png and b/doc/task_parameters.png differ diff --git a/doc/task_with_targets.png b/doc/task_with_targets.png index 0354144bce..46ce7f70fb 100644 Binary files a/doc/task_with_targets.png and b/doc/task_with_targets.png differ diff --git a/doc/tasks_input_output_requires.png b/doc/tasks_input_output_requires.png index 06a551b100..09f3b8378b 100644 Binary files a/doc/tasks_input_output_requires.png and b/doc/tasks_input_output_requires.png differ diff --git a/doc/tasks_with_dependencies.png b/doc/tasks_with_dependencies.png index 7adcb26674..63536582f3 100644 Binary files a/doc/tasks_with_dependencies.png and b/doc/tasks_with_dependencies.png differ diff --git a/doc/user_recs.png b/doc/user_recs.png index 3d3044bdb9..60943bf181 100644 Binary files a/doc/user_recs.png and b/doc/user_recs.png differ diff --git a/doc/visualiser_front_page.png b/doc/visualiser_front_page.png index 8acb6891ee..e5bc87bdd8 100644 Binary files a/doc/visualiser_front_page.png and b/doc/visualiser_front_page.png differ diff --git a/doc/web_server.png b/doc/web_server.png index cad1541d73..20605d0766 100644 Binary files a/doc/web_server.png and b/doc/web_server.png differ diff --git a/examples/spark_als.py b/examples/spark_als.py index dae0479169..c24734b87b 100644 --- a/examples/spark_als.py +++ b/examples/spark_als.py @@ -31,7 +31,7 @@ class UserItemMatrix(luigi.Task): def run(self): """ Generates :py:attr:`~.UserItemMatrix.data_size` elements. - Writes this data in \ separated value format into the target :py:func:`~/.UserItemMatrix.output`. + Writes this data in \\ separated value format into the target :py:func:`~/.UserItemMatrix.output`. The data has the following elements: @@ -43,7 +43,7 @@ def run(self): w = self.output().open('w') for user in range(self.data_size): track = int(random.random() * self.data_size) - w.write('%d\%d\%f' % (user, track, 1.0)) + w.write('%d\\%d\\%f' % (user, track, 1.0)) w.close() def output(self): diff --git a/luigi/configuration/cfg_parser.py b/luigi/configuration/cfg_parser.py index e0df87f10a..c4545fe75a 100644 --- a/luigi/configuration/cfg_parser.py +++ b/luigi/configuration/cfg_parser.py @@ -30,16 +30,94 @@ """ import os +import re import warnings try: - from ConfigParser import ConfigParser, NoOptionError, NoSectionError + from ConfigParser import ConfigParser, NoOptionError, NoSectionError, InterpolationError + Interpolation = object except ImportError: - from configparser import ConfigParser, NoOptionError, NoSectionError + from configparser import ConfigParser, NoOptionError, NoSectionError, InterpolationError + from configparser import Interpolation, BasicInterpolation from .base_parser import BaseParser +class InterpolationMissingEnvvarError(InterpolationError): + """ + Raised when option value refers to a nonexisting environment variable. + """ + + def __init__(self, option, section, value, envvar): + msg = ( + "Config refers to a nonexisting environment variable {}. " + "Section [{}], option {}={}" + ).format(envvar, section, option, value) + InterpolationError.__init__(self, option, section, msg) + + +class EnvironmentInterpolation(Interpolation): + """ + Custom interpolation which allows values to refer to environment variables + using the ``${ENVVAR}`` syntax. + """ + _ENVRE = re.compile(r"\$\{([^}]+)\}") # matches "${envvar}" + + def before_get(self, parser, section, option, value, defaults): + return self._interpolate_env(option, section, value) + + def _interpolate_env(self, option, section, value): + rawval = value + parts = [] + while value: + match = self._ENVRE.search(value) + if match is None: + parts.append(value) + break + envvar = match.groups()[0] + try: + envval = os.environ[envvar] + except KeyError: + raise InterpolationMissingEnvvarError( + option, section, rawval, envvar) + start, end = match.span() + parts.append(value[:start]) + parts.append(envval) + value = value[end:] + return "".join(parts) + + +class CombinedInterpolation(Interpolation): + """ + Custom interpolation which applies multiple interpolations in series. + + :param interpolations: a sequence of configparser.Interpolation objects. + """ + + def __init__(self, interpolations): + self._interpolations = interpolations + + def before_get(self, parser, section, option, value, defaults): + for interp in self._interpolations: + value = interp.before_get(parser, section, option, value, defaults) + return value + + def before_read(self, parser, section, option, value): + for interp in self._interpolations: + value = interp.before_read(parser, section, option, value) + return value + + def before_set(self, parser, section, option, value): + for interp in self._interpolations: + value = interp.before_set(parser, section, option, value) + return value + + def before_write(self, parser, section, option, value): + for interp in self._interpolations: + value = interp.before_write(parser, section, option, value) + return value + + class LuigiConfigParser(BaseParser, ConfigParser): NO_DEFAULT = object() enabled = True @@ -50,6 +128,17 @@ class LuigiConfigParser(BaseParser, ConfigParser): 'client.cfg', # Deprecated old-style local luigi config 'luigi.cfg', ] + if hasattr(ConfigParser, "_interpolate"): + # Override ConfigParser._interpolate (Python 2) + def _interpolate(self, section, option, rawval, vars): + value = ConfigParser._interpolate(self, section, option, rawval, vars) + return EnvironmentInterpolation().before_get( + parser=self, section=section, option=option, + value=value, defaults=None, + ) + else: + # Override ConfigParser._DEFAULT_INTERPOLATION (Python 3) + _DEFAULT_INTERPOLATION = CombinedInterpolation([BasicInterpolation(), EnvironmentInterpolation()]) @classmethod def reload(cls): diff --git a/luigi/contrib/ecs.py b/luigi/contrib/ecs.py index f563e73dc0..8612d00499 100644 --- a/luigi/contrib/ecs.py +++ b/luigi/contrib/ecs.py @@ -182,6 +182,11 @@ def run(self): response = client.run_task(taskDefinition=self.task_def_arn, overrides=overrides, cluster=self.cluster) + + if response['failures']: + raise Exception(", ".join(["fail to run task {0} reason: {1}".format(failure['arn'], failure['reason']) + for failure in response['failures']])) + self._task_ids = [task['taskArn'] for task in response['tasks']] # Wait on task completion diff --git a/luigi/contrib/external_program.py b/luigi/contrib/external_program.py index 4d726f9eef..0377e6847c 100644 --- a/luigi/contrib/external_program.py +++ b/luigi/contrib/external_program.py @@ -59,7 +59,7 @@ class ExternalProgramTask(luigi.Task): behaviour can be overriden by passing ``--capture-output False`` """ - capture_output = luigi.BoolParameter(default=True, significant=False) + capture_output = luigi.BoolParameter(default=True, significant=False, positional=False) def program_args(self): """ diff --git a/luigi/contrib/hdfs/format.py b/luigi/contrib/hdfs/format.py index 1856abf3f4..f87ebd5bf3 100644 --- a/luigi/contrib/hdfs/format.py +++ b/luigi/contrib/hdfs/format.py @@ -1,9 +1,10 @@ -import luigi.format import logging import os + +import luigi.format from luigi.contrib.hdfs.config import load_hadoop_cmd from luigi.contrib.hdfs import config as hdfs_config -from luigi.contrib.hdfs.clients import remove, rename, mkdir, listdir +from luigi.contrib.hdfs.clients import remove, rename, mkdir, listdir, exists from luigi.contrib.hdfs.error import HDFSCliError logger = logging.getLogger('luigi-interface') @@ -75,9 +76,13 @@ def abort(self): def close(self): super(HdfsAtomicWriteDirPipe, self).close() try: - remove(self.path) - except HDFSCliError: - pass + if exists(self.path): + remove(self.path) + except Exception as ex: + if isinstance(ex, HDFSCliError) or ex.args[0].contains("FileNotFoundException"): + pass + else: + raise ex # it's unlikely to fail in this way but better safe than sorry if not all(result['result'] for result in rename(self.tmppath, self.path) or []): diff --git a/luigi/contrib/hdfs/snakebite_client.py b/luigi/contrib/hdfs/snakebite_client.py index 5c38787f4f..b6cd7aedfc 100644 --- a/luigi/contrib/hdfs/snakebite_client.py +++ b/luigi/contrib/hdfs/snakebite_client.py @@ -131,7 +131,7 @@ def remove(self, path, recursive=True, skip_trash=False): :param path: delete-able file(s) or directory(ies) :type path: either a string or a sequence of strings - :param recursive: delete directories trees like \*nix: rm -r + :param recursive: delete directories trees like \\*nix: rm -r :type recursive: boolean, default is True :param skip_trash: do or don't move deleted items into the trash first :type skip_trash: boolean, default is False (use trash) @@ -145,7 +145,7 @@ def chmod(self, path, permissions, recursive=False): :param path: update-able file(s) :type path: either a string or sequence of strings - :param permissions: \*nix style permission number + :param permissions: \\*nix style permission number :type permissions: octal :param recursive: change just listed entry(ies) or all in directories :type recursive: boolean, default is False @@ -242,7 +242,7 @@ def mkdir(self, path, parents=True, mode=0o755, raise_if_exists=False): :type path: string :param parents: create any missing parent directories :type parents: boolean, default is True - :param mode: \*nix style owner/group/other permissions + :param mode: \\*nix style owner/group/other permissions :type mode: octal, default 0755 """ result = list(self.get_bite().mkdir(self.list_path(path), diff --git a/luigi/contrib/hdfs/target.py b/luigi/contrib/hdfs/target.py index 0655b23e0f..9beba56190 100644 --- a/luigi/contrib/hdfs/target.py +++ b/luigi/contrib/hdfs/target.py @@ -185,3 +185,38 @@ def _is_writable(self, path): return True except hdfs_clients.HDFSCliError: return False + + +class HdfsFlagTarget(HdfsTarget): + """ + Defines a target directory with a flag-file (defaults to `_SUCCESS`) used + to signify job success. + + This checks for two things: + + * the path exists (just like the HdfsTarget) + * the _SUCCESS file exists within the directory. + + Because Hadoop outputs into a directory and not a single file, + the path is assumed to be a directory. + """ + def __init__(self, path, format=None, client=None, flag='_SUCCESS'): + """ + Initializes a HdfsFlagTarget. + + :param path: the directory where the files are stored. + :type path: str + :param client: + :type client: + :param flag: + :type flag: str + """ + if path[-1] != "/": + raise ValueError("HdfsFlagTarget requires the path to be to a " + "directory. It must end with a slash ( / ).") + super(HdfsFlagTarget, self).__init__(path, format, client) + self.flag = flag + + def exists(self): + hadoopSemaphore = self.path + self.flag + return self.fs.exists(hadoopSemaphore) diff --git a/luigi/contrib/mysqldb.py b/luigi/contrib/mysqldb.py index 9e198fcfc6..dabef5ff73 100644 --- a/luigi/contrib/mysqldb.py +++ b/luigi/contrib/mysqldb.py @@ -19,14 +19,16 @@ import luigi +from luigi.contrib import rdbms + logger = logging.getLogger('luigi-interface') try: import mysql.connector - from mysql.connector import errorcode + from mysql.connector import errorcode, Error except ImportError as e: logger.warning("Loading MySQL module without the python package mysql-connector-python. \ - This will crash at runtime if MySQL functionality is used.") + This will crash at runtime if MySQL functionality is used.") class MySqlTarget(luigi.Target): @@ -147,3 +149,102 @@ def create_marker_table(self): else: raise connection.close() + + +class CopyToTable(rdbms.CopyToTable): + """ + Template task for inserting a data set into MySQL + + Usage: + Subclass and override the required `host`, `database`, `user`, + `password`, `table` and `columns` attributes. + + To customize how to access data from an input task, override the `rows` method + with a generator that yields each row as a tuple with fields ordered according to `columns`. + """ + + def rows(self): + """ + Return/yield tuples or lists corresponding to each row to be inserted. + """ + with self.input().open('r') as fobj: + for line in fobj: + yield line.strip('\n').split('\t') + +# everything below will rarely have to be overridden + + def output(self): + """ + Returns a MySqlTarget representing the inserted dataset. + + Normally you don't override this. + """ + return MySqlTarget( + host=self.host, + database=self.database, + user=self.user, + password=self.password, + table=self.table, + update_id=self.update_id + + ) + + def copy(self, cursor, file=None): + values = '({})'.format(','.join(['%s' for i in range(len(self.columns))])) + columns = '({})'.format(','.join([c[0] for c in self.columns])) + query = 'INSERT INTO {} {} VALUES {}'.format(self.table, columns, values) + rows = [] + + for idx, row in enumerate(self.rows()): + rows.append(row) + + if (idx + 1) % self.bulk_size == 0: + cursor.executemany(query, rows) + rows = [] + + cursor.executemany(query, rows) + + def run(self): + """ + Inserts data generated by rows() into target table. + + If the target table doesn't exist, self.create_table will be called to attempt to create the table. + + Normally you don't want to override this. + """ + if not (self.table and self.columns): + raise Exception("table and columns need to be specified") + + connection = self.output().connect() + + # attempt to copy the data into mysql + # if it fails because the target table doesn't exist + # try to create it by running self.create_table + for attempt in range(2): + try: + cursor = connection.cursor() + print("caling init copy...") + self.init_copy(connection) + self.copy(cursor) + self.post_copy(connection) + if self.enable_metadata_columns: + self.post_copy_metacolumns(cursor) + except Error as err: + if err.errno == errorcode.ER_NO_SUCH_TABLE and attempt == 0: + # if first attempt fails with "relation not found", try creating table + # logger.info("Creating table %s", self.table) + connection.reconnect() + self.create_table(connection) + else: + raise + else: + break + + # mark as complete in same transaction + self.output().touch(connection) + connection.commit() + connection.close() + + @property + def bulk_size(self): + return 10000 diff --git a/luigi/contrib/opener.py b/luigi/contrib/opener.py index 7607583c52..93c360b9de 100644 --- a/luigi/contrib/opener.py +++ b/luigi/contrib/opener.py @@ -189,7 +189,7 @@ def get_target(cls, scheme, path, fragment, username, the expected target. """ - raise NotImplemented("get_target must be overridden") + raise NotImplementedError("get_target must be overridden") class MockOpener(Opener): diff --git a/luigi/contrib/s3.py b/luigi/contrib/s3.py index 9fb21d283f..1ccc438d7c 100644 --- a/luigi/contrib/s3.py +++ b/luigi/contrib/s3.py @@ -31,9 +31,6 @@ import warnings from multiprocessing.pool import ThreadPool -import botocore -from boto3.s3.transfer import TransferConfig - try: from urlparse import urlsplit except ImportError: @@ -54,6 +51,13 @@ logger = logging.getLogger('luigi-interface') +try: + from boto3.s3.transfer import TransferConfig + import botocore +except ImportError: + logger.warning("Loading S3 module without the python package boto3. " + "Will crash at runtime if S3 functionality is used.") + # two different ways of marking a directory # with a suffix in S3 diff --git a/luigi/server.py b/luigi/server.py index 79a696cc82..6c97cb7427 100644 --- a/luigi/server.py +++ b/luigi/server.py @@ -36,12 +36,12 @@ # import atexit +import datetime import json import logging import os import signal import sys -import datetime import time import pkg_resources @@ -50,21 +50,64 @@ import tornado.netutil import tornado.web +from luigi import Config, parameter from luigi.scheduler import Scheduler, RPC_METHODS logger = logging.getLogger("luigi.server") +class cors(Config): + enabled = parameter.BoolParameter( + default=False, + description='Enables CORS support.') + allowed_origins = parameter.ListParameter( + default=[], + description='A list of allowed origins. Used only if `allow_any_origin` is false.') + allow_any_origin = parameter.BoolParameter( + default=False, + description='Accepts requests from any origin.') + allow_null_origin = parameter.BoolParameter( + default=False, + description='Allows the request to set `null` value of the `Origin` header.') + max_age = parameter.IntParameter( + default=86400, + description='Content of `Access-Control-Max-Age`.') + allowed_methods = parameter.Parameter( + default='GET, OPTIONS', + description='Content of `Access-Control-Allow-Methods`.') + allowed_headers = parameter.Parameter( + default='Accept, Content-Type, Origin', + description='Content of `Access-Control-Allow-Headers`.') + exposed_headers = parameter.Parameter( + default='', + description='Content of `Access-Control-Expose-Headers`.') + allow_credentials = parameter.BoolParameter( + default=False, + description='Indicates that the actual request can include user credentials.') + + def __init__(self, *args, **kwargs): + super(cors, self).__init__(*args, **kwargs) + self.allowed_origins = set(i for i in self.allowed_origins if i not in ['*', 'null']) + + class RPCHandler(tornado.web.RequestHandler): """ Handle remote scheduling calls using rpc.RemoteSchedulerResponder. """ + def __init__(self, *args, **kwargs): + super(RPCHandler, self).__init__(*args, **kwargs) + self._cors_config = cors() + def initialize(self, scheduler): self._scheduler = scheduler - self.set_header("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type, Origin") - self.set_header("Access-Control-Allow-Methods", "GET, OPTIONS") - self.set_header("Access-Control-Allow-Origin", "*") + + def options(self, *args): + if self._cors_config.enabled: + self._handle_cors_preflight() + + self.set_status(204) + self.finish() def get(self, method): if method not in RPC_METHODS: @@ -75,12 +118,57 @@ def get(self, method): if hasattr(self._scheduler, method): result = getattr(self._scheduler, method)(**arguments) + + if self._cors_config.enabled: + self._handle_cors() + self.write({"response": result}) # wrap all json response in a dictionary else: self.send_error(404) post = get + def _handle_cors_preflight(self): + origin = self.request.headers.get('Origin') + if not origin: + return + + if origin == 'null': + if self._cors_config.allow_null_origin: + self.set_header('Access-Control-Allow-Origin', 'null') + self._set_other_cors_headers() + else: + if self._cors_config.allow_any_origin: + self.set_header('Access-Control-Allow-Origin', '*') + self._set_other_cors_headers() + elif origin in self._cors_config.allowed_origins: + self.set_header('Access-Control-Allow-Origin', origin) + self._set_other_cors_headers() + + def _handle_cors(self): + origin = self.request.headers.get('Origin') + if not origin: + return + + if origin == 'null': + if self._cors_config.allow_null_origin: + self.set_header('Access-Control-Allow-Origin', 'null') + else: + if self._cors_config.allow_any_origin: + self.set_header('Access-Control-Allow-Origin', '*') + elif origin in self._cors_config.allowed_origins: + self.set_header('Access-Control-Allow-Origin', origin) + self.set_header('Vary', 'Origin') + + def _set_other_cors_headers(self): + self.set_header('Access-Control-Max-Age', str(self._cors_config.max_age)) + self.set_header('Access-Control-Allow-Methods', self._cors_config.allowed_methods) + self.set_header('Access-Control-Allow-Headers', self._cors_config.allowed_headers) + if self._cors_config.allow_credentials: + self.set_header('Access-Control-Allow-Credentials', 'true') + if self._cors_config.exposed_headers: + self.set_header('Access-Control-Expose-Headers', self._cors_config.exposed_headers) + class BaseTaskHistoryHandler(tornado.web.RequestHandler): def initialize(self, scheduler): diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_222222_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_222222_256x240.png index c1cb1170c8..4586de1a05 100755 Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_222222_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_222222_256x240.png differ diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_2e83ff_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_2e83ff_256x240.png index 84b601bf0f..c4dfa36bd5 100755 Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_2e83ff_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_2e83ff_256x240.png differ diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_454545_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_454545_256x240.png index b6db1acdd4..c39a7411b0 100755 Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_454545_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_454545_256x240.png differ diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_888888_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_888888_256x240.png index feea0e2026..cef233a7fd 100755 Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_888888_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_888888_256x240.png differ diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_cd0a0a_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_cd0a0a_256x240.png index ed5b6b0930..64733826fd 100755 Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_cd0a0a_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_cd0a0a_256x240.png differ diff --git a/luigi/tools/deps_tree.py b/luigi/tools/deps_tree.py index 13a63164c9..27a00313e3 100755 --- a/luigi/tools/deps_tree.py +++ b/luigi/tools/deps_tree.py @@ -44,7 +44,7 @@ def print_tree(task, indent='', last=True): ''' # dont bother printing out warnings about tasks with no output with warnings.catch_warnings(): - warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete\(\) method') + warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete\\(\\) method') is_task_complete = task.complete() is_complete = (bcolors.OKGREEN + 'COMPLETE' if is_task_complete else bcolors.OKBLUE + 'PENDING') + bcolors.ENDC name = task.__class__.__name__ diff --git a/luigi/worker.py b/luigi/worker.py index 6cdaff1884..c7a24c3c5c 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -61,7 +61,7 @@ from luigi.task import Task, flatten, getpaths, Config from luigi.task_register import TaskClassException from luigi.task_status import RUNNING -from luigi.parameter import BoolParameter, FloatParameter, IntParameter, Parameter +from luigi.parameter import BoolParameter, FloatParameter, IntParameter, OptionalParameter try: import simplejson as json @@ -446,12 +446,12 @@ class worker(Config): force_multiprocessing = BoolParameter(default=False, description='If true, use multiprocessing also when ' 'running with 1 worker') - task_process_context = Parameter(default=None, - description='If set to a fully qualified class name, the class will ' - 'be instantiated with a TaskProcess as its constructor parameter and ' - 'applied as a context manager around its run() call, so this can be ' - 'used for obtaining high level customizable monitoring or logging of ' - 'each individual Task run.') + task_process_context = OptionalParameter(default=None, + description='If set to a fully qualified class name, the class will ' + 'be instantiated with a TaskProcess as its constructor parameter and ' + 'applied as a context manager around its run() call, so this can be ' + 'used for obtaining high level customizable monitoring or logging of ' + 'each individual Task run.') class KeepAliveThread(threading.Thread): diff --git a/setup.py b/setup.py index 0fa21e38a2..c015102b73 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ # the License. import os -import sys from setuptools import setup @@ -41,21 +40,19 @@ def get_static_files(path): 'tornado>=4.0,<5', # https://pagure.io/python-daemon/issue/18 'python-daemon<2.2.0', + 'enum34>1.1.0;python_version<"3.4"', ] if os.environ.get('READTHEDOCS', None) == 'True': # So that we can build documentation for luigi.db_task_history and luigi.contrib.sqla install_requires.append('sqlalchemy') # readthedocs don't like python-daemon, see #1342 - install_requires.remove('python-daemon<3.0') + install_requires.remove('python-daemon<2.2.0') install_requires.append('sphinx>=1.4.4') # Value mirrored in doc/conf.py -if sys.version_info < (3, 4): - install_requires.append('enum34>1.1.0') - setup( name='luigi', - version='2.7.9', + version='2.8.0', description='Workflow mgmgt + task scheduling + dependency resolution', long_description=long_description, author='The Luigi Authors', diff --git a/test/config_env_test.py b/test/config_env_test.py new file mode 100644 index 0000000000..d5ce8a53a0 --- /dev/null +++ b/test/config_env_test.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018 Vote inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +from luigi.configuration import LuigiConfigParser, get_config +from luigi.configuration.cfg_parser import InterpolationMissingEnvvarError + +from helpers import LuigiTestCase, with_config + + +class ConfigParserTest(LuigiTestCase): + + environ = { + "TESTVAR": "1", + } + + def setUp(self): + self.environ_backup = { + os.environ[key] for key in self.environ + if key in os.environ + } + for key, value in self.environ.items(): + os.environ[key] = value + LuigiConfigParser._instance = None + super(ConfigParserTest, self).setUp() + + def tearDown(self): + for key in self.environ: + os.environ.pop(key) + for key, value in self.environ_backup: + os.environ[key] = value + + @with_config({"test": { + "a": "testval", + "b": "%(a)s", + "c": "%(a)s%(a)s", + }}) + def test_basic_interpolation(self): + # Make sure the default ConfigParser behaviour is not broken + config = get_config() + + self.assertEqual(config.get("test", "b"), config.get("test", "a")) + self.assertEqual(config.get("test", "c"), 2 * config.get("test", "a")) + + @with_config({"test": { + "a": "${TESTVAR}", + "b": "${TESTVAR} ${TESTVAR}", + "c": "${TESTVAR} %(a)s", + "d": "${NONEXISTING}", + }}) + def test_env_interpolation(self): + config = get_config() + + self.assertEqual(config.get("test", "a"), "1") + self.assertEqual(config.getint("test", "a"), 1) + self.assertEqual(config.getboolean("test", "a"), True) + + self.assertEqual(config.get("test", "b"), "1 1") + + self.assertEqual(config.get("test", "c"), "1 1") + + with self.assertRaises(InterpolationMissingEnvvarError): + config.get("test", "d") diff --git a/test/contrib/hdfs_test.py b/test/contrib/hdfs_test.py index 8b0e2f5edd..9991d3c852 100644 --- a/test/contrib/hdfs_test.py +++ b/test/contrib/hdfs_test.py @@ -532,27 +532,45 @@ def test_tmppath_not_configured(self): res9 = hdfs.tmppath(path9, include_unix_username=False) # Then: I should get correct results relative to Luigi temporary directory - self.assertRegexpMatches(res1, "^/tmp/dir1/dir2/file-luigitemp-\d+") + self.assertRegexpMatches(res1, "^/tmp/dir1/dir2/file-luigitemp-\\d+") # it would be better to see hdfs:///path instead of hdfs:/path, but single slash also works well - self.assertRegexpMatches(res2, "^hdfs:/tmp/dir1/dir2/file-luigitemp-\d+") - self.assertRegexpMatches(res3, "^hdfs://somehost/tmp/dir1/dir2/file-luigitemp-\d+") - self.assertRegexpMatches(res4, "^file:///tmp/dir1/dir2/file-luigitemp-\d+") - self.assertRegexpMatches(res5, "^/tmp/dir/file-luigitemp-\d+") + self.assertRegexpMatches(res2, "^hdfs:/tmp/dir1/dir2/file-luigitemp-\\d+") + self.assertRegexpMatches(res3, "^hdfs://somehost/tmp/dir1/dir2/file-luigitemp-\\d+") + self.assertRegexpMatches(res4, "^file:///tmp/dir1/dir2/file-luigitemp-\\d+") + self.assertRegexpMatches(res5, "^/tmp/dir/file-luigitemp-\\d+") # known issue with duplicated "tmp" if schema is present - self.assertRegexpMatches(res6, "^file:///tmp/tmp/dir/file-luigitemp-\d+") + self.assertRegexpMatches(res6, "^file:///tmp/tmp/dir/file-luigitemp-\\d+") # known issue with duplicated "tmp" if schema is present - self.assertRegexpMatches(res7, "^hdfs://somehost/tmp/tmp/dir/file-luigitemp-\d+") - self.assertRegexpMatches(res8, "^/tmp/luigitemp-\d+") + self.assertRegexpMatches(res7, "^hdfs://somehost/tmp/tmp/dir/file-luigitemp-\\d+") + self.assertRegexpMatches(res8, "^/tmp/luigitemp-\\d+") self.assertRegexpMatches(res9, "/tmp/tmpdir/file") def test_tmppath_username(self): self.assertRegexpMatches(hdfs.tmppath('/path/to/stuff', include_unix_username=True), - "^/tmp/[a-z0-9_]+/path/to/stuff-luigitemp-\d+") + "^/tmp/[a-z0-9_]+/path/to/stuff-luigitemp-\\d+") def test_pickle(self): t = hdfs.HdfsTarget("/tmp/dir") pickle.dumps(t) + def test_flag_target(self): + target = hdfs.HdfsFlagTarget("/some/dir/", format=format) + if target.exists(): + target.remove(skip_trash=True) + self.assertFalse(target.exists()) + + t1 = hdfs.HdfsTarget(target.path + "part-00000", format=format) + with t1.open('w'): + pass + t2 = hdfs.HdfsTarget(target.path + "_SUCCESS", format=format) + with t2.open('w'): + pass + self.assertTrue(target.exists()) + + def test_flag_target_fails_if_not_directory(self): + with self.assertRaises(ValueError): + hdfs.HdfsFlagTarget("/home/file.txt") + @attr('minicluster') class HdfsTargetTest(MiniClusterTestCase, HdfsTargetTestMixin): diff --git a/test/contrib/mysqldb_test.py b/test/contrib/mysqldb_test.py new file mode 100644 index 0000000000..fe720225d1 --- /dev/null +++ b/test/contrib/mysqldb_test.py @@ -0,0 +1,124 @@ +from luigi.tools.range import RangeDaily + +import mock + +import luigi.contrib.mysqldb + +import datetime +from helpers import unittest + +from nose.plugins.attrib import attr + + +def datetime_to_epoch(dt): + td = dt - datetime.datetime(1970, 1, 1) + return td.days * 86400 + td.seconds + td.microseconds / 1E6 + + +class MockMysqlCursor(mock.Mock): + """ + Keeps state to simulate executing SELECT queries and fetching results. + """ + def __init__(self, existing_update_ids): + super(MockMysqlCursor, self).__init__() + self.existing = existing_update_ids + + def execute(self, query, params): + if query.startswith('SELECT 1 FROM table_updates'): + self.fetchone_result = (1, ) if params[0] in self.existing else None + else: + self.fetchone_result = None + + def fetchone(self): + return self.fetchone_result + + +class DummyMysqlImporter(luigi.contrib.mysqldb.CopyToTable): + date = luigi.DateParameter() + + host = 'dummy_host' + database = 'dummy_database' + user = 'dummy_user' + password = 'dummy_password' + table = 'dummy_table' + columns = ( + ('some_text', 'text'), + ('some_int', 'int'), + ) + + +# Testing that an existing update will not be run in RangeDaily +@attr('mysql') +class DailyCopyToTableTest(unittest.TestCase): + + @mock.patch('mysql.connector.connect') + def test_bulk_complete(self, mock_connect): + mock_cursor = MockMysqlCursor([ # Existing update_ids + DummyMysqlImporter(date=datetime.datetime(2015, 1, 3)).task_id + ]) + mock_connect.return_value.cursor.return_value = mock_cursor + + task = RangeDaily(of=DummyMysqlImporter, + start=datetime.date(2015, 1, 2), + now=datetime_to_epoch(datetime.datetime(2015, 1, 7))) + actual = sorted([t.task_id for t in task.requires()]) + + self.assertEqual(actual, sorted([ + DummyMysqlImporter(date=datetime.datetime(2015, 1, 2)).task_id, + DummyMysqlImporter(date=datetime.datetime(2015, 1, 4)).task_id, + DummyMysqlImporter(date=datetime.datetime(2015, 1, 5)).task_id, + DummyMysqlImporter(date=datetime.datetime(2015, 1, 6)).task_id, + ])) + self.assertFalse(task.complete()) + + +@attr('mysql') +class TestCopyToTableWithMetaColumns(unittest.TestCase): + @mock.patch("luigi.contrib.mysqldb.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.mysqldb.CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.mysqldb.CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.mysqldb.CopyToTable.rows", return_value=['row1', 'row2']) + @mock.patch("luigi.contrib.mysqldb.MySqlTarget") + @mock.patch('mysql.connector.connect') + def test_copy_with_metadata_columns_enabled(self, + mock_connect, + mock_mysql_target, + mock_rows, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + + task = DummyMysqlImporter(date=datetime.datetime(1991, 3, 24)) + + mock_cursor = MockMysqlCursor([task.task_id]) + mock_connect.return_value.cursor.return_value = mock_cursor + + task = DummyMysqlImporter(date=datetime.datetime(1991, 3, 24)) + task.run() + + self.assertTrue(mock_add_columns.called) + self.assertTrue(mock_update_columns.called) + + @mock.patch("luigi.contrib.mysqldb.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) + @mock.patch("luigi.contrib.mysqldb.CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.mysqldb.CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.mysqldb.CopyToTable.rows", return_value=['row1', 'row2']) + @mock.patch("luigi.contrib.mysqldb.MySqlTarget") + @mock.patch('mysql.connector.connect') + def test_copy_with_metadata_columns_disabled(self, + mock_connect, + mock_mysql_target, + mock_rows, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + + task = DummyMysqlImporter(date=datetime.datetime(1991, 3, 24)) + + mock_cursor = MockMysqlCursor([task.task_id]) + mock_connect.return_value.cursor.return_value = mock_cursor + + task.run() + + self.assertFalse(mock_add_columns.called) + self.assertFalse(mock_update_columns.called) diff --git a/test/server_test.py b/test/server_test.py index 7b0814768e..1516f8dcef 100644 --- a/test/server_test.py +++ b/test/server_test.py @@ -25,6 +25,7 @@ import luigi.rpc import luigi.server import luigi.cmdline +from luigi.configuration import get_config from luigi.scheduler import Scheduler from luigi.six.moves.urllib.parse import ( urlencode, ParseResult, quote as urlquote @@ -84,6 +85,19 @@ def tearDown(self): class ServerTest(ServerTestBase): + def setUp(self): + super(ServerTest, self).setUp() + get_config().remove_section('cors') + self._default_cors = luigi.server.cors() + + get_config().set('cors', 'enabled', 'true') + get_config().set('cors', 'allow_any_origin', 'true') + get_config().set('cors', 'allow_null_origin', 'true') + + def tearDown(self): + super(ServerTest, self).tearDown() + get_config().remove_section('cors') + def test_visualiser(self): page = self.fetch('/').body self.assertTrue(page.find(b'') != -1) @@ -98,16 +112,176 @@ def test_404(self): def test_api_404(self): self._test_404('/api/foo') + def test_api_preflight_cors_headers(self): + response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'}) + headers = dict(response.headers) + + self.assertEqual(self._default_cors.allowed_headers, + headers['Access-Control-Allow-Headers']) + self.assertEqual(self._default_cors.allowed_methods, + headers['Access-Control-Allow-Methods']) + self.assertEqual('*', headers['Access-Control-Allow-Origin']) + self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age']) + self.assertIsNone(headers.get('Access-Control-Allow-Credentials')) + self.assertIsNone(headers.get('Access-Control-Expose-Headers')) + + def test_api_preflight_cors_headers_all_response_headers(self): + get_config().set('cors', 'allow_credentials', 'true') + get_config().set('cors', 'exposed_headers', 'foo, bar') + response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'}) + headers = dict(response.headers) + + self.assertEqual(self._default_cors.allowed_headers, + headers['Access-Control-Allow-Headers']) + self.assertEqual(self._default_cors.allowed_methods, + headers['Access-Control-Allow-Methods']) + self.assertEqual('*', headers['Access-Control-Allow-Origin']) + self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age']) + self.assertEqual('true', headers['Access-Control-Allow-Credentials']) + self.assertEqual('foo, bar', headers['Access-Control-Expose-Headers']) + + def test_api_preflight_cors_headers_null_origin(self): + response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'null'}) + headers = dict(response.headers) + + self.assertEqual(self._default_cors.allowed_headers, + headers['Access-Control-Allow-Headers']) + self.assertEqual(self._default_cors.allowed_methods, + headers['Access-Control-Allow-Methods']) + self.assertEqual('null', headers['Access-Control-Allow-Origin']) + self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age']) + self.assertIsNone(headers.get('Access-Control-Allow-Credentials')) + self.assertIsNone(headers.get('Access-Control-Expose-Headers')) + + def test_api_preflight_cors_headers_disallow_null(self): + get_config().set('cors', 'allow_null_origin', 'false') + response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'null'}) + headers = dict(response.headers) + + self.assertNotIn('Access-Control-Allow-Headers', headers) + self.assertNotIn('Access-Control-Allow-Methods', headers) + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.assertNotIn('Access-Control-Max-Age', headers) + self.assertNotIn('Access-Control-Allow-Credentials', headers) + self.assertNotIn('Access-Control-Expose-Headers', headers) + + def test_api_preflight_cors_headers_disallow_any(self): + get_config().set('cors', 'allow_any_origin', 'false') + get_config().set('cors', 'allowed_origins', '["foo", "bar"]') + response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'}) + headers = dict(response.headers) + + self.assertEqual(self._default_cors.allowed_headers, + headers['Access-Control-Allow-Headers']) + self.assertEqual(self._default_cors.allowed_methods, + headers['Access-Control-Allow-Methods']) + self.assertEqual('foo', headers['Access-Control-Allow-Origin']) + self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age']) + self.assertIsNone(headers.get('Access-Control-Allow-Credentials')) + self.assertIsNone(headers.get('Access-Control-Expose-Headers')) + + def test_api_preflight_cors_headers_disallow_any_no_matched_allowed_origins(self): + get_config().set('cors', 'allow_any_origin', 'false') + get_config().set('cors', 'allowed_origins', '["foo", "bar"]') + response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foobar'}) + headers = dict(response.headers) + + self.assertNotIn('Access-Control-Allow-Headers', headers) + self.assertNotIn('Access-Control-Allow-Methods', headers) + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.assertNotIn('Access-Control-Max-Age', headers) + self.assertNotIn('Access-Control-Allow-Credentials', headers) + self.assertNotIn('Access-Control-Expose-Headers', headers) + + def test_api_preflight_cors_headers_disallow_any_no_allowed_origins(self): + get_config().set('cors', 'allow_any_origin', 'false') + response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'}) + headers = dict(response.headers) + + self.assertNotIn('Access-Control-Allow-Headers', headers) + self.assertNotIn('Access-Control-Allow-Methods', headers) + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.assertNotIn('Access-Control-Max-Age', headers) + self.assertNotIn('Access-Control-Allow-Credentials', headers) + self.assertNotIn('Access-Control-Expose-Headers', headers) + + def test_api_preflight_cors_headers_disabled(self): + get_config().set('cors', 'enabled', 'false') + response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'}) + headers = dict(response.headers) + + self.assertNotIn('Access-Control-Allow-Headers', headers) + self.assertNotIn('Access-Control-Allow-Methods', headers) + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.assertNotIn('Access-Control-Max-Age', headers) + self.assertNotIn('Access-Control-Allow-Credentials', headers) + self.assertNotIn('Access-Control-Expose-Headers', headers) + + def test_api_preflight_cors_headers_no_origin_header(self): + response = self.fetch('/api/graph', method='OPTIONS') + headers = dict(response.headers) + + self.assertNotIn('Access-Control-Allow-Headers', headers) + self.assertNotIn('Access-Control-Allow-Methods', headers) + self.assertNotIn('Access-Control-Allow-Origin', headers) + self.assertNotIn('Access-Control-Max-Age', headers) + self.assertNotIn('Access-Control-Allow-Credentials', headers) + self.assertNotIn('Access-Control-Expose-Headers', headers) + def test_api_cors_headers(self): - response = self.fetch('/api/graph') + response = self.fetch('/api/graph', headers={'Origin': 'foo'}) + headers = dict(response.headers) + + self.assertEqual('*', headers['Access-Control-Allow-Origin']) + + def test_api_cors_headers_null_origin(self): + response = self.fetch('/api/graph', headers={'Origin': 'null'}) + headers = dict(response.headers) + + self.assertEqual('null', headers['Access-Control-Allow-Origin']) + + def test_api_cors_headers_disallow_null(self): + get_config().set('cors', 'allow_null_origin', 'false') + response = self.fetch('/api/graph', headers={'Origin': 'null'}) + headers = dict(response.headers) + + self.assertIsNone(headers.get('Access-Control-Allow-Origin')) + + def test_api_cors_headers_disallow_any(self): + get_config().set('cors', 'allow_any_origin', 'false') + get_config().set('cors', 'allowed_origins', '["foo", "bar"]') + response = self.fetch('/api/graph', headers={'Origin': 'foo'}) headers = dict(response.headers) - def _set(name): - return set(headers[name].replace(" ", "").split(",")) + self.assertEqual('foo', headers['Access-Control-Allow-Origin']) + + def test_api_cors_headers_disallow_any_no_matched_allowed_origins(self): + get_config().set('cors', 'allow_any_origin', 'false') + get_config().set('cors', 'allowed_origins', '["foo", "bar"]') + response = self.fetch('/api/graph', headers={'Origin': 'foobar'}) + headers = dict(response.headers) + + self.assertIsNone(headers.get('Access-Control-Allow-Origin')) + + def test_api_cors_headers_disallow_any_no_allowed_origins(self): + get_config().set('cors', 'allow_any_origin', 'false') + response = self.fetch('/api/graph', headers={'Origin': 'foo'}) + headers = dict(response.headers) + + self.assertIsNone(headers.get('Access-Control-Allow-Origin')) + + def test_api_cors_headers_disabled(self): + get_config().set('cors', 'enabled', 'false') + response = self.fetch('/api/graph', headers={'Origin': 'foo'}) + headers = dict(response.headers) + + self.assertIsNone(headers.get('Access-Control-Allow-Origin')) + + def test_api_cors_headers_no_origin_header(self): + response = self.fetch('/api/graph') + headers = dict(response.headers) - self.assertSetEqual(_set("Access-Control-Allow-Headers"), {"Content-Type", "Accept", "Authorization", "Origin"}) - self.assertSetEqual(_set("Access-Control-Allow-Methods"), {"GET", "OPTIONS"}) - self.assertEqual(headers["Access-Control-Allow-Origin"], "*") + self.assertIsNone(headers.get('Access-Control-Allow-Origin')) class _ServerTest(unittest.TestCase): diff --git a/test/webhdfs_minicluster.py b/test/webhdfs_minicluster.py index 16d59a55f4..3f803030f9 100644 --- a/test/webhdfs_minicluster.py +++ b/test/webhdfs_minicluster.py @@ -66,7 +66,7 @@ def _get_namenode_port(self): line = f.readline() print(line.rstrip()) - m = re.match(".*Jetty bound to port (\d+).*", line) + m = re.match(".*Jetty bound to port (\\d+).*", line) if just_seen_webhdfs and m: return int(m.group(1)) just_seen_webhdfs = re.match(".*namenode.*webhdfs.*", line) diff --git a/tox.ini b/tox.ini index 6eb434746c..07ae0e753e 100644 --- a/tox.ini +++ b/tox.ini @@ -21,6 +21,7 @@ deps= cdh,hdp: snakebite>=2.5.2,<2.6.0 cdh,hdp: hdfs>=2.0.4,<3.0.0 postgres: psycopg2<3.0 + mysql-connector-python>=8.0.12 gcloud: google-api-python-client>=1.4.0,<2.0 py27-gcloud: avro py33-gcloud,py34-gcloud,py35-gcloud,py36-gcloud: avro-python3