diff --git a/README.rst b/README.rst
index 4aa9f767db..522493a838 100644
--- a/README.rst
+++ b/README.rst
@@ -173,6 +173,9 @@ Some more companies are using Luigi but haven't had a chance yet to write about
* `Okko `_
* `ISVWorld `_
* `Big Data `_
+* `Movio `_
+* `Bonnier News `_
+* `Starsky Robotics `_
We're more than happy to have your company added here. Just send a PR on GitHub.
diff --git a/doc/aggregate_artists.png b/doc/aggregate_artists.png
index f7a395668b..e434fd2895 100644
Binary files a/doc/aggregate_artists.png and b/doc/aggregate_artists.png differ
diff --git a/doc/configuration.rst b/doc/configuration.rst
index d152855ca9..f206d53652 100644
--- a/doc/configuration.rst
+++ b/doc/configuration.rst
@@ -4,9 +4,12 @@ Configuration
All configuration can be done by adding configuration files.
Supported config parsers:
-* ``cfg`` (default)
+
+* ``cfg`` (default), based on Python's standard ConfigParser_. Values may refer to environment variables using ``${ENVVAR}`` syntax.
* ``toml``
+.. _ConfigParser: https://docs.python.org/3/library/configparser.html
+
You can choose right parser via ``LUIGI_CONFIG_PARSER`` environment variable. For example, ``LUIGI_CONFIG_PARSER=toml``.
Default (cfg) parser are looked for in:
@@ -202,6 +205,51 @@ rpc-retry-wait
Defaults to 30
+[cors]
+------
+
+.. versionadded:: 2.8.0
+
+These parameters control ``/api/`` ``CORS`` behaviour (see: `W3C Cross-Origin Resource Sharing
+`_).
+
+enabled
+ Enables CORS support.
+ Defaults to false.
+
+allowed_origins
+ A list of allowed origins. Used only if ``allow_any_origin`` is false.
+ Configure in JSON array format, e.g. ["foo", "bar"].
+ Defaults to empty.
+
+allow_any_origin
+ Accepts requests from any origin.
+ Defaults to false.
+
+allow_null_origin
+ Allows the request to set ``null`` value of the ``Origin`` header.
+ Defaults to false.
+
+max_age
+ Content of ``Access-Control-Max-Age``.
+ Defaults to 86400 (24 hours).
+
+allowed_methods
+ Content of ``Access-Control-Allow-Methods``.
+ Defaults to ``GET, OPTIONS``.
+
+allowed_headers
+ Content of ``Access-Control-Allow-Headers``.
+ Defaults to ``Accept, Content-Type, Origin``.
+
+exposed_headers
+ Content of ``Access-Control-Expose-Headers``.
+ Defaults to empty string (will NOT be sent as a response header).
+
+allow_credentials
+ Indicates that the actual request can include user credentials.
+ Defaults to false.
+
.. _worker-config:
[worker]
@@ -607,7 +655,7 @@ is good practice to do so when you have a fixed set of resources.
.. _retcode-config:
[retcode]
-----------
+---------
Configure return codes for the Luigi binary. In the case of multiple return
codes that could apply, for example a failing task and missing data, the
diff --git a/doc/dependency_graph.png b/doc/dependency_graph.png
index 4196f30c02..28531b4c9d 100644
Binary files a/doc/dependency_graph.png and b/doc/dependency_graph.png differ
diff --git a/doc/execution_model.png b/doc/execution_model.png
index 535eb31066..11a2b28ff9 100644
Binary files a/doc/execution_model.png and b/doc/execution_model.png differ
diff --git a/doc/history.png b/doc/history.png
index e8173fcf85..9aa3e2dcbb 100644
Binary files a/doc/history.png and b/doc/history.png differ
diff --git a/doc/history_by_id.png b/doc/history_by_id.png
index 97a90f1cc5..c1ede5b7fd 100644
Binary files a/doc/history_by_id.png and b/doc/history_by_id.png differ
diff --git a/doc/history_by_name.png b/doc/history_by_name.png
index 5bef1291a4..a54fad59bd 100644
Binary files a/doc/history_by_name.png and b/doc/history_by_name.png differ
diff --git a/doc/luigi.png b/doc/luigi.png
index 586387c3ea..8285638906 100644
Binary files a/doc/luigi.png and b/doc/luigi.png differ
diff --git a/doc/parameters_date_algebra.png b/doc/parameters_date_algebra.png
index 7da7d616f0..37705065aa 100644
Binary files a/doc/parameters_date_algebra.png and b/doc/parameters_date_algebra.png differ
diff --git a/doc/parameters_enum.png b/doc/parameters_enum.png
index 9af33a4d51..2a224e44c4 100644
Binary files a/doc/parameters_enum.png and b/doc/parameters_enum.png differ
diff --git a/doc/parameters_recursion.png b/doc/parameters_recursion.png
index 43b23c2155..61eac1e4a8 100644
Binary files a/doc/parameters_recursion.png and b/doc/parameters_recursion.png differ
diff --git a/doc/task_breakdown.png b/doc/task_breakdown.png
index eeb793d45c..f2628b8210 100644
Binary files a/doc/task_breakdown.png and b/doc/task_breakdown.png differ
diff --git a/doc/task_parameters.png b/doc/task_parameters.png
index a26c2a5a74..b5677b55e5 100644
Binary files a/doc/task_parameters.png and b/doc/task_parameters.png differ
diff --git a/doc/task_with_targets.png b/doc/task_with_targets.png
index 0354144bce..46ce7f70fb 100644
Binary files a/doc/task_with_targets.png and b/doc/task_with_targets.png differ
diff --git a/doc/tasks_input_output_requires.png b/doc/tasks_input_output_requires.png
index 06a551b100..09f3b8378b 100644
Binary files a/doc/tasks_input_output_requires.png and b/doc/tasks_input_output_requires.png differ
diff --git a/doc/tasks_with_dependencies.png b/doc/tasks_with_dependencies.png
index 7adcb26674..63536582f3 100644
Binary files a/doc/tasks_with_dependencies.png and b/doc/tasks_with_dependencies.png differ
diff --git a/doc/user_recs.png b/doc/user_recs.png
index 3d3044bdb9..60943bf181 100644
Binary files a/doc/user_recs.png and b/doc/user_recs.png differ
diff --git a/doc/visualiser_front_page.png b/doc/visualiser_front_page.png
index 8acb6891ee..e5bc87bdd8 100644
Binary files a/doc/visualiser_front_page.png and b/doc/visualiser_front_page.png differ
diff --git a/doc/web_server.png b/doc/web_server.png
index cad1541d73..20605d0766 100644
Binary files a/doc/web_server.png and b/doc/web_server.png differ
diff --git a/examples/spark_als.py b/examples/spark_als.py
index dae0479169..c24734b87b 100644
--- a/examples/spark_als.py
+++ b/examples/spark_als.py
@@ -31,7 +31,7 @@ class UserItemMatrix(luigi.Task):
def run(self):
"""
Generates :py:attr:`~.UserItemMatrix.data_size` elements.
- Writes this data in \ separated value format into the target :py:func:`~/.UserItemMatrix.output`.
+ Writes this data in \\ separated value format into the target :py:func:`~/.UserItemMatrix.output`.
The data has the following elements:
@@ -43,7 +43,7 @@ def run(self):
w = self.output().open('w')
for user in range(self.data_size):
track = int(random.random() * self.data_size)
- w.write('%d\%d\%f' % (user, track, 1.0))
+ w.write('%d\\%d\\%f' % (user, track, 1.0))
w.close()
def output(self):
diff --git a/luigi/configuration/cfg_parser.py b/luigi/configuration/cfg_parser.py
index e0df87f10a..c4545fe75a 100644
--- a/luigi/configuration/cfg_parser.py
+++ b/luigi/configuration/cfg_parser.py
@@ -30,16 +30,94 @@
"""
import os
+import re
import warnings
try:
- from ConfigParser import ConfigParser, NoOptionError, NoSectionError
+ from ConfigParser import ConfigParser, NoOptionError, NoSectionError, InterpolationError
+ Interpolation = object
except ImportError:
- from configparser import ConfigParser, NoOptionError, NoSectionError
+ from configparser import ConfigParser, NoOptionError, NoSectionError, InterpolationError
+ from configparser import Interpolation, BasicInterpolation
from .base_parser import BaseParser
+class InterpolationMissingEnvvarError(InterpolationError):
+ """
+ Raised when option value refers to a nonexisting environment variable.
+ """
+
+ def __init__(self, option, section, value, envvar):
+ msg = (
+ "Config refers to a nonexisting environment variable {}. "
+ "Section [{}], option {}={}"
+ ).format(envvar, section, option, value)
+ InterpolationError.__init__(self, option, section, msg)
+
+
+class EnvironmentInterpolation(Interpolation):
+ """
+ Custom interpolation which allows values to refer to environment variables
+ using the ``${ENVVAR}`` syntax.
+ """
+ _ENVRE = re.compile(r"\$\{([^}]+)\}") # matches "${envvar}"
+
+ def before_get(self, parser, section, option, value, defaults):
+ return self._interpolate_env(option, section, value)
+
+ def _interpolate_env(self, option, section, value):
+ rawval = value
+ parts = []
+ while value:
+ match = self._ENVRE.search(value)
+ if match is None:
+ parts.append(value)
+ break
+ envvar = match.groups()[0]
+ try:
+ envval = os.environ[envvar]
+ except KeyError:
+ raise InterpolationMissingEnvvarError(
+ option, section, rawval, envvar)
+ start, end = match.span()
+ parts.append(value[:start])
+ parts.append(envval)
+ value = value[end:]
+ return "".join(parts)
+
+
+class CombinedInterpolation(Interpolation):
+ """
+ Custom interpolation which applies multiple interpolations in series.
+
+ :param interpolations: a sequence of configparser.Interpolation objects.
+ """
+
+ def __init__(self, interpolations):
+ self._interpolations = interpolations
+
+ def before_get(self, parser, section, option, value, defaults):
+ for interp in self._interpolations:
+ value = interp.before_get(parser, section, option, value, defaults)
+ return value
+
+ def before_read(self, parser, section, option, value):
+ for interp in self._interpolations:
+ value = interp.before_read(parser, section, option, value)
+ return value
+
+ def before_set(self, parser, section, option, value):
+ for interp in self._interpolations:
+ value = interp.before_set(parser, section, option, value)
+ return value
+
+ def before_write(self, parser, section, option, value):
+ for interp in self._interpolations:
+ value = interp.before_write(parser, section, option, value)
+ return value
+
+
class LuigiConfigParser(BaseParser, ConfigParser):
NO_DEFAULT = object()
enabled = True
@@ -50,6 +128,17 @@ class LuigiConfigParser(BaseParser, ConfigParser):
'client.cfg', # Deprecated old-style local luigi config
'luigi.cfg',
]
+ if hasattr(ConfigParser, "_interpolate"):
+ # Override ConfigParser._interpolate (Python 2)
+ def _interpolate(self, section, option, rawval, vars):
+ value = ConfigParser._interpolate(self, section, option, rawval, vars)
+ return EnvironmentInterpolation().before_get(
+ parser=self, section=section, option=option,
+ value=value, defaults=None,
+ )
+ else:
+ # Override ConfigParser._DEFAULT_INTERPOLATION (Python 3)
+ _DEFAULT_INTERPOLATION = CombinedInterpolation([BasicInterpolation(), EnvironmentInterpolation()])
@classmethod
def reload(cls):
diff --git a/luigi/contrib/ecs.py b/luigi/contrib/ecs.py
index f563e73dc0..8612d00499 100644
--- a/luigi/contrib/ecs.py
+++ b/luigi/contrib/ecs.py
@@ -182,6 +182,11 @@ def run(self):
response = client.run_task(taskDefinition=self.task_def_arn,
overrides=overrides,
cluster=self.cluster)
+
+ if response['failures']:
+ raise Exception(", ".join(["fail to run task {0} reason: {1}".format(failure['arn'], failure['reason'])
+ for failure in response['failures']]))
+
self._task_ids = [task['taskArn'] for task in response['tasks']]
# Wait on task completion
diff --git a/luigi/contrib/external_program.py b/luigi/contrib/external_program.py
index 4d726f9eef..0377e6847c 100644
--- a/luigi/contrib/external_program.py
+++ b/luigi/contrib/external_program.py
@@ -59,7 +59,7 @@ class ExternalProgramTask(luigi.Task):
behaviour can be overriden by passing ``--capture-output False``
"""
- capture_output = luigi.BoolParameter(default=True, significant=False)
+ capture_output = luigi.BoolParameter(default=True, significant=False, positional=False)
def program_args(self):
"""
diff --git a/luigi/contrib/hdfs/format.py b/luigi/contrib/hdfs/format.py
index 1856abf3f4..f87ebd5bf3 100644
--- a/luigi/contrib/hdfs/format.py
+++ b/luigi/contrib/hdfs/format.py
@@ -1,9 +1,10 @@
-import luigi.format
import logging
import os
+
+import luigi.format
from luigi.contrib.hdfs.config import load_hadoop_cmd
from luigi.contrib.hdfs import config as hdfs_config
-from luigi.contrib.hdfs.clients import remove, rename, mkdir, listdir
+from luigi.contrib.hdfs.clients import remove, rename, mkdir, listdir, exists
from luigi.contrib.hdfs.error import HDFSCliError
logger = logging.getLogger('luigi-interface')
@@ -75,9 +76,13 @@ def abort(self):
def close(self):
super(HdfsAtomicWriteDirPipe, self).close()
try:
- remove(self.path)
- except HDFSCliError:
- pass
+ if exists(self.path):
+ remove(self.path)
+ except Exception as ex:
+ if isinstance(ex, HDFSCliError) or ex.args[0].contains("FileNotFoundException"):
+ pass
+ else:
+ raise ex
# it's unlikely to fail in this way but better safe than sorry
if not all(result['result'] for result in rename(self.tmppath, self.path) or []):
diff --git a/luigi/contrib/hdfs/snakebite_client.py b/luigi/contrib/hdfs/snakebite_client.py
index 5c38787f4f..b6cd7aedfc 100644
--- a/luigi/contrib/hdfs/snakebite_client.py
+++ b/luigi/contrib/hdfs/snakebite_client.py
@@ -131,7 +131,7 @@ def remove(self, path, recursive=True, skip_trash=False):
:param path: delete-able file(s) or directory(ies)
:type path: either a string or a sequence of strings
- :param recursive: delete directories trees like \*nix: rm -r
+ :param recursive: delete directories trees like \\*nix: rm -r
:type recursive: boolean, default is True
:param skip_trash: do or don't move deleted items into the trash first
:type skip_trash: boolean, default is False (use trash)
@@ -145,7 +145,7 @@ def chmod(self, path, permissions, recursive=False):
:param path: update-able file(s)
:type path: either a string or sequence of strings
- :param permissions: \*nix style permission number
+ :param permissions: \\*nix style permission number
:type permissions: octal
:param recursive: change just listed entry(ies) or all in directories
:type recursive: boolean, default is False
@@ -242,7 +242,7 @@ def mkdir(self, path, parents=True, mode=0o755, raise_if_exists=False):
:type path: string
:param parents: create any missing parent directories
:type parents: boolean, default is True
- :param mode: \*nix style owner/group/other permissions
+ :param mode: \\*nix style owner/group/other permissions
:type mode: octal, default 0755
"""
result = list(self.get_bite().mkdir(self.list_path(path),
diff --git a/luigi/contrib/hdfs/target.py b/luigi/contrib/hdfs/target.py
index 0655b23e0f..9beba56190 100644
--- a/luigi/contrib/hdfs/target.py
+++ b/luigi/contrib/hdfs/target.py
@@ -185,3 +185,38 @@ def _is_writable(self, path):
return True
except hdfs_clients.HDFSCliError:
return False
+
+
+class HdfsFlagTarget(HdfsTarget):
+ """
+ Defines a target directory with a flag-file (defaults to `_SUCCESS`) used
+ to signify job success.
+
+ This checks for two things:
+
+ * the path exists (just like the HdfsTarget)
+ * the _SUCCESS file exists within the directory.
+
+ Because Hadoop outputs into a directory and not a single file,
+ the path is assumed to be a directory.
+ """
+ def __init__(self, path, format=None, client=None, flag='_SUCCESS'):
+ """
+ Initializes a HdfsFlagTarget.
+
+ :param path: the directory where the files are stored.
+ :type path: str
+ :param client:
+ :type client:
+ :param flag:
+ :type flag: str
+ """
+ if path[-1] != "/":
+ raise ValueError("HdfsFlagTarget requires the path to be to a "
+ "directory. It must end with a slash ( / ).")
+ super(HdfsFlagTarget, self).__init__(path, format, client)
+ self.flag = flag
+
+ def exists(self):
+ hadoopSemaphore = self.path + self.flag
+ return self.fs.exists(hadoopSemaphore)
diff --git a/luigi/contrib/mysqldb.py b/luigi/contrib/mysqldb.py
index 9e198fcfc6..dabef5ff73 100644
--- a/luigi/contrib/mysqldb.py
+++ b/luigi/contrib/mysqldb.py
@@ -19,14 +19,16 @@
import luigi
+from luigi.contrib import rdbms
+
logger = logging.getLogger('luigi-interface')
try:
import mysql.connector
- from mysql.connector import errorcode
+ from mysql.connector import errorcode, Error
except ImportError as e:
logger.warning("Loading MySQL module without the python package mysql-connector-python. \
- This will crash at runtime if MySQL functionality is used.")
+ This will crash at runtime if MySQL functionality is used.")
class MySqlTarget(luigi.Target):
@@ -147,3 +149,102 @@ def create_marker_table(self):
else:
raise
connection.close()
+
+
+class CopyToTable(rdbms.CopyToTable):
+ """
+ Template task for inserting a data set into MySQL
+
+ Usage:
+ Subclass and override the required `host`, `database`, `user`,
+ `password`, `table` and `columns` attributes.
+
+ To customize how to access data from an input task, override the `rows` method
+ with a generator that yields each row as a tuple with fields ordered according to `columns`.
+ """
+
+ def rows(self):
+ """
+ Return/yield tuples or lists corresponding to each row to be inserted.
+ """
+ with self.input().open('r') as fobj:
+ for line in fobj:
+ yield line.strip('\n').split('\t')
+
+# everything below will rarely have to be overridden
+
+ def output(self):
+ """
+ Returns a MySqlTarget representing the inserted dataset.
+
+ Normally you don't override this.
+ """
+ return MySqlTarget(
+ host=self.host,
+ database=self.database,
+ user=self.user,
+ password=self.password,
+ table=self.table,
+ update_id=self.update_id
+
+ )
+
+ def copy(self, cursor, file=None):
+ values = '({})'.format(','.join(['%s' for i in range(len(self.columns))]))
+ columns = '({})'.format(','.join([c[0] for c in self.columns]))
+ query = 'INSERT INTO {} {} VALUES {}'.format(self.table, columns, values)
+ rows = []
+
+ for idx, row in enumerate(self.rows()):
+ rows.append(row)
+
+ if (idx + 1) % self.bulk_size == 0:
+ cursor.executemany(query, rows)
+ rows = []
+
+ cursor.executemany(query, rows)
+
+ def run(self):
+ """
+ Inserts data generated by rows() into target table.
+
+ If the target table doesn't exist, self.create_table will be called to attempt to create the table.
+
+ Normally you don't want to override this.
+ """
+ if not (self.table and self.columns):
+ raise Exception("table and columns need to be specified")
+
+ connection = self.output().connect()
+
+ # attempt to copy the data into mysql
+ # if it fails because the target table doesn't exist
+ # try to create it by running self.create_table
+ for attempt in range(2):
+ try:
+ cursor = connection.cursor()
+ print("caling init copy...")
+ self.init_copy(connection)
+ self.copy(cursor)
+ self.post_copy(connection)
+ if self.enable_metadata_columns:
+ self.post_copy_metacolumns(cursor)
+ except Error as err:
+ if err.errno == errorcode.ER_NO_SUCH_TABLE and attempt == 0:
+ # if first attempt fails with "relation not found", try creating table
+ # logger.info("Creating table %s", self.table)
+ connection.reconnect()
+ self.create_table(connection)
+ else:
+ raise
+ else:
+ break
+
+ # mark as complete in same transaction
+ self.output().touch(connection)
+ connection.commit()
+ connection.close()
+
+ @property
+ def bulk_size(self):
+ return 10000
diff --git a/luigi/contrib/opener.py b/luigi/contrib/opener.py
index 7607583c52..93c360b9de 100644
--- a/luigi/contrib/opener.py
+++ b/luigi/contrib/opener.py
@@ -189,7 +189,7 @@ def get_target(cls, scheme, path, fragment, username,
the expected target.
"""
- raise NotImplemented("get_target must be overridden")
+ raise NotImplementedError("get_target must be overridden")
class MockOpener(Opener):
diff --git a/luigi/contrib/s3.py b/luigi/contrib/s3.py
index 9fb21d283f..1ccc438d7c 100644
--- a/luigi/contrib/s3.py
+++ b/luigi/contrib/s3.py
@@ -31,9 +31,6 @@
import warnings
from multiprocessing.pool import ThreadPool
-import botocore
-from boto3.s3.transfer import TransferConfig
-
try:
from urlparse import urlsplit
except ImportError:
@@ -54,6 +51,13 @@
logger = logging.getLogger('luigi-interface')
+try:
+ from boto3.s3.transfer import TransferConfig
+ import botocore
+except ImportError:
+ logger.warning("Loading S3 module without the python package boto3. "
+ "Will crash at runtime if S3 functionality is used.")
+
# two different ways of marking a directory
# with a suffix in S3
diff --git a/luigi/server.py b/luigi/server.py
index 79a696cc82..6c97cb7427 100644
--- a/luigi/server.py
+++ b/luigi/server.py
@@ -36,12 +36,12 @@
#
import atexit
+import datetime
import json
import logging
import os
import signal
import sys
-import datetime
import time
import pkg_resources
@@ -50,21 +50,64 @@
import tornado.netutil
import tornado.web
+from luigi import Config, parameter
from luigi.scheduler import Scheduler, RPC_METHODS
logger = logging.getLogger("luigi.server")
+class cors(Config):
+ enabled = parameter.BoolParameter(
+ default=False,
+ description='Enables CORS support.')
+ allowed_origins = parameter.ListParameter(
+ default=[],
+ description='A list of allowed origins. Used only if `allow_any_origin` is false.')
+ allow_any_origin = parameter.BoolParameter(
+ default=False,
+ description='Accepts requests from any origin.')
+ allow_null_origin = parameter.BoolParameter(
+ default=False,
+ description='Allows the request to set `null` value of the `Origin` header.')
+ max_age = parameter.IntParameter(
+ default=86400,
+ description='Content of `Access-Control-Max-Age`.')
+ allowed_methods = parameter.Parameter(
+ default='GET, OPTIONS',
+ description='Content of `Access-Control-Allow-Methods`.')
+ allowed_headers = parameter.Parameter(
+ default='Accept, Content-Type, Origin',
+ description='Content of `Access-Control-Allow-Headers`.')
+ exposed_headers = parameter.Parameter(
+ default='',
+ description='Content of `Access-Control-Expose-Headers`.')
+ allow_credentials = parameter.BoolParameter(
+ default=False,
+ description='Indicates that the actual request can include user credentials.')
+
+ def __init__(self, *args, **kwargs):
+ super(cors, self).__init__(*args, **kwargs)
+ self.allowed_origins = set(i for i in self.allowed_origins if i not in ['*', 'null'])
+
+
class RPCHandler(tornado.web.RequestHandler):
"""
Handle remote scheduling calls using rpc.RemoteSchedulerResponder.
"""
+ def __init__(self, *args, **kwargs):
+ super(RPCHandler, self).__init__(*args, **kwargs)
+ self._cors_config = cors()
+
def initialize(self, scheduler):
self._scheduler = scheduler
- self.set_header("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type, Origin")
- self.set_header("Access-Control-Allow-Methods", "GET, OPTIONS")
- self.set_header("Access-Control-Allow-Origin", "*")
+
+ def options(self, *args):
+ if self._cors_config.enabled:
+ self._handle_cors_preflight()
+
+ self.set_status(204)
+ self.finish()
def get(self, method):
if method not in RPC_METHODS:
@@ -75,12 +118,57 @@ def get(self, method):
if hasattr(self._scheduler, method):
result = getattr(self._scheduler, method)(**arguments)
+
+ if self._cors_config.enabled:
+ self._handle_cors()
+
self.write({"response": result}) # wrap all json response in a dictionary
else:
self.send_error(404)
post = get
+ def _handle_cors_preflight(self):
+ origin = self.request.headers.get('Origin')
+ if not origin:
+ return
+
+ if origin == 'null':
+ if self._cors_config.allow_null_origin:
+ self.set_header('Access-Control-Allow-Origin', 'null')
+ self._set_other_cors_headers()
+ else:
+ if self._cors_config.allow_any_origin:
+ self.set_header('Access-Control-Allow-Origin', '*')
+ self._set_other_cors_headers()
+ elif origin in self._cors_config.allowed_origins:
+ self.set_header('Access-Control-Allow-Origin', origin)
+ self._set_other_cors_headers()
+
+ def _handle_cors(self):
+ origin = self.request.headers.get('Origin')
+ if not origin:
+ return
+
+ if origin == 'null':
+ if self._cors_config.allow_null_origin:
+ self.set_header('Access-Control-Allow-Origin', 'null')
+ else:
+ if self._cors_config.allow_any_origin:
+ self.set_header('Access-Control-Allow-Origin', '*')
+ elif origin in self._cors_config.allowed_origins:
+ self.set_header('Access-Control-Allow-Origin', origin)
+ self.set_header('Vary', 'Origin')
+
+ def _set_other_cors_headers(self):
+ self.set_header('Access-Control-Max-Age', str(self._cors_config.max_age))
+ self.set_header('Access-Control-Allow-Methods', self._cors_config.allowed_methods)
+ self.set_header('Access-Control-Allow-Headers', self._cors_config.allowed_headers)
+ if self._cors_config.allow_credentials:
+ self.set_header('Access-Control-Allow-Credentials', 'true')
+ if self._cors_config.exposed_headers:
+ self.set_header('Access-Control-Expose-Headers', self._cors_config.exposed_headers)
+
class BaseTaskHistoryHandler(tornado.web.RequestHandler):
def initialize(self, scheduler):
diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_222222_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_222222_256x240.png
index c1cb1170c8..4586de1a05 100755
Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_222222_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_222222_256x240.png differ
diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_2e83ff_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_2e83ff_256x240.png
index 84b601bf0f..c4dfa36bd5 100755
Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_2e83ff_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_2e83ff_256x240.png differ
diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_454545_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_454545_256x240.png
index b6db1acdd4..c39a7411b0 100755
Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_454545_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_454545_256x240.png differ
diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_888888_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_888888_256x240.png
index feea0e2026..cef233a7fd 100755
Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_888888_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_888888_256x240.png differ
diff --git a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_cd0a0a_256x240.png b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_cd0a0a_256x240.png
index ed5b6b0930..64733826fd 100755
Binary files a/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_cd0a0a_256x240.png and b/luigi/static/visualiser/lib/jquery-ui/css/images/ui-icons_cd0a0a_256x240.png differ
diff --git a/luigi/tools/deps_tree.py b/luigi/tools/deps_tree.py
index 13a63164c9..27a00313e3 100755
--- a/luigi/tools/deps_tree.py
+++ b/luigi/tools/deps_tree.py
@@ -44,7 +44,7 @@ def print_tree(task, indent='', last=True):
'''
# dont bother printing out warnings about tasks with no output
with warnings.catch_warnings():
- warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete\(\) method')
+ warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete\\(\\) method')
is_task_complete = task.complete()
is_complete = (bcolors.OKGREEN + 'COMPLETE' if is_task_complete else bcolors.OKBLUE + 'PENDING') + bcolors.ENDC
name = task.__class__.__name__
diff --git a/luigi/worker.py b/luigi/worker.py
index 6cdaff1884..c7a24c3c5c 100644
--- a/luigi/worker.py
+++ b/luigi/worker.py
@@ -61,7 +61,7 @@
from luigi.task import Task, flatten, getpaths, Config
from luigi.task_register import TaskClassException
from luigi.task_status import RUNNING
-from luigi.parameter import BoolParameter, FloatParameter, IntParameter, Parameter
+from luigi.parameter import BoolParameter, FloatParameter, IntParameter, OptionalParameter
try:
import simplejson as json
@@ -446,12 +446,12 @@ class worker(Config):
force_multiprocessing = BoolParameter(default=False,
description='If true, use multiprocessing also when '
'running with 1 worker')
- task_process_context = Parameter(default=None,
- description='If set to a fully qualified class name, the class will '
- 'be instantiated with a TaskProcess as its constructor parameter and '
- 'applied as a context manager around its run() call, so this can be '
- 'used for obtaining high level customizable monitoring or logging of '
- 'each individual Task run.')
+ task_process_context = OptionalParameter(default=None,
+ description='If set to a fully qualified class name, the class will '
+ 'be instantiated with a TaskProcess as its constructor parameter and '
+ 'applied as a context manager around its run() call, so this can be '
+ 'used for obtaining high level customizable monitoring or logging of '
+ 'each individual Task run.')
class KeepAliveThread(threading.Thread):
diff --git a/setup.py b/setup.py
index 0fa21e38a2..c015102b73 100644
--- a/setup.py
+++ b/setup.py
@@ -13,7 +13,6 @@
# the License.
import os
-import sys
from setuptools import setup
@@ -41,21 +40,19 @@ def get_static_files(path):
'tornado>=4.0,<5',
# https://pagure.io/python-daemon/issue/18
'python-daemon<2.2.0',
+ 'enum34>1.1.0;python_version<"3.4"',
]
if os.environ.get('READTHEDOCS', None) == 'True':
# So that we can build documentation for luigi.db_task_history and luigi.contrib.sqla
install_requires.append('sqlalchemy')
# readthedocs don't like python-daemon, see #1342
- install_requires.remove('python-daemon<3.0')
+ install_requires.remove('python-daemon<2.2.0')
install_requires.append('sphinx>=1.4.4') # Value mirrored in doc/conf.py
-if sys.version_info < (3, 4):
- install_requires.append('enum34>1.1.0')
-
setup(
name='luigi',
- version='2.7.9',
+ version='2.8.0',
description='Workflow mgmgt + task scheduling + dependency resolution',
long_description=long_description,
author='The Luigi Authors',
diff --git a/test/config_env_test.py b/test/config_env_test.py
new file mode 100644
index 0000000000..d5ce8a53a0
--- /dev/null
+++ b/test/config_env_test.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright 2018 Vote inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+
+from luigi.configuration import LuigiConfigParser, get_config
+from luigi.configuration.cfg_parser import InterpolationMissingEnvvarError
+
+from helpers import LuigiTestCase, with_config
+
+
+class ConfigParserTest(LuigiTestCase):
+
+ environ = {
+ "TESTVAR": "1",
+ }
+
+ def setUp(self):
+ self.environ_backup = {
+ os.environ[key] for key in self.environ
+ if key in os.environ
+ }
+ for key, value in self.environ.items():
+ os.environ[key] = value
+ LuigiConfigParser._instance = None
+ super(ConfigParserTest, self).setUp()
+
+ def tearDown(self):
+ for key in self.environ:
+ os.environ.pop(key)
+ for key, value in self.environ_backup:
+ os.environ[key] = value
+
+ @with_config({"test": {
+ "a": "testval",
+ "b": "%(a)s",
+ "c": "%(a)s%(a)s",
+ }})
+ def test_basic_interpolation(self):
+ # Make sure the default ConfigParser behaviour is not broken
+ config = get_config()
+
+ self.assertEqual(config.get("test", "b"), config.get("test", "a"))
+ self.assertEqual(config.get("test", "c"), 2 * config.get("test", "a"))
+
+ @with_config({"test": {
+ "a": "${TESTVAR}",
+ "b": "${TESTVAR} ${TESTVAR}",
+ "c": "${TESTVAR} %(a)s",
+ "d": "${NONEXISTING}",
+ }})
+ def test_env_interpolation(self):
+ config = get_config()
+
+ self.assertEqual(config.get("test", "a"), "1")
+ self.assertEqual(config.getint("test", "a"), 1)
+ self.assertEqual(config.getboolean("test", "a"), True)
+
+ self.assertEqual(config.get("test", "b"), "1 1")
+
+ self.assertEqual(config.get("test", "c"), "1 1")
+
+ with self.assertRaises(InterpolationMissingEnvvarError):
+ config.get("test", "d")
diff --git a/test/contrib/hdfs_test.py b/test/contrib/hdfs_test.py
index 8b0e2f5edd..9991d3c852 100644
--- a/test/contrib/hdfs_test.py
+++ b/test/contrib/hdfs_test.py
@@ -532,27 +532,45 @@ def test_tmppath_not_configured(self):
res9 = hdfs.tmppath(path9, include_unix_username=False)
# Then: I should get correct results relative to Luigi temporary directory
- self.assertRegexpMatches(res1, "^/tmp/dir1/dir2/file-luigitemp-\d+")
+ self.assertRegexpMatches(res1, "^/tmp/dir1/dir2/file-luigitemp-\\d+")
# it would be better to see hdfs:///path instead of hdfs:/path, but single slash also works well
- self.assertRegexpMatches(res2, "^hdfs:/tmp/dir1/dir2/file-luigitemp-\d+")
- self.assertRegexpMatches(res3, "^hdfs://somehost/tmp/dir1/dir2/file-luigitemp-\d+")
- self.assertRegexpMatches(res4, "^file:///tmp/dir1/dir2/file-luigitemp-\d+")
- self.assertRegexpMatches(res5, "^/tmp/dir/file-luigitemp-\d+")
+ self.assertRegexpMatches(res2, "^hdfs:/tmp/dir1/dir2/file-luigitemp-\\d+")
+ self.assertRegexpMatches(res3, "^hdfs://somehost/tmp/dir1/dir2/file-luigitemp-\\d+")
+ self.assertRegexpMatches(res4, "^file:///tmp/dir1/dir2/file-luigitemp-\\d+")
+ self.assertRegexpMatches(res5, "^/tmp/dir/file-luigitemp-\\d+")
# known issue with duplicated "tmp" if schema is present
- self.assertRegexpMatches(res6, "^file:///tmp/tmp/dir/file-luigitemp-\d+")
+ self.assertRegexpMatches(res6, "^file:///tmp/tmp/dir/file-luigitemp-\\d+")
# known issue with duplicated "tmp" if schema is present
- self.assertRegexpMatches(res7, "^hdfs://somehost/tmp/tmp/dir/file-luigitemp-\d+")
- self.assertRegexpMatches(res8, "^/tmp/luigitemp-\d+")
+ self.assertRegexpMatches(res7, "^hdfs://somehost/tmp/tmp/dir/file-luigitemp-\\d+")
+ self.assertRegexpMatches(res8, "^/tmp/luigitemp-\\d+")
self.assertRegexpMatches(res9, "/tmp/tmpdir/file")
def test_tmppath_username(self):
self.assertRegexpMatches(hdfs.tmppath('/path/to/stuff', include_unix_username=True),
- "^/tmp/[a-z0-9_]+/path/to/stuff-luigitemp-\d+")
+ "^/tmp/[a-z0-9_]+/path/to/stuff-luigitemp-\\d+")
def test_pickle(self):
t = hdfs.HdfsTarget("/tmp/dir")
pickle.dumps(t)
+ def test_flag_target(self):
+ target = hdfs.HdfsFlagTarget("/some/dir/", format=format)
+ if target.exists():
+ target.remove(skip_trash=True)
+ self.assertFalse(target.exists())
+
+ t1 = hdfs.HdfsTarget(target.path + "part-00000", format=format)
+ with t1.open('w'):
+ pass
+ t2 = hdfs.HdfsTarget(target.path + "_SUCCESS", format=format)
+ with t2.open('w'):
+ pass
+ self.assertTrue(target.exists())
+
+ def test_flag_target_fails_if_not_directory(self):
+ with self.assertRaises(ValueError):
+ hdfs.HdfsFlagTarget("/home/file.txt")
+
@attr('minicluster')
class HdfsTargetTest(MiniClusterTestCase, HdfsTargetTestMixin):
diff --git a/test/contrib/mysqldb_test.py b/test/contrib/mysqldb_test.py
new file mode 100644
index 0000000000..fe720225d1
--- /dev/null
+++ b/test/contrib/mysqldb_test.py
@@ -0,0 +1,124 @@
+from luigi.tools.range import RangeDaily
+
+import mock
+
+import luigi.contrib.mysqldb
+
+import datetime
+from helpers import unittest
+
+from nose.plugins.attrib import attr
+
+
+def datetime_to_epoch(dt):
+ td = dt - datetime.datetime(1970, 1, 1)
+ return td.days * 86400 + td.seconds + td.microseconds / 1E6
+
+
+class MockMysqlCursor(mock.Mock):
+ """
+ Keeps state to simulate executing SELECT queries and fetching results.
+ """
+ def __init__(self, existing_update_ids):
+ super(MockMysqlCursor, self).__init__()
+ self.existing = existing_update_ids
+
+ def execute(self, query, params):
+ if query.startswith('SELECT 1 FROM table_updates'):
+ self.fetchone_result = (1, ) if params[0] in self.existing else None
+ else:
+ self.fetchone_result = None
+
+ def fetchone(self):
+ return self.fetchone_result
+
+
+class DummyMysqlImporter(luigi.contrib.mysqldb.CopyToTable):
+ date = luigi.DateParameter()
+
+ host = 'dummy_host'
+ database = 'dummy_database'
+ user = 'dummy_user'
+ password = 'dummy_password'
+ table = 'dummy_table'
+ columns = (
+ ('some_text', 'text'),
+ ('some_int', 'int'),
+ )
+
+
+# Testing that an existing update will not be run in RangeDaily
+@attr('mysql')
+class DailyCopyToTableTest(unittest.TestCase):
+
+ @mock.patch('mysql.connector.connect')
+ def test_bulk_complete(self, mock_connect):
+ mock_cursor = MockMysqlCursor([ # Existing update_ids
+ DummyMysqlImporter(date=datetime.datetime(2015, 1, 3)).task_id
+ ])
+ mock_connect.return_value.cursor.return_value = mock_cursor
+
+ task = RangeDaily(of=DummyMysqlImporter,
+ start=datetime.date(2015, 1, 2),
+ now=datetime_to_epoch(datetime.datetime(2015, 1, 7)))
+ actual = sorted([t.task_id for t in task.requires()])
+
+ self.assertEqual(actual, sorted([
+ DummyMysqlImporter(date=datetime.datetime(2015, 1, 2)).task_id,
+ DummyMysqlImporter(date=datetime.datetime(2015, 1, 4)).task_id,
+ DummyMysqlImporter(date=datetime.datetime(2015, 1, 5)).task_id,
+ DummyMysqlImporter(date=datetime.datetime(2015, 1, 6)).task_id,
+ ]))
+ self.assertFalse(task.complete())
+
+
+@attr('mysql')
+class TestCopyToTableWithMetaColumns(unittest.TestCase):
+ @mock.patch("luigi.contrib.mysqldb.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True)
+ @mock.patch("luigi.contrib.mysqldb.CopyToTable._add_metadata_columns")
+ @mock.patch("luigi.contrib.mysqldb.CopyToTable.post_copy_metacolumns")
+ @mock.patch("luigi.contrib.mysqldb.CopyToTable.rows", return_value=['row1', 'row2'])
+ @mock.patch("luigi.contrib.mysqldb.MySqlTarget")
+ @mock.patch('mysql.connector.connect')
+ def test_copy_with_metadata_columns_enabled(self,
+ mock_connect,
+ mock_mysql_target,
+ mock_rows,
+ mock_add_columns,
+ mock_update_columns,
+ mock_metadata_columns_enabled):
+
+ task = DummyMysqlImporter(date=datetime.datetime(1991, 3, 24))
+
+ mock_cursor = MockMysqlCursor([task.task_id])
+ mock_connect.return_value.cursor.return_value = mock_cursor
+
+ task = DummyMysqlImporter(date=datetime.datetime(1991, 3, 24))
+ task.run()
+
+ self.assertTrue(mock_add_columns.called)
+ self.assertTrue(mock_update_columns.called)
+
+ @mock.patch("luigi.contrib.mysqldb.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False)
+ @mock.patch("luigi.contrib.mysqldb.CopyToTable._add_metadata_columns")
+ @mock.patch("luigi.contrib.mysqldb.CopyToTable.post_copy_metacolumns")
+ @mock.patch("luigi.contrib.mysqldb.CopyToTable.rows", return_value=['row1', 'row2'])
+ @mock.patch("luigi.contrib.mysqldb.MySqlTarget")
+ @mock.patch('mysql.connector.connect')
+ def test_copy_with_metadata_columns_disabled(self,
+ mock_connect,
+ mock_mysql_target,
+ mock_rows,
+ mock_add_columns,
+ mock_update_columns,
+ mock_metadata_columns_enabled):
+
+ task = DummyMysqlImporter(date=datetime.datetime(1991, 3, 24))
+
+ mock_cursor = MockMysqlCursor([task.task_id])
+ mock_connect.return_value.cursor.return_value = mock_cursor
+
+ task.run()
+
+ self.assertFalse(mock_add_columns.called)
+ self.assertFalse(mock_update_columns.called)
diff --git a/test/server_test.py b/test/server_test.py
index 7b0814768e..1516f8dcef 100644
--- a/test/server_test.py
+++ b/test/server_test.py
@@ -25,6 +25,7 @@
import luigi.rpc
import luigi.server
import luigi.cmdline
+from luigi.configuration import get_config
from luigi.scheduler import Scheduler
from luigi.six.moves.urllib.parse import (
urlencode, ParseResult, quote as urlquote
@@ -84,6 +85,19 @@ def tearDown(self):
class ServerTest(ServerTestBase):
+ def setUp(self):
+ super(ServerTest, self).setUp()
+ get_config().remove_section('cors')
+ self._default_cors = luigi.server.cors()
+
+ get_config().set('cors', 'enabled', 'true')
+ get_config().set('cors', 'allow_any_origin', 'true')
+ get_config().set('cors', 'allow_null_origin', 'true')
+
+ def tearDown(self):
+ super(ServerTest, self).tearDown()
+ get_config().remove_section('cors')
+
def test_visualiser(self):
page = self.fetch('/').body
self.assertTrue(page.find(b'') != -1)
@@ -98,16 +112,176 @@ def test_404(self):
def test_api_404(self):
self._test_404('/api/foo')
+ def test_api_preflight_cors_headers(self):
+ response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'})
+ headers = dict(response.headers)
+
+ self.assertEqual(self._default_cors.allowed_headers,
+ headers['Access-Control-Allow-Headers'])
+ self.assertEqual(self._default_cors.allowed_methods,
+ headers['Access-Control-Allow-Methods'])
+ self.assertEqual('*', headers['Access-Control-Allow-Origin'])
+ self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age'])
+ self.assertIsNone(headers.get('Access-Control-Allow-Credentials'))
+ self.assertIsNone(headers.get('Access-Control-Expose-Headers'))
+
+ def test_api_preflight_cors_headers_all_response_headers(self):
+ get_config().set('cors', 'allow_credentials', 'true')
+ get_config().set('cors', 'exposed_headers', 'foo, bar')
+ response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'})
+ headers = dict(response.headers)
+
+ self.assertEqual(self._default_cors.allowed_headers,
+ headers['Access-Control-Allow-Headers'])
+ self.assertEqual(self._default_cors.allowed_methods,
+ headers['Access-Control-Allow-Methods'])
+ self.assertEqual('*', headers['Access-Control-Allow-Origin'])
+ self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age'])
+ self.assertEqual('true', headers['Access-Control-Allow-Credentials'])
+ self.assertEqual('foo, bar', headers['Access-Control-Expose-Headers'])
+
+ def test_api_preflight_cors_headers_null_origin(self):
+ response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'null'})
+ headers = dict(response.headers)
+
+ self.assertEqual(self._default_cors.allowed_headers,
+ headers['Access-Control-Allow-Headers'])
+ self.assertEqual(self._default_cors.allowed_methods,
+ headers['Access-Control-Allow-Methods'])
+ self.assertEqual('null', headers['Access-Control-Allow-Origin'])
+ self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age'])
+ self.assertIsNone(headers.get('Access-Control-Allow-Credentials'))
+ self.assertIsNone(headers.get('Access-Control-Expose-Headers'))
+
+ def test_api_preflight_cors_headers_disallow_null(self):
+ get_config().set('cors', 'allow_null_origin', 'false')
+ response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'null'})
+ headers = dict(response.headers)
+
+ self.assertNotIn('Access-Control-Allow-Headers', headers)
+ self.assertNotIn('Access-Control-Allow-Methods', headers)
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.assertNotIn('Access-Control-Max-Age', headers)
+ self.assertNotIn('Access-Control-Allow-Credentials', headers)
+ self.assertNotIn('Access-Control-Expose-Headers', headers)
+
+ def test_api_preflight_cors_headers_disallow_any(self):
+ get_config().set('cors', 'allow_any_origin', 'false')
+ get_config().set('cors', 'allowed_origins', '["foo", "bar"]')
+ response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'})
+ headers = dict(response.headers)
+
+ self.assertEqual(self._default_cors.allowed_headers,
+ headers['Access-Control-Allow-Headers'])
+ self.assertEqual(self._default_cors.allowed_methods,
+ headers['Access-Control-Allow-Methods'])
+ self.assertEqual('foo', headers['Access-Control-Allow-Origin'])
+ self.assertEqual(str(self._default_cors.max_age), headers['Access-Control-Max-Age'])
+ self.assertIsNone(headers.get('Access-Control-Allow-Credentials'))
+ self.assertIsNone(headers.get('Access-Control-Expose-Headers'))
+
+ def test_api_preflight_cors_headers_disallow_any_no_matched_allowed_origins(self):
+ get_config().set('cors', 'allow_any_origin', 'false')
+ get_config().set('cors', 'allowed_origins', '["foo", "bar"]')
+ response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foobar'})
+ headers = dict(response.headers)
+
+ self.assertNotIn('Access-Control-Allow-Headers', headers)
+ self.assertNotIn('Access-Control-Allow-Methods', headers)
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.assertNotIn('Access-Control-Max-Age', headers)
+ self.assertNotIn('Access-Control-Allow-Credentials', headers)
+ self.assertNotIn('Access-Control-Expose-Headers', headers)
+
+ def test_api_preflight_cors_headers_disallow_any_no_allowed_origins(self):
+ get_config().set('cors', 'allow_any_origin', 'false')
+ response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'})
+ headers = dict(response.headers)
+
+ self.assertNotIn('Access-Control-Allow-Headers', headers)
+ self.assertNotIn('Access-Control-Allow-Methods', headers)
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.assertNotIn('Access-Control-Max-Age', headers)
+ self.assertNotIn('Access-Control-Allow-Credentials', headers)
+ self.assertNotIn('Access-Control-Expose-Headers', headers)
+
+ def test_api_preflight_cors_headers_disabled(self):
+ get_config().set('cors', 'enabled', 'false')
+ response = self.fetch('/api/graph', method='OPTIONS', headers={'Origin': 'foo'})
+ headers = dict(response.headers)
+
+ self.assertNotIn('Access-Control-Allow-Headers', headers)
+ self.assertNotIn('Access-Control-Allow-Methods', headers)
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.assertNotIn('Access-Control-Max-Age', headers)
+ self.assertNotIn('Access-Control-Allow-Credentials', headers)
+ self.assertNotIn('Access-Control-Expose-Headers', headers)
+
+ def test_api_preflight_cors_headers_no_origin_header(self):
+ response = self.fetch('/api/graph', method='OPTIONS')
+ headers = dict(response.headers)
+
+ self.assertNotIn('Access-Control-Allow-Headers', headers)
+ self.assertNotIn('Access-Control-Allow-Methods', headers)
+ self.assertNotIn('Access-Control-Allow-Origin', headers)
+ self.assertNotIn('Access-Control-Max-Age', headers)
+ self.assertNotIn('Access-Control-Allow-Credentials', headers)
+ self.assertNotIn('Access-Control-Expose-Headers', headers)
+
def test_api_cors_headers(self):
- response = self.fetch('/api/graph')
+ response = self.fetch('/api/graph', headers={'Origin': 'foo'})
+ headers = dict(response.headers)
+
+ self.assertEqual('*', headers['Access-Control-Allow-Origin'])
+
+ def test_api_cors_headers_null_origin(self):
+ response = self.fetch('/api/graph', headers={'Origin': 'null'})
+ headers = dict(response.headers)
+
+ self.assertEqual('null', headers['Access-Control-Allow-Origin'])
+
+ def test_api_cors_headers_disallow_null(self):
+ get_config().set('cors', 'allow_null_origin', 'false')
+ response = self.fetch('/api/graph', headers={'Origin': 'null'})
+ headers = dict(response.headers)
+
+ self.assertIsNone(headers.get('Access-Control-Allow-Origin'))
+
+ def test_api_cors_headers_disallow_any(self):
+ get_config().set('cors', 'allow_any_origin', 'false')
+ get_config().set('cors', 'allowed_origins', '["foo", "bar"]')
+ response = self.fetch('/api/graph', headers={'Origin': 'foo'})
headers = dict(response.headers)
- def _set(name):
- return set(headers[name].replace(" ", "").split(","))
+ self.assertEqual('foo', headers['Access-Control-Allow-Origin'])
+
+ def test_api_cors_headers_disallow_any_no_matched_allowed_origins(self):
+ get_config().set('cors', 'allow_any_origin', 'false')
+ get_config().set('cors', 'allowed_origins', '["foo", "bar"]')
+ response = self.fetch('/api/graph', headers={'Origin': 'foobar'})
+ headers = dict(response.headers)
+
+ self.assertIsNone(headers.get('Access-Control-Allow-Origin'))
+
+ def test_api_cors_headers_disallow_any_no_allowed_origins(self):
+ get_config().set('cors', 'allow_any_origin', 'false')
+ response = self.fetch('/api/graph', headers={'Origin': 'foo'})
+ headers = dict(response.headers)
+
+ self.assertIsNone(headers.get('Access-Control-Allow-Origin'))
+
+ def test_api_cors_headers_disabled(self):
+ get_config().set('cors', 'enabled', 'false')
+ response = self.fetch('/api/graph', headers={'Origin': 'foo'})
+ headers = dict(response.headers)
+
+ self.assertIsNone(headers.get('Access-Control-Allow-Origin'))
+
+ def test_api_cors_headers_no_origin_header(self):
+ response = self.fetch('/api/graph')
+ headers = dict(response.headers)
- self.assertSetEqual(_set("Access-Control-Allow-Headers"), {"Content-Type", "Accept", "Authorization", "Origin"})
- self.assertSetEqual(_set("Access-Control-Allow-Methods"), {"GET", "OPTIONS"})
- self.assertEqual(headers["Access-Control-Allow-Origin"], "*")
+ self.assertIsNone(headers.get('Access-Control-Allow-Origin'))
class _ServerTest(unittest.TestCase):
diff --git a/test/webhdfs_minicluster.py b/test/webhdfs_minicluster.py
index 16d59a55f4..3f803030f9 100644
--- a/test/webhdfs_minicluster.py
+++ b/test/webhdfs_minicluster.py
@@ -66,7 +66,7 @@ def _get_namenode_port(self):
line = f.readline()
print(line.rstrip())
- m = re.match(".*Jetty bound to port (\d+).*", line)
+ m = re.match(".*Jetty bound to port (\\d+).*", line)
if just_seen_webhdfs and m:
return int(m.group(1))
just_seen_webhdfs = re.match(".*namenode.*webhdfs.*", line)
diff --git a/tox.ini b/tox.ini
index 6eb434746c..07ae0e753e 100644
--- a/tox.ini
+++ b/tox.ini
@@ -21,6 +21,7 @@ deps=
cdh,hdp: snakebite>=2.5.2,<2.6.0
cdh,hdp: hdfs>=2.0.4,<3.0.0
postgres: psycopg2<3.0
+ mysql-connector-python>=8.0.12
gcloud: google-api-python-client>=1.4.0,<2.0
py27-gcloud: avro
py33-gcloud,py34-gcloud,py35-gcloud,py36-gcloud: avro-python3