From 599b7e73c0c86f065dbbb2946573aa3ba83419f3 Mon Sep 17 00:00:00 2001 From: Arif Wider Date: Wed, 20 Dec 2017 16:15:06 +0100 Subject: [PATCH] fixed tests --- .travis.yml | 6 +- src/evaluation.py | 5 +- test/merger_test.py | 70 ------------------- ...oad_and_split_test.py => splitter_test.py} | 6 +- 4 files changed, 11 insertions(+), 76 deletions(-) delete mode 100644 test/merger_test.py rename test/{load_and_split_test.py => splitter_test.py} (78%) diff --git a/.travis.yml b/.travis.yml index 47447a8..aea6cf0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,4 +9,8 @@ script: - pytest notifications: - slack: tw-datalab:tySjyrBfi9MdHdiuk5hh8Jbd \ No newline at end of file + slack: + rooms: + - tw-datalab:tySjyrBfi9MdHdiuk5hh8Jbd + on_success: change + on_failure: always \ No newline at end of file diff --git a/src/evaluation.py b/src/evaluation.py index 2a24222..5c7f7b5 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -7,7 +7,8 @@ def nwrmsle(predictions, targets, weights): predictions = np.array([np.nan if x < 0 else x for x in predictions]) elif type(predictions) == pd.Series: predictions[predictions < 0] = np.nan - targets[targets < 0] = np.nan + targetsf = targets.astype(float) + targetsf[targets < 0] = np.nan weights = 1 + 0.25 * weights - log_square_errors = (np.log(predictions + 1) - np.log(targets + 1)) ** 2 + log_square_errors = (np.log(predictions + 1) - np.log(targetsf + 1)) ** 2 return(np.sqrt(np.sum(weights * log_square_errors) / np.sum(weights))) diff --git a/test/merger_test.py b/test/merger_test.py deleted file mode 100644 index b0665a0..0000000 --- a/test/merger_test.py +++ /dev/null @@ -1,70 +0,0 @@ -import sys -import os -import pytest -import logging -# import numpy as np -# from pytest import approx -from pyspark.sql import SparkSession -sys.path.append(os.path.join('..', 'src')) -sys.path.append(os.path.join('src')) -import merger - - -def quiet_py4j(): - """Suppress spark logging for the test context.""" - logger = logging.getLogger('py4j') - logger.setLevel(logging.WARN) - - -@pytest.fixture(scope="session") -def spark_session(request): - """Fixture for creating a spark context.""" - - spark = (SparkSession - .builder - .config("spark.driver.memory", "4g") - .appName('pytest-pyspark-local-testing') - .getOrCreate()) - request.addfinalizer(lambda: spark.stop()) - - quiet_py4j() - return spark - - -def test_table_attribute_formatter(): - table_name = 'table' - column_names = ['attrib1', 'attrib2'] - formatted_string = merger.table_attribute_formatter(table_name, column_names) - assert formatted_string == 'table.attrib1, table.attrib2' - - -def test_loj_sql_formatter(): - right_attributes = "a.x, a.y, a.z" - on_column = "qrs" - sql_statement = merger.loj_sql_formatter(right_attributes, on_column) - assert sql_statement == "SELECT left.*, a.x, a.y, a.z FROM left LEFT JOIN right ON left.qrs = right.qrs" - - -def test_join_datasets_on_column(spark_session): - train_headers = ['id', 'date', 'store_nbr', 'item_nbr', 'unit_sales', 'onpromotion'] - train_data1 = ['0', '2013-01-01', '25', '1111', '7.0', '0'] - train_data2 = ['1', '2013-01-01', '25', '9999', '1.0', '0'] - mockTrain = spark_session.createDataFrame([train_data1, train_data2], schema=train_headers) - - items_headers = ['item_nbr', 'family', 'class', 'perishable'] - items_data1 = ['1111', 'A', 'C', '1'] - items_data2 = ['1234', 'B', 'C', '0'] - mockItems = spark_session.createDataFrame([items_data1, items_data2], schema=items_headers) - mockTrain.createOrReplaceTempView("train") - mockItems.createOrReplaceTempView("items") - on_column = "item_nbr" - columns = ["family", "class", "perishable"] - train_items_merged = merger.leftOuterJoin(mockTrain, mockItems, on_column, columns) - assert train_items_merged.count() == 2 - assert train_items_merged.columns == ['id', 'date', 'store_nbr', 'item_nbr', 'unit_sales', 'onpromotion', 'family', 'class', 'perishable'] - assert train_items_merged.count() == 2 - # asserting that the family column of item 9999 is null, because we have - # no store, family, or class data for it - assert train_items_merged.filter(train_items_merged.item_nbr == 9999).collect()[0].family is None - assert train_items_merged.filter(train_items_merged.item_nbr == 9999).collect()[0].id is not None - assert train_items_merged.filter(train_items_merged.item_nbr == 1111).collect()[0].family is not None diff --git a/test/load_and_split_test.py b/test/splitter_test.py similarity index 78% rename from test/load_and_split_test.py rename to test/splitter_test.py index 993842a..a836013 100644 --- a/test/load_and_split_test.py +++ b/test/splitter_test.py @@ -3,11 +3,11 @@ import pandas as pd sys.path.append(os.path.join('..', 'src')) sys.path.append(os.path.join('src')) -import load_and_split +import splitter def test_get_validation_period(): latest_date = pd.to_datetime('2017-11-22') - actual_begin_date, actual_end_date = load_and_split.get_validation_period(latest_date) + actual_begin_date, actual_end_date = splitter.get_validation_period(latest_date) expected_begin_date = pd.to_datetime('2017-11-01') expected_end_date = pd.to_datetime('2017-11-16') assert actual_begin_date == expected_begin_date @@ -22,6 +22,6 @@ def test_split_validation_train_by_validation_period(): validation_end_date = pd.to_datetime('2017-11-30') d = {'date': [date1, date2, date3, date4], 'col2': [3, 4, 5, 6]} df = pd.DataFrame(data=d) - df_train, df_validation = load_and_split.split_validation_train_by_validation_period(df, validation_begin_date, validation_end_date) + df_train, df_validation = splitter.split_validation_train_by_validation_period(df, validation_begin_date, validation_end_date) assert df_train.shape[0] == 1 assert df_validation.shape[0] == 2