diff --git a/qds_sdk/engine.py b/qds_sdk/engine.py index c84ca366..534c940a 100644 --- a/qds_sdk/engine.py +++ b/qds_sdk/engine.py @@ -11,6 +11,7 @@ class Engine: def __init__(self, flavour=None): self.flavour = flavour self.hadoop_settings = {} + self.hive_settings = {} self.presto_settings = {} self.spark_settings = {} self.airflow_settings = {} @@ -26,6 +27,7 @@ def set_engine_config(self, custom_presto_config=None, spark_version=None, custom_spark_config=None, + hive_version=None, dbtap_id=None, fernet_key=None, overrides=None, @@ -56,6 +58,8 @@ def set_engine_config(self, custom_spark_config: Specify the custom Spark configuration overrides + hive_version: Version of hive to be used in cluster + dbtap_id: ID of the data store inside QDS fernet_key: Encryption key for sensitive information inside airflow database. @@ -75,8 +79,10 @@ def set_engine_config(self, ''' - self.set_hadoop_settings(custom_hadoop_config, use_qubole_placement_policy, is_ha, fairscheduler_config_xml, + self.set_hadoop_settings(custom_hadoop_config, use_qubole_placement_policy, + is_ha, fairscheduler_config_xml, default_pool, enable_rubix) + self.set_hive_settings(hive_version) self.set_presto_settings(presto_version, custom_presto_config) self.set_spark_settings(spark_version, custom_spark_config) self.set_airflow_settings(dbtap_id, fernet_key, overrides, airflow_version, airflow_python_version) @@ -103,6 +109,10 @@ def set_hadoop_settings(self, self.set_fairscheduler_settings(fairscheduler_config_xml, default_pool) self.hadoop_settings['enable_rubix'] = enable_rubix + def set_hive_settings(self, + hive_version=None): + self.hive_settings['hive_version'] = hive_version + def set_presto_settings(self, presto_version=None, custom_presto_config=None): @@ -147,6 +157,7 @@ def set_engine_config_settings(self, arguments): custom_presto_config=custom_presto_config, spark_version=arguments.spark_version, custom_spark_config=arguments.custom_spark_config, + hive_version=arguments.hive_version, dbtap_id=arguments.dbtap_id, fernet_key=arguments.fernet_key, overrides=arguments.overrides, @@ -218,6 +229,11 @@ def engine_parser(argparser): dest="presto_custom_config_file", help="location of file containg custom" + " presto configuration overrides") + hive_settings_group = argparser.add_argument_group("hive version settings") + hive_settings_group.add_argument("--hive_version", + dest="hive_version", + default=None, + help="Version of hive for the cluster",) spark_settings_group = argparser.add_argument_group("spark settings") spark_settings_group.add_argument("--spark-version", diff --git a/tests/test_clusterv2.py b/tests/test_clusterv2.py index 078108ff..17859456 100644 --- a/tests/test_clusterv2.py +++ b/tests/test_clusterv2.py @@ -562,6 +562,26 @@ def test_mlflow_engine_config(self): }}, 'cluster_info': {'label': ['test_label'], }}) + def test_hive_engine_config(self): + with tempfile.NamedTemporaryFile() as temp: + temp.write("config.properties:\na=1\nb=2".encode("utf8")) + temp.flush() + sys.argv = ['qds.py', '--version', 'v2', 'cluster', 'create', '--label', 'test_label', + '--flavour', 'hadoop2', '--hive_version', '2.3'] + Qubole.cloud = None + print_command() + Connection._api_call = Mock(return_value={}) + qds.main() + Connection._api_call.assert_called_with('POST', 'clusters', + {'engine_config': + {'flavour': 'hadoop2', + 'hive_settings': { + 'hive_version': '2.3' + }}, + 'cluster_info': {'label': ['test_label'],}}) + + + def test_persistent_security_groups_v2(self): sys.argv = ['qds.py', '--version', 'v2', 'cluster', 'create', '--label', 'test_label', '--persistent-security-groups', 'sg1, sg2']