Skip to content

Commit

Permalink
feat(charts): modify custom api filter to include more fields
Browse files Browse the repository at this point in the history
  • Loading branch information
nytai committed Sep 25, 2020
1 parent d056e3d commit bb23780
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 24 deletions.
3 changes: 2 additions & 1 deletion superset-frontend/src/components/ListView/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ export interface Filter {
| 'rel_o_m'
| 'title_or_slug'
| 'name_or_description'
| 'all_text';
| 'all_text'
| 'all_text_chart';
input?: 'text' | 'textarea' | 'select' | 'checkbox' | 'search';
unfilteredLabel?: string;
selects?: SelectOption[];
Expand Down
2 changes: 1 addition & 1 deletion superset-frontend/src/views/CRUD/chart/ChartList.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ function ChartList(props: ChartListProps) {
Header: t('Search'),
id: 'slice_name',
input: 'search',
operator: 'name_or_description',
operator: 'all_text_chart',
},
];

Expand Down
4 changes: 2 additions & 2 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
ChartUpdateFailedError,
)
from superset.charts.commands.update import UpdateChartCommand
from superset.charts.filters import ChartFilter, ChartNameOrDescriptionFilter
from superset.charts.filters import ChartAllTextFilter, ChartFilter
from superset.charts.schemas import (
CHART_SCHEMAS,
ChartDataQueryContextSchema,
Expand Down Expand Up @@ -145,7 +145,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
]
base_order = ("changed_on", "desc")
base_filters = [["id", ChartFilter, lambda: []]]
search_filters = {"slice_name": [ChartNameOrDescriptionFilter]}
search_filters = {"slice_name": [ChartAllTextFilter]}

# Will just affect _info endpoint
edit_columns = ["slice_name"]
Expand Down
11 changes: 6 additions & 5 deletions superset/charts/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
from sqlalchemy.orm.query import Query

from superset import security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.models.slice import Slice
from superset.views.base import BaseFilter


class ChartNameOrDescriptionFilter(
BaseFilter
): # pylint: disable=too-few-public-methods
name = _("Name or Description")
arg_name = "name_or_description"
class ChartAllTextFilter(BaseFilter): # pylint: disable=too-few-public-methods
name = _("All Text")
arg_name = "all_text_chart"

def apply(self, query: Query, value: Any) -> Query:
if not value:
Expand All @@ -39,6 +38,8 @@ def apply(self, query: Query, value: Any) -> Query:
or_(
Slice.slice_name.ilike(ilike_value),
Slice.description.ilike(ilike_value),
Slice.viz_type.ilike(ilike_value),
SqlaTable.table_name.ilike(ilike_value),
)
)

Expand Down
44 changes: 29 additions & 15 deletions tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def test_get_chart(self):

def test_get_chart_not_found(self):
"""
Chart API: Test get chart not found
Chart API: Test get chart not found
"""
chart_id = 1000
self.login(username="admin")
Expand All @@ -525,7 +525,7 @@ def test_get_chart_not_found(self):

def test_get_chart_no_data_access(self):
"""
Chart API: Test get chart without data access
Chart API: Test get chart without data access
"""
self.login(username="gamma")
chart_no_access = (
Expand Down Expand Up @@ -596,34 +596,49 @@ def test_get_charts_custom_filter(self):
chart1 = self.insert_chart("foo_a", [admin.id], 1, description="ZY_bar")
chart2 = self.insert_chart("zy_foo", [admin.id], 1, description="desc1")
chart3 = self.insert_chart("foo_b", [admin.id], 1, description="desc1zy_")
chart4 = self.insert_chart("bar", [admin.id], 1, description="foo")
chart4 = self.insert_chart("foo_c", [admin.id], 1, viz_type="viz_zy_")
chart5 = self.insert_chart("bar", [admin.id], 1, description="foo")

arguments = {
"filters": [
{"col": "slice_name", "opr": "name_or_description", "value": "zy_"}
],
"filters": [{"col": "slice_name", "opr": "all_text_chart", "value": "zy_"}],
"order_column": "slice_name",
"order_direction": "asc",
"keys": ["none"],
"columns": ["slice_name", "description"],
"columns": ["slice_name", "description", "viz_type"],
}
self.login(username="admin")
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 3)
self.assertEqual(data["count"], 4)

expected_response = [
{"description": "ZY_bar", "slice_name": "foo_a",},
{"description": "desc1zy_", "slice_name": "foo_b",},
{"description": "desc1", "slice_name": "zy_foo",},
{"description": "ZY_bar", "slice_name": "foo_a", "viz_type": None},
{"description": "desc1zy_", "slice_name": "foo_b", "viz_type": None},
{"description": None, "slice_name": "foo_c", "viz_type": "viz_zy_"},
{"description": "desc1", "slice_name": "zy_foo", "viz_type": None},
]
for index, item in enumerate(data["result"]):
self.assertEqual(
item["description"], expected_response[index]["description"]
)
self.assertEqual(item["slice_name"], expected_response[index]["slice_name"])
self.assertEqual(item["viz_type"], expected_response[index]["viz_type"])

# test filtering on datasource_name
arguments = {
"filters": [
{"col": "slice_name", "opr": "all_text_chart", "value": "energy",}
],
"keys": ["none"],
"columns": ["slice_name"],
}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 8)

self.logout()
self.login(username="gamma")
Expand All @@ -638,6 +653,7 @@ def test_get_charts_custom_filter(self):
db.session.delete(chart2)
db.session.delete(chart3)
db.session.delete(chart4)
db.session.delete(chart5)
db.session.commit()

def test_get_charts_page(self):
Expand Down Expand Up @@ -870,8 +886,7 @@ def test_chart_data_incorrect_request(self):
self.assertEqual(rv.status_code, 400)

def test_chart_data_with_invalid_datasource(self):
"""Chart data API: Test chart data query with invalid schema
"""
"""Chart data API: Test chart data query with invalid schema"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
Expand All @@ -880,8 +895,7 @@ def test_chart_data_with_invalid_datasource(self):
self.assertEqual(rv.status_code, 400)

def test_chart_data_with_invalid_enum_value(self):
"""Chart data API: Test chart data query with invalid enum value
"""
"""Chart data API: Test chart data query with invalid enum value"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
Expand Down

0 comments on commit bb23780

Please sign in to comment.