Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolagigic committed Mar 10, 2021
1 parent 0189072 commit bfdc994
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 58 deletions.
17 changes: 10 additions & 7 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
from sqlalchemy.types import TypeEngine
from sqlalchemy.types import String, TypeEngine, UnicodeText

from superset import app, security_manager, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
Expand Down Expand Up @@ -160,7 +160,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
),
(
re.compile(r"^integer", re.IGNORECASE),
types.Integer(),
types.Integer,
GenericDataType.NUMERIC,
),
(
Expand Down Expand Up @@ -195,12 +195,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
GenericDataType.NUMERIC,
),
(
re.compile(r"^varchar", re.IGNORECASE),
types.VARCHAR(),
GenericDataType.STRING,
re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
UnicodeText(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
),
(re.compile(r"^char", re.IGNORECASE), types.CHAR(), GenericDataType.STRING),
(re.compile(r"^text", re.IGNORECASE), types.Text(), GenericDataType.STRING),
(re.compile(r"^date", re.IGNORECASE), types.Date(), GenericDataType.TEMPORAL,),
(
re.compile(r"^timestamp", re.IGNORECASE),
Expand Down
13 changes: 0 additions & 13 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,6 @@ def fetch_data(
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

column_type_mappings = (
(
re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
UnicodeText(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
),
)

@classmethod
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):
Expand Down
11 changes: 6 additions & 5 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _show_columns(
),
(
re.compile(r"^integer.*", re.IGNORECASE),
types.INTEGER,
types.INTEGER(),
utils.GenericDataType.NUMERIC,
),
(
Expand Down Expand Up @@ -1201,10 +1201,11 @@ def get_column_spec( # type: ignore
] = column_type_mappings,
) -> Union[ColumnSpec, None]:

column_spec = super().get_column_spec(native_type)
column_spec = super().get_column_spec(
native_type, column_type_mappings=column_type_mappings
)

if column_spec:
return column_spec

return super().get_column_spec(
native_type, column_type_mappings=column_type_mappings
)
return super().get_column_spec(native_type)
5 changes: 3 additions & 2 deletions tests/databases/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_export_database_command(self, mock_g):
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_run_async": False,
"allow_run_async": True,
"cache_timeout": None,
"database_name": "examples",
"expose_in_sqllab": True,
Expand Down Expand Up @@ -247,7 +247,8 @@ def test_export_database_command(self, mock_g):
"version": "1.0.0",
}
expected_metadata["columns"].sort(key=lambda x: x["column_name"])
assert metadata == expected_metadata
self.maxDiff = None
self.assertEquals(metadata, expected_metadata)

@patch("superset.security.manager.g")
def test_export_database_command_no_access(self, mock_g):
Expand Down
62 changes: 31 additions & 31 deletions tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,37 +535,37 @@ def test_presto_expand_data_array(self):
self.assertEqual(actual_expanded_cols, expected_expanded_cols)

def test_get_sqla_column_type(self):
sqla_type, generic_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
assert isinstance(sqla_type, types.VARCHAR)
assert sqla_type.length is None
self.assertEqual(generic_type, GenericDataType.STRING)

sqla_type, generic_type = PrestoEngineSpec.get_sqla_column_type("varchar")
assert isinstance(sqla_type, types.String)
assert sqla_type.length is None
self.assertEqual(generic_type, GenericDataType.STRING)

sqla_type, generic_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length is None
self.assertEqual(generic_type, GenericDataType.STRING)

sqla_type, generic_type = PrestoEngineSpec.get_sqla_column_type("char")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length is None
self.assertEqual(generic_type, GenericDataType.STRING)

sqla_type, generic_type = PrestoEngineSpec.get_sqla_column_type("integer")
assert isinstance(sqla_type, types.Integer)
self.assertEqual(generic_type, GenericDataType.NUMERIC)

sqla_type, generic_type = PrestoEngineSpec.get_sqla_column_type("time")
assert isinstance(sqla_type, types.Time)
self.assertEqual(generic_type, GenericDataType.TEMPORAL)

sqla_type, generic_type = PrestoEngineSpec.get_sqla_column_type("timestamp")
assert isinstance(sqla_type, types.TIMESTAMP)
self.assertEqual(generic_type, GenericDataType.TEMPORAL)
column_spec = PrestoEngineSpec.get_column_spec("varchar(255)")
assert isinstance(column_spec.sqla_type, types.VARCHAR)
assert column_spec.sqla_type.length == 255
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)

column_spec = PrestoEngineSpec.get_column_spec("varchar")
assert isinstance(column_spec.sqla_type, types.String)
assert column_spec.sqla_type.length is None
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)

column_spec = PrestoEngineSpec.get_column_spec("char(10)")
assert isinstance(column_spec.sqla_type, types.CHAR)
assert column_spec.sqla_type.length == 10
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)

column_spec = PrestoEngineSpec.get_column_spec("char")
assert isinstance(column_spec.sqla_type, types.CHAR)
assert column_spec.sqla_type.length is None
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)

column_spec = PrestoEngineSpec.get_column_spec("integer")
assert isinstance(column_spec.sqla_type, types.Integer)
self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC)

column_spec = PrestoEngineSpec.get_column_spec("time")
assert isinstance(column_spec.sqla_type, types.Time)
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)

column_spec = PrestoEngineSpec.get_column_spec("timestamp")
assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)

sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
assert sqla_type is None
Expand Down

0 comments on commit bfdc994

Please sign in to comment.