diff --git a/backend/dataall/base/aws/s3_client.py b/backend/dataall/base/aws/s3_client.py index 74f79e7b4..e05fa9933 100644 --- a/backend/dataall/base/aws/s3_client.py +++ b/backend/dataall/base/aws/s3_client.py @@ -1,8 +1,8 @@ +import logging + import boto3 from botocore.config import Config - from botocore.exceptions import ClientError -import logging log = logging.getLogger(__name__) diff --git a/backend/dataall/modules/worksheets/api/__init__.py b/backend/dataall/modules/worksheets/api/__init__.py index 1251b301b..48b0459d6 100644 --- a/backend/dataall/modules/worksheets/api/__init__.py +++ b/backend/dataall/modules/worksheets/api/__init__.py @@ -4,7 +4,12 @@ queries, resolvers, types, - enums, ) -__all__ = ['resolvers', 'types', 'input_types', 'queries', 'mutations', 'enums'] +__all__ = [ + 'resolvers', + 'types', + 'input_types', + 'queries', + 'mutations', +] diff --git a/backend/dataall/modules/worksheets/api/input_types.py b/backend/dataall/modules/worksheets/api/input_types.py index e2ec62dea..9fbeb75b1 100644 --- a/backend/dataall/modules/worksheets/api/input_types.py +++ b/backend/dataall/modules/worksheets/api/input_types.py @@ -75,3 +75,14 @@ gql.Argument(name='measures', type=gql.ArrayType(gql.Ref('WorksheetMeasureInput'))), ], ) + + +WorksheetQueryResultDownloadUrlInput = gql.InputType( + name='WorksheetQueryResultDownloadUrlInput', + arguments=[ + gql.Argument(name='athenaQueryId', type=gql.NonNullableType(gql.String)), + gql.Argument(name='fileFormat', type=gql.NonNullableType(gql.String)), + gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), + gql.Argument(name='worksheetUri', type=gql.NonNullableType(gql.String)), + ], +) diff --git a/backend/dataall/modules/worksheets/api/mutations.py b/backend/dataall/modules/worksheets/api/mutations.py index a232fba83..efe4afa6b 100644 --- a/backend/dataall/modules/worksheets/api/mutations.py +++ b/backend/dataall/modules/worksheets/api/mutations.py @@ -1,5 +1,10 @@ from dataall.base.api import gql -from dataall.modules.worksheets.api.resolvers import create_worksheet, delete_worksheet, update_worksheet +from dataall.modules.worksheets.api.resolvers import ( + create_worksheet, + delete_worksheet, + update_worksheet, + create_athena_query_result_download_url, +) createWorksheet = gql.MutationField( @@ -27,3 +32,12 @@ ], type=gql.Boolean, ) + +createWorksheetQueryResultDownloadUrl = gql.MutationField( + name='createWorksheetQueryResultDownloadUrl', + resolver=create_athena_query_result_download_url, + args=[ + gql.Argument(name='input', type=gql.Ref('WorksheetQueryResultDownloadUrlInput')), + ], + type=gql.Ref('WorksheetQueryResult'), +) diff --git a/backend/dataall/modules/worksheets/api/resolvers.py b/backend/dataall/modules/worksheets/api/resolvers.py index 450667217..07bf5ea80 100644 --- a/backend/dataall/modules/worksheets/api/resolvers.py +++ b/backend/dataall/modules/worksheets/api/resolvers.py @@ -1,9 +1,10 @@ from dataall.base.db import exceptions -from dataall.modules.worksheets.api.enums import WorksheetRole +from dataall.modules.worksheets.services.worksheet_enums import WorksheetRole, WorksheetResultsFormat from dataall.modules.worksheets.db.worksheet_models import Worksheet from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository from dataall.modules.worksheets.services.worksheet_service import WorksheetService from dataall.base.api.context import Context +from dataall.modules.worksheets.services.worksheet_query_result_service import WorksheetQueryResultService def create_worksheet(context: Context, source, input: dict = None): @@ -69,3 +70,28 @@ def run_sql_query(context: Context, source, environmentUri: str = None, workshee def delete_worksheet(context, source, worksheetUri: str = None): with context.engine.scoped_session() as session: return WorksheetService.delete_worksheet(session=session, uri=worksheetUri) + + +def create_athena_query_result_download_url(context: Context, source, input: dict = None): + if not input: + raise exceptions.RequiredParameter('data') + if not input.get('environmentUri'): + raise exceptions.RequiredParameter('environmentUri') + if not input.get('athenaQueryId'): + raise exceptions.RequiredParameter('athenaQueryId') + if not input.get('fileFormat'): + raise exceptions.RequiredParameter('fileFormat') + if not hasattr(WorksheetResultsFormat, input.get('fileFormat').upper()): + raise exceptions.InvalidInput( + 'fileFormat', + input.get('fileFormat'), + ', '.join(result_format.value for result_format in WorksheetResultsFormat), + ) + + env_uri = input['environmentUri'] + worksheet_uri = input['worksheetUri'] + + with context.engine.scoped_session() as session: + return WorksheetQueryResultService.download_sql_query_result( + session=session, uri=worksheet_uri, env_uri=env_uri, data=input + ) diff --git a/backend/dataall/modules/worksheets/api/types.py b/backend/dataall/modules/worksheets/api/types.py index e9187353f..f55741f8b 100644 --- a/backend/dataall/modules/worksheets/api/types.py +++ b/backend/dataall/modules/worksheets/api/types.py @@ -86,15 +86,16 @@ name='WorksheetQueryResult', fields=[ gql.Field(name='worksheetQueryResultUri', type=gql.ID), - gql.Field(name='queryType', type=gql.NonNullableType(gql.String)), - gql.Field(name='sqlBody', type=gql.NonNullableType(gql.String)), + gql.Field(name='sqlBody', type=gql.String), gql.Field(name='AthenaQueryId', type=gql.NonNullableType(gql.String)), gql.Field(name='region', type=gql.NonNullableType(gql.String)), gql.Field(name='AwsAccountId', type=gql.NonNullableType(gql.String)), - gql.Field(name='AthenaOutputBucketName', type=gql.NonNullableType(gql.String)), - gql.Field(name='AthenaOutputKey', type=gql.NonNullableType(gql.String)), - gql.Field(name='timeElapsedInSecond', type=gql.NonNullableType(gql.Integer)), + gql.Field(name='elapsedTimeInMs', type=gql.Integer), gql.Field(name='created', type=gql.NonNullableType(gql.String)), + gql.Field(name='downloadLink', type=gql.String), + gql.Field(name='outputLocation', type=gql.String), + gql.Field(name='expiresIn', type=gql.AWSDateTime), + gql.Field(name='fileFormat', type=gql.String), ], ) diff --git a/backend/dataall/modules/worksheets/aws/s3_client.py b/backend/dataall/modules/worksheets/aws/s3_client.py new file mode 100644 index 000000000..9d07132e1 --- /dev/null +++ b/backend/dataall/modules/worksheets/aws/s3_client.py @@ -0,0 +1,69 @@ +import logging +from typing import TYPE_CHECKING + +from botocore.exceptions import ClientError + +from dataall.base.aws.sts import SessionHelper +from dataall.base.db.exceptions import AWSResourceNotFound + +if TYPE_CHECKING: + from dataall.core.environment.db.environment_models import Environment + + try: + from mypy_boto3_s3 import S3Client as S3ClientType + except ImportError: + S3ClientType = None + +log = logging.getLogger(__name__) + + +class S3Client: + def __init__(self, env: 'Environment'): + self._client = SessionHelper.remote_session(env.AwsAccountId, env.region).client('s3', region_name=env.region) + self._env = env + + @property + def client(self) -> 'S3ClientType': + return self._client + + def get_presigned_url(self, bucket, key, expire_minutes: int = 15): + expire_seconds = expire_minutes * 60 + try: + presigned_url = self.client.generate_presigned_url( + 'get_object', + Params=dict( + Bucket=bucket, + Key=key, + ), + ExpiresIn=expire_seconds, + ) + return presigned_url + except ClientError as e: + log.error(f'Failed to get presigned URL due to: {e}') + raise e + + def object_exists(self, bucket, key) -> bool: + try: + self.client.head_object(Bucket=bucket, Key=key) + return True + except ClientError as e: + if e.response['Error']['Code'] == '404': + log.info(f'Object {key} not found in bucket {bucket}') + return False + log.error(f'Failed to check object existence due to: {e}') + raise AWSResourceNotFound('s3_object_exists', f'Object {key} not found in bucket {bucket}') + + def put_object(self, bucket, key, body): + try: + self.client.put_object(Bucket=bucket, Key=key, Body=body) + except ClientError as e: + log.error(f'Failed to put object due to: {e}') + raise e + + def get_object(self, bucket, key) -> str: + try: + response = self.client.get_object(Bucket=bucket, Key=key) + return response['Body'].read().decode('utf-8') + except ClientError as e: + log.error(f'Failed to get object due to: {e}') + raise e diff --git a/backend/dataall/modules/worksheets/db/worksheet_models.py b/backend/dataall/modules/worksheets/db/worksheet_models.py index 6549cb96c..b7a01a871 100644 --- a/backend/dataall/modules/worksheets/db/worksheet_models.py +++ b/backend/dataall/modules/worksheets/db/worksheet_models.py @@ -1,7 +1,7 @@ import datetime import enum -from sqlalchemy import Column, DateTime, Integer, Enum, String +from sqlalchemy import Column, DateTime, Integer, Enum, String, BigInteger from sqlalchemy.dialects import postgresql from sqlalchemy.orm import query_expression @@ -27,15 +27,23 @@ class Worksheet(Resource, Base): class WorksheetQueryResult(Base): __tablename__ = 'worksheet_query_result' + worksheetQueryResultUri = Column(String, primary_key=True, default=utils.uuid('worksheetQueryResultUri')) worksheetUri = Column(String, nullable=False) - AthenaQueryId = Column(String, primary_key=True) - status = Column(String, nullable=False) - queryType = Column(Enum(QueryType), nullable=False, default=True) - sqlBody = Column(String, nullable=False) + AthenaQueryId = Column(String, nullable=False) + status = Column(String, nullable=True) + sqlBody = Column(String, nullable=True) AwsAccountId = Column(String, nullable=False) region = Column(String, nullable=False) OutputLocation = Column(String, nullable=False) error = Column(String, nullable=True) ElapsedTimeInMs = Column(Integer, nullable=True) - DataScannedInBytes = Column(Integer, nullable=True) + DataScannedInBytes = Column(BigInteger, nullable=True) created = Column(DateTime, default=datetime.datetime.now) + + downloadLink = Column(String, nullable=True) + expiresIn = Column(DateTime, nullable=True) + updated = Column(DateTime, nullable=False, onupdate=datetime.datetime.utcnow, default=datetime.datetime.utcnow) + fileFormat = Column(String, nullable=True) + + def is_download_link_expired(self): + return self.expiresIn is None or self.expiresIn <= datetime.datetime.utcnow() diff --git a/backend/dataall/modules/worksheets/db/worksheet_repositories.py b/backend/dataall/modules/worksheets/db/worksheet_repositories.py index aea51761b..cac450943 100644 --- a/backend/dataall/modules/worksheets/db/worksheet_repositories.py +++ b/backend/dataall/modules/worksheets/db/worksheet_repositories.py @@ -53,3 +53,17 @@ def paginated_user_worksheets(session, username, groups, uri, data=None, check_p page=data.get('page', WorksheetRepository._DEFAULT_PAGE), page_size=data.get('pageSize', WorksheetRepository._DEFAULT_PAGE_SIZE), ).to_dict() + + @staticmethod + def find_query_result_by_format( + session, worksheet_uri: str, athena_query_id: str, file_format: str + ) -> WorksheetQueryResult: + return ( + session.query(WorksheetQueryResult) + .filter( + WorksheetQueryResult.worksheetUri == worksheet_uri, + WorksheetQueryResult.AthenaQueryId == athena_query_id, + WorksheetQueryResult.fileFormat == file_format, + ) + .first() + ) diff --git a/backend/dataall/modules/worksheets/api/enums.py b/backend/dataall/modules/worksheets/services/worksheet_enums.py similarity index 65% rename from backend/dataall/modules/worksheets/api/enums.py rename to backend/dataall/modules/worksheets/services/worksheet_enums.py index 3e9549f2a..c973c2399 100644 --- a/backend/dataall/modules/worksheets/api/enums.py +++ b/backend/dataall/modules/worksheets/services/worksheet_enums.py @@ -5,3 +5,8 @@ class WorksheetRole(GraphQLEnumMapper): Creator = '950' Admin = '900' NoPermission = '000' + + +class WorksheetResultsFormat(GraphQLEnumMapper): + CSV = 'csv' + XLSX = 'xlsx' diff --git a/backend/dataall/modules/worksheets/services/worksheet_permissions.py b/backend/dataall/modules/worksheets/services/worksheet_permissions.py index 0f494567d..301c2ddca 100644 --- a/backend/dataall/modules/worksheets/services/worksheet_permissions.py +++ b/backend/dataall/modules/worksheets/services/worksheet_permissions.py @@ -1,13 +1,12 @@ -from dataall.core.permissions.services.resources_permissions import ( - RESOURCES_ALL, - RESOURCES_ALL_WITH_DESC, -) from dataall.core.permissions.services.environment_permissions import ( ENVIRONMENT_INVITED, ENVIRONMENT_INVITATION_REQUEST, ENVIRONMENT_ALL, ) - +from dataall.core.permissions.services.resources_permissions import ( + RESOURCES_ALL, + RESOURCES_ALL_WITH_DESC, +) from dataall.core.permissions.services.tenant_permissions import TENANT_ALL, TENANT_ALL_WITH_DESC MANAGE_WORKSHEETS = 'MANAGE_WORKSHEETS' @@ -22,12 +21,10 @@ UPDATE_WORKSHEET = 'UPDATE_WORKSHEET' DELETE_WORKSHEET = 'DELETE_WORKSHEET' RUN_WORKSHEET_QUERY = 'RUN_WORKSHEET_QUERY' -WORKSHEET_ALL = [ - GET_WORKSHEET, - UPDATE_WORKSHEET, - DELETE_WORKSHEET, - RUN_WORKSHEET_QUERY, -] +DOWNLOAD_ATHENA_QUERY_RESULTS = 'DOWNLOAD_ATHENA_QUERY_RESULTS' + + +WORKSHEET_ALL = [GET_WORKSHEET, UPDATE_WORKSHEET, DELETE_WORKSHEET, RUN_WORKSHEET_QUERY, DOWNLOAD_ATHENA_QUERY_RESULTS] RESOURCES_ALL.extend(WORKSHEET_ALL) diff --git a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py new file mode 100644 index 000000000..6557f68c2 --- /dev/null +++ b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py @@ -0,0 +1,123 @@ +import csv +import io +import os +from datetime import datetime, timedelta +from typing import TYPE_CHECKING + +from openpyxl import Workbook + +from dataall.base.db import exceptions +from dataall.core.environment.services.environment_service import EnvironmentService +from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService +from dataall.modules.worksheets.aws.s3_client import S3Client +from dataall.modules.worksheets.db.worksheet_models import WorksheetQueryResult +from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository +from dataall.modules.worksheets.services.worksheet_enums import WorksheetResultsFormat +from dataall.modules.worksheets.services.worksheet_permissions import DOWNLOAD_ATHENA_QUERY_RESULTS, MANAGE_WORKSHEETS +from dataall.modules.worksheets.services.worksheet_service import WorksheetService + +if TYPE_CHECKING: + try: + from sqlalchemy.orm import Session + from openpyxl.worksheet.worksheet import Worksheet + except ImportError: + print('skipping type checks as stubs are not installed') + Session = None + Worksheet = None + + +class WorksheetQueryResultService: + _DEFAULT_ATHENA_QUERIES_PATH = 'athenaqueries' + _DEFAULT_QUERY_RESULTS_TIMEOUT = os.getenv('QUERY_RESULT_TIMEOUT_MINUTES', 120) + + @staticmethod + def _create_query_result( + environment_bucket: str, athena_workgroup: str, worksheet_uri: str, region: str, aws_account_id: str, data: dict + ) -> WorksheetQueryResult: + sql_query_result = WorksheetQueryResult( + worksheetUri=worksheet_uri, + AthenaQueryId=data.get('athenaQueryId'), + fileFormat=data.get('fileFormat'), + OutputLocation=f's3://{environment_bucket}/{WorksheetQueryResultService._DEFAULT_ATHENA_QUERIES_PATH}/{athena_workgroup}/', + region=region, + AwsAccountId=aws_account_id, + ) + return sql_query_result + + @staticmethod + def _build_s3_file_path(workgroup: str, query_id: str, athena_queries_dir: str = None) -> str: + athena_queries_dir = athena_queries_dir or WorksheetQueryResultService._DEFAULT_ATHENA_QUERIES_PATH + return f'{athena_queries_dir}/{workgroup}/{query_id}' + + @staticmethod + def _convert_csv_to_xlsx(csv_data) -> io.BytesIO: + wb = Workbook() + ws: 'Worksheet' = wb.active + csv_reader = csv.reader(csv_data.splitlines()) + for row in csv_reader: + ws.append(row) + + excel_buffer = io.BytesIO() + wb.save(excel_buffer) + excel_buffer.seek(0) + return excel_buffer + + @staticmethod + @TenantPolicyService.has_tenant_permission(MANAGE_WORKSHEETS) + @ResourcePolicyService.has_resource_permission(DOWNLOAD_ATHENA_QUERY_RESULTS) + def download_sql_query_result(session: 'Session', uri: str, env_uri: str, data: dict = None): + environment = EnvironmentService.get_environment_by_uri(session, env_uri) + worksheet = WorksheetService.get_worksheet_by_uri(session, uri) + env_group = EnvironmentService.get_environment_group( + session, worksheet.SamlAdminGroupName, environment.environmentUri + ) + sql_query_result = WorksheetRepository.find_query_result_by_format( + session, data.get('worksheetUri'), data.get('athenaQueryId'), data.get('fileFormat') + ) + s3_client = S3Client(environment) + if not sql_query_result: + sql_query_result = WorksheetQueryResultService._create_query_result( + environment.EnvironmentDefaultBucketName, + env_group.environmentAthenaWorkGroup, + worksheet.worksheetUri, + environment.region, + environment.AwsAccountId, + data, + ) + output_file_s3_path = WorksheetQueryResultService._build_s3_file_path( + env_group.environmentAthenaWorkGroup, data.get('athenaQueryId') + ) + if sql_query_result.fileFormat == WorksheetResultsFormat.XLSX.value: + try: + csv_data = s3_client.get_object( + bucket=environment.EnvironmentDefaultBucketName, + key=f'{output_file_s3_path}.{WorksheetResultsFormat.CSV.value}', + ) + excel_buffer = WorksheetQueryResultService._convert_csv_to_xlsx(csv_data) + s3_client.put_object( + bucket=environment.EnvironmentDefaultBucketName, + key=f'{output_file_s3_path}.{WorksheetResultsFormat.XLSX.value}', + body=excel_buffer, + ) + except Exception as e: + raise exceptions.AWSResourceNotAvailable('CONVERT_CSV_TO_EXCEL', f'Failed to convert csv to xlsx: {e}') + + s3_client.object_exists( + bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{sql_query_result.fileFormat}' + ) + if sql_query_result.is_download_link_expired(): + url = s3_client.get_presigned_url( + bucket=environment.EnvironmentDefaultBucketName, + key=f'{output_file_s3_path}.{sql_query_result.fileFormat}', + expire_minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT, + ) + sql_query_result.downloadLink = url + sql_query_result.expiresIn = datetime.utcnow() + timedelta( + minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT + ) + + session.add(sql_query_result) + session.commit() + + return sql_query_result diff --git a/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py b/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py index 5209963e8..1932a7fd5 100644 --- a/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py +++ b/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py @@ -24,7 +24,7 @@ # revision identifiers, used by Alembic. revision = '427db8f31999' -down_revision = 'f87aecc36d39' +down_revision = 'd1d6da1b2d67' branch_labels = None depends_on = None diff --git a/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py new file mode 100644 index 000000000..d4cf99819 --- /dev/null +++ b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py @@ -0,0 +1,46 @@ +"""add_columns_worksheet_query_result_model + +Revision ID: d1d6da1b2d67 +Revises: d274e756f0ae +Create Date: 2024-09-10 14:34:31.492186 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd1d6da1b2d67' +down_revision = 'f87aecc36d39' +branch_labels = None +depends_on = None + + +def upgrade(): + op.alter_column('worksheet_query_result', 'status', nullable=True) + op.alter_column('worksheet_query_result', 'sqlBody', nullable=True) + op.alter_column('worksheet_query_result', 'DataScannedInBytes', type_=sa.BigInteger(), nullable=True) + op.add_column('worksheet_query_result', sa.Column('downloadLink', sa.String(), nullable=True)) + op.add_column('worksheet_query_result', sa.Column('expiresIn', sa.DateTime(), nullable=True)) + op.add_column('worksheet_query_result', sa.Column('updated', sa.DateTime(), nullable=False)) + op.add_column('worksheet_query_result', sa.Column('fileFormat', sa.String(), nullable=True)) + op.add_column('worksheet_query_result', sa.Column('worksheetQueryResultUri', sa.String(), nullable=False)) + op.drop_constraint('AthenaQueryId', 'worksheet_query_result', type_='primary') + op.create_primary_key('worksheet_query_result_pkey', 'worksheet_query_result', ['worksheetQueryResultUri']) + op.alter_column('worksheet_query_result', 'AthenaQueryId', nullable=False) + op.drop_column('worksheet_query_result', 'queryType') + + +def downgrade(): + op.add_column('worksheet_query_result', sa.Column('queryType', sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_constraint('worksheet_query_result_pkey', 'worksheet_query_result', type_='primary') + op.create_primary_key('AthenaQueryId', 'worksheet_query_result', ['AthenaQueryId']) + op.drop_column('worksheet_query_result', 'worksheetQueryResultUri') + op.drop_column('worksheet_query_result', 'fileFormat') + op.drop_column('worksheet_query_result', 'updated') + op.drop_column('worksheet_query_result', 'expiresIn') + op.drop_column('worksheet_query_result', 'downloadLink') + op.alter_column('worksheet_query_result', 'DataScannedInBytes', type_=sa.Integer(), nullable=True) + op.alter_column('worksheet_query_result', 'sqlBody', nullable=False) + op.alter_column('worksheet_query_result', 'status', nullable=False) diff --git a/backend/requirements.txt b/backend/requirements.txt index 05cb6619c..4e18f9db7 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,7 +1,10 @@ ariadne==0.17.0 aws-xray-sdk==2.4.3 -boto3==1.35.26 -fastapi == 0.115.0 +boto3==1.34.119 +botocore==1.34.119 +fastapi == 0.109.2 +Flask==3.0.3 +flask-cors==5.0.0 nanoid==2.0.0 opensearch-py==1.0.0 PyAthena==2.3.0 @@ -11,5 +14,7 @@ PyYAML==6.0 requests==2.32.2 requests_aws4auth==1.1.1 sqlalchemy==1.3.24 +starlette==0.36.3 alembic==1.13.1 +openpyxl==3.1.5 retrying==1.3.4 diff --git a/frontend/src/modules/Worksheets/components/WorksheetResult.js b/frontend/src/modules/Worksheets/components/WorksheetResult.js index e8e74e557..62fa6dea4 100644 --- a/frontend/src/modules/Worksheets/components/WorksheetResult.js +++ b/frontend/src/modules/Worksheets/components/WorksheetResult.js @@ -12,12 +12,69 @@ import { TableRow } from '@mui/material'; import PropTypes from 'prop-types'; -import React from 'react'; +import React, { useState, useCallback } from 'react'; import { FaBars } from 'react-icons/fa'; import * as ReactIf from 'react-if'; import { Scrollbar } from 'design'; -export const WorksheetResult = ({ results, loading }) => { +import Stack from '@mui/material/Stack'; +import Radio from '@mui/material/Radio'; +import RadioGroup from '@mui/material/RadioGroup'; +import FormControlLabel from '@mui/material/FormControlLabel'; +import { Download } from '@mui/icons-material'; +import { LoadingButton } from '@mui/lab'; +import { useClient } from 'services'; +import { createWorksheetQueryResultDownloadUrl } from '../services'; +import { SET_ERROR, useDispatch } from 'globalErrors'; + +export const WorksheetResult = ({ + results, + loading, + currentEnv, + athenaQueryId, + worksheetUri +}) => { + const [runningDownloadQuery, setRunningDownloadQuery] = useState(false); + const [fileType, setFileType] = useState('csv'); + const client = useClient(); + const dispatch = useDispatch(); + + const handleChange = (event) => { + setFileType(event.target.value); + }; + + const downloadFile = useCallback(async () => { + try { + setRunningDownloadQuery(true); + const response = await client.query( + createWorksheetQueryResultDownloadUrl({ + fileFormat: fileType, + environmentUri: currentEnv.environmentUri, + athenaQueryId: athenaQueryId, + worksheetUri: worksheetUri + }) + ); + + if (!response.errors) { + const link = document.createElement('a'); + link.href = + response.data.createWorksheetQueryResultDownloadUrl.downloadLink; + // Append to html link element page + document.body.appendChild(link); + // Start download + link.click(); + // Clean up and remove the link + link.parentNode.removeChild(link); + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } catch (e) { + dispatch({ type: SET_ERROR, error: e.message }); + } finally { + setRunningDownloadQuery(false); + } + }, [client, dispatch, currentEnv, athenaQueryId, fileType]); + if (loading) { return ; } @@ -34,6 +91,41 @@ export const WorksheetResult = ({ results, loading }) => { Query Results } + action={ + <> + + + } + label="CSV" + /> + } + label="XLSX" + /> + + + } + sx={{ m: 1 }} + variant="contained" + > + Download + + + + } /> diff --git a/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js new file mode 100644 index 000000000..305ab5599 --- /dev/null +++ b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js @@ -0,0 +1,20 @@ +import { gql } from 'apollo-boost'; + +export const createWorksheetQueryResultDownloadUrl = (input) => ({ + variables: { + input + }, + query: gql` + mutation CreateWorksheetQueryResultDownloadUrl( + $input: WorksheetQueryResultDownloadUrlInput! + ) { + createWorksheetQueryResultDownloadUrl(input: $input) { + downloadLink + AthenaQueryId + expiresIn + fileFormat + outputLocation + } + } + ` +}); diff --git a/frontend/src/modules/Worksheets/services/index.js b/frontend/src/modules/Worksheets/services/index.js index b10e7d361..65637d6ae 100644 --- a/frontend/src/modules/Worksheets/services/index.js +++ b/frontend/src/modules/Worksheets/services/index.js @@ -6,3 +6,4 @@ export * from './listWorksheets'; export * from './runAthenaSqlQuery'; export * from './updateWorksheet'; export * from './listSharedDatasetTableColumns'; +export * from './createWorksheetQueryResultDownloadUrl'; diff --git a/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js b/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js index 8d5666cf4..80bda11bf 100644 --- a/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js +++ b/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js @@ -21,6 +21,7 @@ export const runAthenaSqlQuery = ({ worksheetUri: $worksheetUri sqlQuery: $sqlQuery ) { + AthenaQueryId rows { cells { columnName diff --git a/frontend/src/modules/Worksheets/views/WorksheetView.js b/frontend/src/modules/Worksheets/views/WorksheetView.js index 545c3f190..419e51d77 100644 --- a/frontend/src/modules/Worksheets/views/WorksheetView.js +++ b/frontend/src/modules/Worksheets/views/WorksheetView.js @@ -76,6 +76,8 @@ const WorksheetView = () => { const [runningQuery, setRunningQuery] = useState(false); const [isEditWorksheetOpen, setIsEditWorksheetOpen] = useState(null); const [isDeleteWorksheetOpen, setIsDeleteWorksheetOpen] = useState(null); + const [athenaQueryId, setAthenaQueryId] = useState(); + const handleEditWorksheetModalOpen = () => { setIsEditWorksheetOpen(true); }; @@ -291,6 +293,7 @@ const WorksheetView = () => { ); if (!response.errors) { const athenaResults = response.data.runAthenaSqlQuery; + setAthenaQueryId(response.data.runAthenaSqlQuery.AthenaQueryId); setResults({ rows: athenaResults.rows.map((c, index) => ({ ...c, id: index })), columns: athenaResults.columns.map((c, index) => ({ @@ -636,7 +639,13 @@ const WorksheetView = () => { - + diff --git a/tests/modules/worksheets/test_worksheet.py b/tests/modules/worksheets/test_worksheet.py index 0f3b10e41..0a7bd43dc 100644 --- a/tests/modules/worksheets/test_worksheet.py +++ b/tests/modules/worksheets/test_worksheet.py @@ -1,4 +1,5 @@ import pytest +from future.backports.datetime import datetime from dataall.modules.worksheets.api.resolvers import WorksheetRole @@ -30,6 +31,19 @@ def worksheet(client, tenant, group): return response.data.createWorksheet +@pytest.fixture(scope='module', autouse=True) +def mock_s3_client(module_mocker): + s3_client = module_mocker.patch( + 'dataall.modules.worksheets.services.worksheet_query_result_service.S3Client', autospec=True + ) + + s3_client.return_value.object_exists.return_value = True + s3_client.return_value.put_object.return_value = None + s3_client.return_value.get_object.return_value = '123,123,123' + s3_client.return_value.get_presigned_url.return_value = 'https://s3.amazonaws.com/file/123.csv' + yield s3_client + + def test_create_worksheet(client, worksheet): assert worksheet.label == 'my worksheet' assert worksheet.owner == 'alice' @@ -145,3 +159,106 @@ def test_update_worksheet(client, worksheet, group): ) assert response.data.updateWorksheet.label == 'change label' + + +def test_create_query_download_url(client, worksheet, env_fixture, group): + response = client.query( + """ + mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){ + createWorksheetQueryResultDownloadUrl(input: $input){ + sqlBody + AthenaQueryId + region + AwsAccountId + elapsedTimeInMs + created + downloadLink + outputLocation + expiresIn + fileFormat + } + } + """, + input={ + 'worksheetUri': worksheet.worksheetUri, + 'athenaQueryId': '123', + 'fileFormat': 'csv', + 'environmentUri': env_fixture.environmentUri, + }, + username='alice', + groups=[group.name], + ) + + expires_in = datetime.strptime( + response.data.createWorksheetQueryResultDownloadUrl.expiresIn, '%Y-%m-%d %H:%M:%S.%f' + ) + assert response.data.createWorksheetQueryResultDownloadUrl.downloadLink is not None + assert response.data.createWorksheetQueryResultDownloadUrl.fileFormat == 'csv' + assert expires_in > datetime.utcnow() + + +def test_tenant_unauthorized__create_query_download_url(client, worksheet, env_fixture, tenant): + response = client.query( + """ + mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){ + createWorksheetQueryResultDownloadUrl(input: $input){ + sqlBody + AthenaQueryId + region + AwsAccountId + elapsedTimeInMs + created + downloadLink + outputLocation + expiresIn + fileFormat + } + } + """, + input={ + 'worksheetUri': worksheet.worksheetUri, + 'athenaQueryId': '123', + 'fileFormat': 'csv', + 'environmentUri': env_fixture.environmentUri, + }, + ) + + assert response.errors is not None + assert len(response.errors) > 0 + assert f'is not authorized to perform: MANAGE_WORKSHEETS on {tenant.name}' in response.errors[0].message + + +def test_resource_unauthorized__create_query_download_url(client, worksheet, env_fixture, group2): + response = client.query( + """ + mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){ + createWorksheetQueryResultDownloadUrl(input: $input){ + sqlBody + AthenaQueryId + region + AwsAccountId + elapsedTimeInMs + created + downloadLink + outputLocation + expiresIn + fileFormat + } + } + """, + input={ + 'worksheetUri': worksheet.worksheetUri, + 'athenaQueryId': '123', + 'fileFormat': 'csv', + 'environmentUri': env_fixture.environmentUri, + }, + username='bob', + groups=[group2.name], + ) + + assert response.errors is not None + assert len(response.errors) > 0 + assert ( + f'is not authorized to perform: DOWNLOAD_ATHENA_QUERY_RESULTS on resource: {worksheet.worksheetUri}' + in response.errors[0].message + )