Skip to content

Commit

Permalink
HiveTableTarget inherited from HivePartitionTarget
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Dec 24, 2019
1 parent b9a1bbd commit 702a4e7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 45 deletions.
65 changes: 22 additions & 43 deletions luigi/contrib/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,34 +469,6 @@ def run_job(self, job, tracking_url_callback=None):
return luigi.contrib.hadoop.run_and_track_hadoop_job(arglist, job.set_tracking_url)


class HiveTableTarget(luigi.Target):
"""
exists returns true if the table exists.
"""

def __init__(self, table, database='default', client=None):
self.database = database
self.table = table
self.client = client or get_default_client()

def exists(self):
logger.debug("Checking if Hive table '%s.%s' exists", self.database, self.table)
return self.client.table_exists(self.table, self.database)

@property
def path(self):
"""
Returns the path to this table in HDFS.
"""
location = self.client.table_location(self.table, self.database)
if not location:
raise Exception("Couldn't find location for table: {0}".format(str(self)))
return location

def open(self, mode):
return NotImplementedError("open() is not supported for HiveTableTarget")


class HivePartitionTarget(luigi.Target):
"""
exists returns true if the table's partition exists.
Expand All @@ -507,7 +479,6 @@ def __init__(self, table, partition, database='default', fail_missing_table=True
self.table = table
self.partition = partition
self.client = client or get_default_client()

self.fail_missing_table = fail_missing_table

def exists(self):
Expand All @@ -516,7 +487,7 @@ def exists(self):
"Checking Hive table '{d}.{t}' for partition {p}".format(
d=self.database,
t=self.table,
p=str(self.partition)
p=str(self.partition or {})
)
)

Expand All @@ -543,14 +514,28 @@ def path(self):
return location

def open(self, mode):
return NotImplementedError("open() is not supported for HivePartitionTarget")
return NotImplementedError("open() is not supported for {}".format(self.__class__.__name__))


class HiveTableTarget(HivePartitionTarget):
"""
exists returns true if the table exists.
"""

def __init__(self, table, database='default', client=None):
super(HiveTableTarget, self).__init__(
table=table,
partition=None,
database=database,
fail_missing_table=True,
client=client,
)


class ExternalHiveTask(luigi.ExternalTask):
"""
External task that depends on a Hive table/partition.
"""

database = luigi.Parameter(default='default')
table = luigi.Parameter()
partition = luigi.DictParameter(
Expand All @@ -559,14 +544,8 @@ class ExternalHiveTask(luigi.ExternalTask):
)

def output(self):
if self.partition:
return HivePartitionTarget(
database=self.database,
table=self.table,
partition=self.partition,
)
else:
return HiveTableTarget(
database=self.database,
table=self.table,
)
return HivePartitionTarget(
database=self.database,
table=self.table,
partition=self.partition,
)
5 changes: 3 additions & 2 deletions test/contrib/hive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def test_hive_table_target(self):
client = mock.Mock()
target = luigi.contrib.hive.HiveTableTarget(database='db', table='foo', client=client)
target.exists()
client.table_exists.assert_called_with('foo', 'db')
client.table_exists.assert_called_with('foo', 'db', None)

def test_hive_partition_target(self):
client = mock.Mock()
Expand All @@ -480,9 +480,10 @@ class _Task(luigi.contrib.hive.ExternalHiveTask):
output = _Task().output()

# assert
assert isinstance(output, luigi.contrib.hive.HiveTableTarget)
assert isinstance(output, luigi.contrib.hive.HivePartitionTarget)
assert output.database == 'schema1'
assert output.table == 'table1'
assert output.partition == {}

def test_partition_exists(self):
# arrange
Expand Down

0 comments on commit 702a4e7

Please sign in to comment.