From 8af5e16ca23538f890222837e4e06838adfe9010 Mon Sep 17 00:00:00 2001 From: anison Date: Tue, 10 Sep 2024 15:28:43 +0530 Subject: [PATCH 1/8] Initial query download feature changes --- backend/dataall/base/aws/s3_client.py | 29 ++++ .../modules/worksheets/api/input_types.py | 11 ++ .../modules/worksheets/api/mutations.py | 16 +- .../modules/worksheets/api/resolvers.py | 8 + .../dataall/modules/worksheets/api/types.py | 10 +- .../modules/worksheets/db/worksheet_models.py | 16 +- .../worksheets/db/worksheet_repositories.py | 14 ++ .../worksheet_query_result_service.py | 139 ++++++++++++++++++ ...dd_columns_worksheet_query_result_model.py | 37 +++++ backend/requirements.txt | 3 +- tests/modules/worksheets/test_worksheet.py | 69 +++++++++ 11 files changed, 342 insertions(+), 10 deletions(-) create mode 100644 backend/dataall/modules/worksheets/services/worksheet_query_result_service.py create mode 100644 backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py diff --git a/backend/dataall/base/aws/s3_client.py b/backend/dataall/base/aws/s3_client.py index 74f79e7b4..04d43e6b3 100644 --- a/backend/dataall/base/aws/s3_client.py +++ b/backend/dataall/base/aws/s3_client.py @@ -3,6 +3,7 @@ from botocore.exceptions import ClientError import logging +from dataall.base.db.exceptions import AWSResourceNotFound log = logging.getLogger(__name__) @@ -33,3 +34,31 @@ def get_presigned_url(region, bucket, key, expire_minutes: int = 15): except ClientError as e: log.error(f'Failed to get presigned URL due to: {e}') raise e + + @staticmethod + def object_exists(region, bucket, key) -> bool: + try: + S3_client.client(region, None).head_object(Bucket=bucket, Key=key) + return True + except ClientError as e: + log.error(f'Failed to check object existence due to: {e}') + if e.response['Error']['Code'] == '404': + return False + raise AWSResourceNotFound('s3_object_exists', f'Object {key} not found in bucket {bucket}') + + @staticmethod + def put_object(region, bucket, key, body): + try: + S3_client.client(region, None).put_object(Bucket=bucket, Key=key, Body=body) + except ClientError as e: + log.error(f'Failed to put object due to: {e}') + raise e + + @staticmethod + def get_object(region, bucket, key): + try: + response = S3_client.client(region, None).get_object(Bucket=bucket, Key=key) + return response['Body'].read().decode('utf-8') + except ClientError as e: + log.error(f'Failed to get object due to: {e}') + raise e diff --git a/backend/dataall/modules/worksheets/api/input_types.py b/backend/dataall/modules/worksheets/api/input_types.py index e2ec62dea..9fbeb75b1 100644 --- a/backend/dataall/modules/worksheets/api/input_types.py +++ b/backend/dataall/modules/worksheets/api/input_types.py @@ -75,3 +75,14 @@ gql.Argument(name='measures', type=gql.ArrayType(gql.Ref('WorksheetMeasureInput'))), ], ) + + +WorksheetQueryResultDownloadUrlInput = gql.InputType( + name='WorksheetQueryResultDownloadUrlInput', + arguments=[ + gql.Argument(name='athenaQueryId', type=gql.NonNullableType(gql.String)), + gql.Argument(name='fileFormat', type=gql.NonNullableType(gql.String)), + gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), + gql.Argument(name='worksheetUri', type=gql.NonNullableType(gql.String)), + ], +) diff --git a/backend/dataall/modules/worksheets/api/mutations.py b/backend/dataall/modules/worksheets/api/mutations.py index a232fba83..efe4afa6b 100644 --- a/backend/dataall/modules/worksheets/api/mutations.py +++ b/backend/dataall/modules/worksheets/api/mutations.py @@ -1,5 +1,10 @@ from dataall.base.api import gql -from dataall.modules.worksheets.api.resolvers import create_worksheet, delete_worksheet, update_worksheet +from dataall.modules.worksheets.api.resolvers import ( + create_worksheet, + delete_worksheet, + update_worksheet, + create_athena_query_result_download_url, +) createWorksheet = gql.MutationField( @@ -27,3 +32,12 @@ ], type=gql.Boolean, ) + +createWorksheetQueryResultDownloadUrl = gql.MutationField( + name='createWorksheetQueryResultDownloadUrl', + resolver=create_athena_query_result_download_url, + args=[ + gql.Argument(name='input', type=gql.Ref('WorksheetQueryResultDownloadUrlInput')), + ], + type=gql.Ref('WorksheetQueryResult'), +) diff --git a/backend/dataall/modules/worksheets/api/resolvers.py b/backend/dataall/modules/worksheets/api/resolvers.py index 450667217..adc90f170 100644 --- a/backend/dataall/modules/worksheets/api/resolvers.py +++ b/backend/dataall/modules/worksheets/api/resolvers.py @@ -4,6 +4,7 @@ from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository from dataall.modules.worksheets.services.worksheet_service import WorksheetService from dataall.base.api.context import Context +from dataall.modules.worksheets.services.worksheet_query_result_service import WorksheetQueryResultService def create_worksheet(context: Context, source, input: dict = None): @@ -69,3 +70,10 @@ def run_sql_query(context: Context, source, environmentUri: str = None, workshee def delete_worksheet(context, source, worksheetUri: str = None): with context.engine.scoped_session() as session: return WorksheetService.delete_worksheet(session=session, uri=worksheetUri) + + +def create_athena_query_result_download_url(context: Context, source, input: dict = None): + WorksheetQueryResultService.validate_input(input) + + with context.engine.scoped_session() as session: + return WorksheetQueryResultService.download_sql_query_result(session=session, data=input) diff --git a/backend/dataall/modules/worksheets/api/types.py b/backend/dataall/modules/worksheets/api/types.py index e9187353f..72e2e6344 100644 --- a/backend/dataall/modules/worksheets/api/types.py +++ b/backend/dataall/modules/worksheets/api/types.py @@ -87,14 +87,16 @@ fields=[ gql.Field(name='worksheetQueryResultUri', type=gql.ID), gql.Field(name='queryType', type=gql.NonNullableType(gql.String)), - gql.Field(name='sqlBody', type=gql.NonNullableType(gql.String)), + gql.Field(name='sqlBody', type=gql.String), gql.Field(name='AthenaQueryId', type=gql.NonNullableType(gql.String)), gql.Field(name='region', type=gql.NonNullableType(gql.String)), gql.Field(name='AwsAccountId', type=gql.NonNullableType(gql.String)), - gql.Field(name='AthenaOutputBucketName', type=gql.NonNullableType(gql.String)), - gql.Field(name='AthenaOutputKey', type=gql.NonNullableType(gql.String)), - gql.Field(name='timeElapsedInSecond', type=gql.NonNullableType(gql.Integer)), + gql.Field(name='ElapsedTimeInMs', type=gql.Integer), gql.Field(name='created', type=gql.NonNullableType(gql.String)), + gql.Field(name='downloadLink', type=gql.String), + gql.Field(name='OutputLocation', type=gql.String), + gql.Field(name='expiresIn', type=gql.AWSDateTime), + gql.Field(name='fileFormat', type=gql.String), ], ) diff --git a/backend/dataall/modules/worksheets/db/worksheet_models.py b/backend/dataall/modules/worksheets/db/worksheet_models.py index 6549cb96c..281c4b369 100644 --- a/backend/dataall/modules/worksheets/db/worksheet_models.py +++ b/backend/dataall/modules/worksheets/db/worksheet_models.py @@ -1,7 +1,7 @@ import datetime import enum -from sqlalchemy import Column, DateTime, Integer, Enum, String +from sqlalchemy import Column, DateTime, Integer, Enum, String, BigInteger from sqlalchemy.dialects import postgresql from sqlalchemy.orm import query_expression @@ -29,13 +29,21 @@ class WorksheetQueryResult(Base): __tablename__ = 'worksheet_query_result' worksheetUri = Column(String, nullable=False) AthenaQueryId = Column(String, primary_key=True) - status = Column(String, nullable=False) + status = Column(String, nullable=True) queryType = Column(Enum(QueryType), nullable=False, default=True) - sqlBody = Column(String, nullable=False) + sqlBody = Column(String, nullable=True) AwsAccountId = Column(String, nullable=False) region = Column(String, nullable=False) OutputLocation = Column(String, nullable=False) error = Column(String, nullable=True) ElapsedTimeInMs = Column(Integer, nullable=True) - DataScannedInBytes = Column(Integer, nullable=True) + DataScannedInBytes = Column(BigInteger, nullable=True) created = Column(DateTime, default=datetime.datetime.now) + + downloadLink = Column(String, nullable=True) + expiresIn = Column(DateTime, nullable=True) + updated = Column(DateTime, nullable=False, onupdate=datetime.datetime.utcnow, default=datetime.datetime.utcnow) + fileFormat = Column(String, nullable=True) + + def is_download_link_expired(self): + return self.expiresIn is None or self.expiresIn <= datetime.datetime.utcnow() diff --git a/backend/dataall/modules/worksheets/db/worksheet_repositories.py b/backend/dataall/modules/worksheets/db/worksheet_repositories.py index aea51761b..cac450943 100644 --- a/backend/dataall/modules/worksheets/db/worksheet_repositories.py +++ b/backend/dataall/modules/worksheets/db/worksheet_repositories.py @@ -53,3 +53,17 @@ def paginated_user_worksheets(session, username, groups, uri, data=None, check_p page=data.get('page', WorksheetRepository._DEFAULT_PAGE), page_size=data.get('pageSize', WorksheetRepository._DEFAULT_PAGE_SIZE), ).to_dict() + + @staticmethod + def find_query_result_by_format( + session, worksheet_uri: str, athena_query_id: str, file_format: str + ) -> WorksheetQueryResult: + return ( + session.query(WorksheetQueryResult) + .filter( + WorksheetQueryResult.worksheetUri == worksheet_uri, + WorksheetQueryResult.AthenaQueryId == athena_query_id, + WorksheetQueryResult.fileFormat == file_format, + ) + .first() + ) diff --git a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py new file mode 100644 index 000000000..13cefb95b --- /dev/null +++ b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py @@ -0,0 +1,139 @@ +import csv +import io +import os +from datetime import datetime, timedelta +from typing import TYPE_CHECKING + +from openpyxl import Workbook + +from dataall.base.aws.s3_client import S3_client +from dataall.base.db import exceptions +from dataall.core.environment.services.environment_service import EnvironmentService +from dataall.modules.worksheets.db.worksheet_models import WorksheetQueryResult +from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository +from dataall.modules.worksheets.services.worksheet_service import WorksheetService + +if TYPE_CHECKING: + try: + from sqlalchemy.orm import Session + from mypy_boto3_s3.client import S3Client + except ImportError: + print('skipping type checks as stubs are not installed') + S3Client = None + Session = None + + +class WorksheetQueryResultService: + SupportedFormats = {'csv', 'xlsx'} + + @staticmethod + def validate_input(data): + if not data: + raise exceptions.InvalidInput('data', data, 'input is required') + if not data.get('athenaQueryId'): + raise exceptions.RequiredParameter('athenaQueryId') + if not data.get('fileFormat'): + raise exceptions.RequiredParameter('fileFormat') + if data.get('fileFormat', '').lower() not in WorksheetQueryResultService.SupportedFormats: + raise exceptions.InvalidInput( + 'fileFormat', data.get('fileFormat'), ', '.join(WorksheetQueryResultService.SupportedFormats) + ) + + @staticmethod + def get_output_bucket(session: 'Session', environment_uri: str) -> str: + environment = EnvironmentService.get_environment_by_uri(session, environment_uri) + bucket = environment.EnvironmentDefaultBucketName + return bucket + + @staticmethod + def create_query_result( + environment_bucket: str, athena_workgroup: str, worksheet_uri: str, region: str, aws_account_id: str, data: dict + ) -> WorksheetQueryResult: + sql_query_result = WorksheetQueryResult( + worksheetUri=worksheet_uri, + AthenaQueryId=data.get('athenaQueryId'), + fileFormat=data.get('fileFormat'), + OutputLocation=f's3://{environment_bucket}/athenaqueries/{athena_workgroup}/', + region=region, + AwsAccountId=aws_account_id, + queryType='data', + ) + return sql_query_result + + @staticmethod + def get_file_key( + workgroup: str, query_id: str, file_format: str = 'csv', athena_queries_dir: str = 'athenaqueries' + ) -> str: + return f'{athena_queries_dir}/{workgroup}/{query_id}.{file_format}' + + @staticmethod + def convert_csv_to_xlsx(csv_data) -> io.BytesIO: + wb = Workbook() + ws = wb.active + csv_reader = csv.reader(csv_data.splitlines()) + for row in csv_reader: + ws.append(row) + + excel_buffer = io.BytesIO() + wb.save(excel_buffer) + excel_buffer.seek(0) + return excel_buffer + + @staticmethod + def handle_xlsx_format(output_bucket: str, file_key: str) -> bool: + aws_region_name = os.getenv('AWS_REGION_NAME', 'eu-west-1') + file_name, _ = file_key.split('.') + csv_data = S3_client.get_object(region=aws_region_name, bucket=output_bucket, key=f'{file_name}.csv') + excel_buffer = WorksheetQueryResultService.convert_csv_to_xlsx(csv_data) + S3_client.put_object(region=aws_region_name, bucket=output_bucket, key=file_key, body=excel_buffer) + return True + + @staticmethod + def download_sql_query_result(session: 'Session', data: dict = None): + # # default timeout for the download link is 2 hours(in minutes) + default_timeout = os.getenv('QUERY_RESULT_TIMEOUT_MINUTES', 120) + + environment = EnvironmentService.get_environment_by_uri(session, data.get('environmentUri')) + worksheet = WorksheetService.get_worksheet_by_uri(session, data.get('worksheetUri')) + env_group = EnvironmentService.get_environment_group( + session, worksheet.SamlAdminGroupName, environment.environmentUri + ) + output_file_key = WorksheetQueryResultService.get_file_key( + env_group.environmentAthenaWorkGroup, data.get('athenaQueryId'), data.get('fileFormat') + ) + sql_query_result = WorksheetRepository.find_query_result_by_format( + session, data.get('worksheetUri'), data.get('athenaQueryId'), data.get('fileFormat') + ) + if data.get('fileFormat') == 'xlsx': + is_job_failed = WorksheetQueryResultService.handle_xlsx_format( + environment.EnvironmentDefaultBucketName, output_file_key + ) + if is_job_failed: + raise ValueError('Error while preparing the xlsx file') + + if not sql_query_result: + sql_query_result = WorksheetQueryResultService.create_query_result( + environment.EnvironmentDefaultBucketName, + env_group.environmentAthenaWorkGroup, + worksheet.worksheetUri, + environment.region, + environment.AwsAccountId, + data, + ) + S3_client.object_exists( + region=environment.region, bucket=environment.EnvironmentDefaultBucketName, key=output_file_key + ) + if sql_query_result.is_download_link_expired(): + url = S3_client.get_presigned_url( + region=environment.region, + bucket=environment.EnvironmentDefaultBucketName, + key=output_file_key, + expire_minutes=default_timeout, + ) + sql_query_result.downloadLink = url + sql_query_result.expiresIn = datetime.utcnow() + timedelta(seconds=default_timeout) + + session.add(sql_query_result) + session.commit() + + return sql_query_result diff --git a/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py new file mode 100644 index 000000000..4b0b48e7d --- /dev/null +++ b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py @@ -0,0 +1,37 @@ +"""add_columns_worksheet_query_result_model + +Revision ID: d1d6da1b2d67 +Revises: d274e756f0ae +Create Date: 2024-09-10 14:34:31.492186 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd1d6da1b2d67' +down_revision = 'd274e756f0ae' +branch_labels = None +depends_on = None + + +def upgrade(): + op.alter_column('worksheet_query_result', 'status', nullable=True) + op.alter_column('worksheet_query_result', 'sqlBody', nullable=True) + op.alter_column('worksheet_query_result', 'DataScannedInBytes', type_=sa.BigInteger(), nullable=True) + op.add_column('worksheet_query_result', sa.Column('downloadLink', sa.String(), nullable=True)) + op.add_column('worksheet_query_result', sa.Column('expiresIn', sa.DateTime(), nullable=True)) + op.add_column('worksheet_query_result', sa.Column('updated', sa.DateTime(), nullable=False)) + op.add_column('worksheet_query_result', sa.Column('fileFormat', sa.String(), nullable=True)) + + +def downgrade(): + op.drop_column('worksheet_query_result', 'fileFormat') + op.drop_column('worksheet_query_result', 'updated') + op.drop_column('worksheet_query_result', 'expiresIn') + op.drop_column('worksheet_query_result', 'downloadLink') + op.alter_column('worksheet_query_result', 'DataScannedInBytes', type_=sa.Integer(), nullable=True) + op.alter_column('worksheet_query_result', 'sqlBody', nullable=False) + op.alter_column('worksheet_query_result', 'status', nullable=False) diff --git a/backend/requirements.txt b/backend/requirements.txt index 94f7927ca..d7b6e35f9 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -15,4 +15,5 @@ requests==2.32.2 requests_aws4auth==1.1.1 sqlalchemy==1.3.24 starlette==0.36.3 -alembic==1.13.1 \ No newline at end of file +alembic==1.13.1 +openpyxl==3.1.5 diff --git a/tests/modules/worksheets/test_worksheet.py b/tests/modules/worksheets/test_worksheet.py index 0f3b10e41..2edffe0a7 100644 --- a/tests/modules/worksheets/test_worksheet.py +++ b/tests/modules/worksheets/test_worksheet.py @@ -1,5 +1,9 @@ import pytest +from unittest.mock import MagicMock + +from future.backports.datetime import datetime + from dataall.modules.worksheets.api.resolvers import WorksheetRole @@ -30,6 +34,38 @@ def worksheet(client, tenant, group): return response.data.createWorksheet +@pytest.fixture(scope='module', autouse=True) +def mock_s3_client(module_mocker): + s3_client = MagicMock() + module_mocker.patch('dataall.modules.worksheets.services.worksheet_query_result_service.S3_client', s3_client) + + # s3_client.client.return_value = s3_client + + s3_client().object_exists.return_value = True + s3_client().put_object.return_value = None + s3_client().get_object.return_value = '123,123,123' + s3_client.get_presigned_url.return_value = 'https://s3.amazonaws.com/file/123.csv' + yield s3_client + + +# @pytest.fixture(scope='module') +# def dataset1( +# module_mocker, +# org_fixture: Organization, +# env_fixture: Environment, +# dataset: typing.Callable, +# group, +# ) -> S3Dataset: +# kms_client = MagicMock() +# module_mocker.patch('dataall.modules.s3_datasets.services.dataset_service.KmsClient', kms_client) +# +# kms_client().get_key_id.return_value = mocked_key_id +# +# d = dataset(org=org_fixture, env=env_fixture, name='dataset1', owner=env_fixture.owner, group=group.name) +# print(d) +# yield d + + def test_create_worksheet(client, worksheet): assert worksheet.label == 'my worksheet' assert worksheet.owner == 'alice' @@ -145,3 +181,36 @@ def test_update_worksheet(client, worksheet, group): ) assert response.data.updateWorksheet.label == 'change label' + + +def test_create_query_download_url(client, worksheet, env_fixture): + response = client.query( + """ + mutation CreateWorksheetQueryResultDownloadUrl($input:WorksheetQueryResultDownloadUrlInput){ + createWorksheetQueryResultDownloadUrl(input:$input){ + queryType + sqlBody + AthenaQueryId + region + AwsAccountId + ElapsedTimeInMs + created + downloadLink + OutputLocation + expiresIn + fileFormat + } + } + """, + input={ + 'worksheetUri': worksheet.worksheetUri, + 'athenaQueryId': '123', + 'fileFormat': 'csv', + 'environmentUri': env_fixture.environmentUri, + }, + ) + + expires_in = datetime.strptime(response.data.createWorksheetQueryResultDownloadUrl.created, '%Y-%m-%d %H:%M:%S.%f') + assert response.data.createWorksheetQueryResultDownloadUrl.downloadLink is not None + assert response.data.createWorksheetQueryResultDownloadUrl.fileFormat == 'csv' + assert expires_in > datetime.utcnow() From 01d681db39722f3ad04493271cf2b02533789e92 Mon Sep 17 00:00:00 2001 From: anisubhra-syncron Date: Wed, 11 Sep 2024 12:10:49 +0530 Subject: [PATCH 2/8] download feature implementation from frontend --- .../Worksheets/components/WorksheetResult.js | 86 ++++++++++++++++++- .../createWorksheetQueryResultDownloadUrl.js | 40 +++++++++ .../src/modules/Worksheets/services/index.js | 1 + .../Worksheets/services/runAthenaSqlQuery.js | 1 + .../modules/Worksheets/views/WorksheetView.js | 11 ++- 5 files changed, 136 insertions(+), 3 deletions(-) create mode 100644 frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js diff --git a/frontend/src/modules/Worksheets/components/WorksheetResult.js b/frontend/src/modules/Worksheets/components/WorksheetResult.js index e8e74e557..faee82e52 100644 --- a/frontend/src/modules/Worksheets/components/WorksheetResult.js +++ b/frontend/src/modules/Worksheets/components/WorksheetResult.js @@ -12,12 +12,66 @@ import { TableRow } from '@mui/material'; import PropTypes from 'prop-types'; -import React from 'react'; +import React, { useState, useCallback } from 'react'; import { FaBars } from 'react-icons/fa'; import * as ReactIf from 'react-if'; import { Scrollbar } from 'design'; -export const WorksheetResult = ({ results, loading }) => { +import Stack from '@mui/material/Stack'; +import Radio from '@mui/material/Radio'; +import RadioGroup from '@mui/material/RadioGroup'; +import FormControlLabel from '@mui/material/FormControlLabel'; +import { Download } from '@mui/icons-material'; +import { LoadingButton } from '@mui/lab'; +import { + useClient +} from 'services'; +import { + createWorksheetQueryResultDownloadUrl +} from '../services'; +import { SET_ERROR, useDispatch } from 'globalErrors'; + +export const WorksheetResult = ({ results, loading, currentEnv, athenaQueryId, worksheetUri }) => { + const [runningDownloadQuery, setRunningDownloadQuery] = useState(false); + const [fileType, setFileType] = useState('csv'); + const client = useClient(); + const dispatch = useDispatch(); + + const handleChange = (event) => { + setFileType(event.target.value); + }; + + const downloadFile = useCallback(async () => { + try { + setRunningDownloadQuery(true); + const response = await client.query( + createWorksheetQueryResultDownloadUrl({ + fileFormat: fileType, + environmentUri: currentEnv.environmentUri, + athenaQueryId: athenaQueryId, + worksheetUri: worksheetUri + }) + ); + + if (!response.errors) { + const link = document.createElement('a'); + link.href = response.data.createWorksheetQueryResultDownloadUrl.downloadLink; + // Append to html link element page + document.body.appendChild(link); + // Start download + link.click(); + // Clean up and remove the link + link.parentNode.removeChild(link); + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } catch (e) { + dispatch({ type: SET_ERROR, error: e.message }); + } finally { + setRunningDownloadQuery(false); + } + }, [client, dispatch, currentEnv, athenaQueryId, fileType]) + if (loading) { return ; } @@ -34,6 +88,34 @@ export const WorksheetResult = ({ results, loading }) => { Query Results } + action={ + <> + + + } label="CSV" /> + } label="XLSX" /> + + + } + sx={{ m: 1 }} + variant="contained" + > + Download + + + + + } /> diff --git a/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js new file mode 100644 index 000000000..c5fec1942 --- /dev/null +++ b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js @@ -0,0 +1,40 @@ +import { gql } from 'apollo-boost'; + +export const createWorksheetQueryResultDownloadUrl = ({ + fileFormat, + environmentUri, + athenaQueryId, + worksheetUri +}) => ({ + variables: { + fileFormat, + environmentUri, + athenaQueryId, + worksheetUri + }, + query: gql` + mutation CreateWorksheetQueryResultDownloadUrl( + $fileFormat: String! + $environmentUri: String! + $athenaQueryId: String! + $worksheetUri: String! + ) { + createWorksheetQueryResultDownloadUrl( + input: { + fileFormat: $fileFormat + environmentUri: $environmentUri + athenaQueryId: $athenaQueryId + worksheetUri: $worksheetUri + } + ) { + downloadLink + AthenaQueryId + expiresIn + fileFormat + OutputLocation + } + } + ` +}); + + diff --git a/frontend/src/modules/Worksheets/services/index.js b/frontend/src/modules/Worksheets/services/index.js index b10e7d361..ee5a24b33 100644 --- a/frontend/src/modules/Worksheets/services/index.js +++ b/frontend/src/modules/Worksheets/services/index.js @@ -6,3 +6,4 @@ export * from './listWorksheets'; export * from './runAthenaSqlQuery'; export * from './updateWorksheet'; export * from './listSharedDatasetTableColumns'; +export * from './createWorksheetQueryResultDownloadUrl'; \ No newline at end of file diff --git a/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js b/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js index 8d5666cf4..80bda11bf 100644 --- a/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js +++ b/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js @@ -21,6 +21,7 @@ export const runAthenaSqlQuery = ({ worksheetUri: $worksheetUri sqlQuery: $sqlQuery ) { + AthenaQueryId rows { cells { columnName diff --git a/frontend/src/modules/Worksheets/views/WorksheetView.js b/frontend/src/modules/Worksheets/views/WorksheetView.js index 545c3f190..419e51d77 100644 --- a/frontend/src/modules/Worksheets/views/WorksheetView.js +++ b/frontend/src/modules/Worksheets/views/WorksheetView.js @@ -76,6 +76,8 @@ const WorksheetView = () => { const [runningQuery, setRunningQuery] = useState(false); const [isEditWorksheetOpen, setIsEditWorksheetOpen] = useState(null); const [isDeleteWorksheetOpen, setIsDeleteWorksheetOpen] = useState(null); + const [athenaQueryId, setAthenaQueryId] = useState(); + const handleEditWorksheetModalOpen = () => { setIsEditWorksheetOpen(true); }; @@ -291,6 +293,7 @@ const WorksheetView = () => { ); if (!response.errors) { const athenaResults = response.data.runAthenaSqlQuery; + setAthenaQueryId(response.data.runAthenaSqlQuery.AthenaQueryId); setResults({ rows: athenaResults.rows.map((c, index) => ({ ...c, id: index })), columns: athenaResults.columns.map((c, index) => ({ @@ -636,7 +639,13 @@ const WorksheetView = () => { - + From 6c8cd5477b1fcaf6bb44d99b3cf56cdfe66624bc Mon Sep 17 00:00:00 2001 From: anison Date: Thu, 10 Oct 2024 14:00:45 +0530 Subject: [PATCH 3/8] fixing download query result issues --- .../modules/worksheets/api/resolvers.py | 15 ++- .../dataall/modules/worksheets/api/types.py | 5 +- .../modules/worksheets/aws/s3_client.py | 73 ++++++++++++++ .../modules/worksheets/db/worksheet_models.py | 5 +- .../enums.py => services/worksheet_enums.py} | 5 + .../worksheet_query_result_service.py | 97 +++++++------------ ...dd_columns_worksheet_query_result_model.py | 9 +- .../createWorksheetQueryResultDownloadUrl.js | 28 +----- tests/modules/worksheets/test_worksheet.py | 23 +---- 9 files changed, 146 insertions(+), 114 deletions(-) create mode 100644 backend/dataall/modules/worksheets/aws/s3_client.py rename backend/dataall/modules/worksheets/{api/enums.py => services/worksheet_enums.py} (65%) diff --git a/backend/dataall/modules/worksheets/api/resolvers.py b/backend/dataall/modules/worksheets/api/resolvers.py index adc90f170..331973be1 100644 --- a/backend/dataall/modules/worksheets/api/resolvers.py +++ b/backend/dataall/modules/worksheets/api/resolvers.py @@ -1,5 +1,5 @@ from dataall.base.db import exceptions -from dataall.modules.worksheets.api.enums import WorksheetRole +from dataall.modules.worksheets.services.worksheet_enums import WorksheetRole, WorksheetResultsFormat from dataall.modules.worksheets.db.worksheet_models import Worksheet from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository from dataall.modules.worksheets.services.worksheet_service import WorksheetService @@ -73,7 +73,18 @@ def delete_worksheet(context, source, worksheetUri: str = None): def create_athena_query_result_download_url(context: Context, source, input: dict = None): - WorksheetQueryResultService.validate_input(input) + + if not input: + # raise exceptions.InvalidInput('data', input, 'input is required') + raise exceptions.RequiredParameter('data') + if not input.get('athenaQueryId'): + raise exceptions.RequiredParameter('athenaQueryId') + if not input.get('fileFormat'): + raise exceptions.RequiredParameter('fileFormat') + if not hasattr(WorksheetResultsFormat, input.get('fileFormat').upper()): + raise exceptions.InvalidInput( + 'fileFormat', input.get('fileFormat'), + ', '.join(result_format.value for result_format in WorksheetResultsFormat)) with context.engine.scoped_session() as session: return WorksheetQueryResultService.download_sql_query_result(session=session, data=input) diff --git a/backend/dataall/modules/worksheets/api/types.py b/backend/dataall/modules/worksheets/api/types.py index 72e2e6344..f55741f8b 100644 --- a/backend/dataall/modules/worksheets/api/types.py +++ b/backend/dataall/modules/worksheets/api/types.py @@ -86,15 +86,14 @@ name='WorksheetQueryResult', fields=[ gql.Field(name='worksheetQueryResultUri', type=gql.ID), - gql.Field(name='queryType', type=gql.NonNullableType(gql.String)), gql.Field(name='sqlBody', type=gql.String), gql.Field(name='AthenaQueryId', type=gql.NonNullableType(gql.String)), gql.Field(name='region', type=gql.NonNullableType(gql.String)), gql.Field(name='AwsAccountId', type=gql.NonNullableType(gql.String)), - gql.Field(name='ElapsedTimeInMs', type=gql.Integer), + gql.Field(name='elapsedTimeInMs', type=gql.Integer), gql.Field(name='created', type=gql.NonNullableType(gql.String)), gql.Field(name='downloadLink', type=gql.String), - gql.Field(name='OutputLocation', type=gql.String), + gql.Field(name='outputLocation', type=gql.String), gql.Field(name='expiresIn', type=gql.AWSDateTime), gql.Field(name='fileFormat', type=gql.String), ], diff --git a/backend/dataall/modules/worksheets/aws/s3_client.py b/backend/dataall/modules/worksheets/aws/s3_client.py new file mode 100644 index 000000000..8cab77e85 --- /dev/null +++ b/backend/dataall/modules/worksheets/aws/s3_client.py @@ -0,0 +1,73 @@ +import boto3 +from botocore.config import Config + +from botocore.exceptions import ClientError +import logging +from dataall.base.db.exceptions import AWSResourceNotFound +from dataall.base.aws.sts import SessionHelper + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dataall.core.environment.db.environment_models import Environment + try: + from mypy_boto3_s3 import S3Client as S3ClientType + except ImportError: + S3ClientType = None + +log = logging.getLogger(__name__) + + +class S3Client: + + def __init__(self, env: 'Environment'): + self._client = SessionHelper.remote_session(env.AwsAccountId, env.region).client('s3', region_name=env.region) + self._env = env + + @property + def client(self) -> 'S3ClientType': + return self._client + + def get_presigned_url(self, bucket, key, expire_minutes: int = 15): + expire_seconds = expire_minutes * 60 + try: + presigned_url = self.client.generate_presigned_url( + 'get_object', + Params=dict( + Bucket=bucket, + Key=key, + ), + ExpiresIn=expire_seconds, + ) + return presigned_url + except ClientError as e: + log.error(f'Failed to get presigned URL due to: {e}') + raise e + + def object_exists(self, bucket, key) -> bool: + try: + self.client.head_object(Bucket=bucket, Key=key) + return True + except ClientError as e: + if e.response['Error']['Code'] == '404': + log.info(f'Object {key} not found in bucket {bucket}') + return False + log.error(f'Failed to check object existence due to: {e}') + raise AWSResourceNotFound('s3_object_exists', f'Object {key} not found in bucket {bucket}') + + + def put_object(self, bucket, key, body): + try: + self.client.put_object(Bucket=bucket, Key=key, Body=body) + except ClientError as e: + log.error(f'Failed to put object due to: {e}') + raise e + + + def get_object(self, bucket, key) -> str: + try: + response = self.client.get_object(Bucket=bucket, Key=key) + return response['Body'].read().decode('utf-8') + except ClientError as e: + log.error(f'Failed to get object due to: {e}') + raise e diff --git a/backend/dataall/modules/worksheets/db/worksheet_models.py b/backend/dataall/modules/worksheets/db/worksheet_models.py index 281c4b369..5342a65b2 100644 --- a/backend/dataall/modules/worksheets/db/worksheet_models.py +++ b/backend/dataall/modules/worksheets/db/worksheet_models.py @@ -1,6 +1,7 @@ import datetime import enum +from future.backports.email.policy import default from sqlalchemy import Column, DateTime, Integer, Enum, String, BigInteger from sqlalchemy.dialects import postgresql from sqlalchemy.orm import query_expression @@ -27,10 +28,10 @@ class Worksheet(Resource, Base): class WorksheetQueryResult(Base): __tablename__ = 'worksheet_query_result' + worksheetQueryResultUri = Column(String, primary_key=True, default=utils.uuid('worksheetQueryResultUri')) worksheetUri = Column(String, nullable=False) - AthenaQueryId = Column(String, primary_key=True) + AthenaQueryId = Column(String, nullable=False) status = Column(String, nullable=True) - queryType = Column(Enum(QueryType), nullable=False, default=True) sqlBody = Column(String, nullable=True) AwsAccountId = Column(String, nullable=False) region = Column(String, nullable=False) diff --git a/backend/dataall/modules/worksheets/api/enums.py b/backend/dataall/modules/worksheets/services/worksheet_enums.py similarity index 65% rename from backend/dataall/modules/worksheets/api/enums.py rename to backend/dataall/modules/worksheets/services/worksheet_enums.py index 3e9549f2a..c973c2399 100644 --- a/backend/dataall/modules/worksheets/api/enums.py +++ b/backend/dataall/modules/worksheets/services/worksheet_enums.py @@ -5,3 +5,8 @@ class WorksheetRole(GraphQLEnumMapper): Creator = '950' Admin = '900' NoPermission = '000' + + +class WorksheetResultsFormat(GraphQLEnumMapper): + CSV = 'csv' + XLSX = 'xlsx' diff --git a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py index 13cefb95b..58840b148 100644 --- a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py +++ b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py @@ -1,52 +1,37 @@ import csv import io import os -from datetime import datetime, timedelta +from datetime import datetime, timedelta, UTC as DATETIME_UTC from typing import TYPE_CHECKING from openpyxl import Workbook -from dataall.base.aws.s3_client import S3_client from dataall.base.db import exceptions from dataall.core.environment.services.environment_service import EnvironmentService +from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService +from dataall.modules.worksheets.aws.s3_client import S3Client from dataall.modules.worksheets.db.worksheet_models import WorksheetQueryResult from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository +from dataall.modules.worksheets.services.worksheet_enums import WorksheetResultsFormat +from dataall.modules.worksheets.services.worksheet_permissions import RUN_ATHENA_QUERY from dataall.modules.worksheets.services.worksheet_service import WorksheetService if TYPE_CHECKING: try: from sqlalchemy.orm import Session - from mypy_boto3_s3.client import S3Client + from openpyxl.worksheet.worksheet import Worksheet except ImportError: print('skipping type checks as stubs are not installed') - S3Client = None Session = None + Worksheet = None class WorksheetQueryResultService: - SupportedFormats = {'csv', 'xlsx'} + _DEFAULT_ATHENA_QUERIES_PATH = 'athenaqueries' + _DEFAULT_QUERY_RESULTS_TIMEOUT = os.getenv('QUERY_RESULT_TIMEOUT_MINUTES', 120) @staticmethod - def validate_input(data): - if not data: - raise exceptions.InvalidInput('data', data, 'input is required') - if not data.get('athenaQueryId'): - raise exceptions.RequiredParameter('athenaQueryId') - if not data.get('fileFormat'): - raise exceptions.RequiredParameter('fileFormat') - if data.get('fileFormat', '').lower() not in WorksheetQueryResultService.SupportedFormats: - raise exceptions.InvalidInput( - 'fileFormat', data.get('fileFormat'), ', '.join(WorksheetQueryResultService.SupportedFormats) - ) - - @staticmethod - def get_output_bucket(session: 'Session', environment_uri: str) -> str: - environment = EnvironmentService.get_environment_by_uri(session, environment_uri) - bucket = environment.EnvironmentDefaultBucketName - return bucket - - @staticmethod - def create_query_result( + def _create_query_result( environment_bucket: str, athena_workgroup: str, worksheet_uri: str, region: str, aws_account_id: str, data: dict ) -> WorksheetQueryResult: sql_query_result = WorksheetQueryResult( @@ -55,21 +40,21 @@ def create_query_result( fileFormat=data.get('fileFormat'), OutputLocation=f's3://{environment_bucket}/athenaqueries/{athena_workgroup}/', region=region, - AwsAccountId=aws_account_id, - queryType='data', + AwsAccountId=aws_account_id ) return sql_query_result @staticmethod - def get_file_key( - workgroup: str, query_id: str, file_format: str = 'csv', athena_queries_dir: str = 'athenaqueries' + def build_s3_file_path( + workgroup: str, query_id: str, athena_queries_dir: str = None ) -> str: - return f'{athena_queries_dir}/{workgroup}/{query_id}.{file_format}' + athena_queries_dir = athena_queries_dir or WorksheetQueryResultService._DEFAULT_ATHENA_QUERIES_PATH + return f'{athena_queries_dir}/{workgroup}/{query_id}' @staticmethod def convert_csv_to_xlsx(csv_data) -> io.BytesIO: wb = Workbook() - ws = wb.active + ws: 'Worksheet' = wb.active csv_reader = csv.reader(csv_data.splitlines()) for row in csv_reader: ws.append(row) @@ -80,39 +65,19 @@ def convert_csv_to_xlsx(csv_data) -> io.BytesIO: return excel_buffer @staticmethod - def handle_xlsx_format(output_bucket: str, file_key: str) -> bool: - aws_region_name = os.getenv('AWS_REGION_NAME', 'eu-west-1') - file_name, _ = file_key.split('.') - csv_data = S3_client.get_object(region=aws_region_name, bucket=output_bucket, key=f'{file_name}.csv') - excel_buffer = WorksheetQueryResultService.convert_csv_to_xlsx(csv_data) - S3_client.put_object(region=aws_region_name, bucket=output_bucket, key=file_key, body=excel_buffer) - return True - - @staticmethod + @ResourcePolicyService.has_resource_permission(RUN_ATHENA_QUERY) def download_sql_query_result(session: 'Session', data: dict = None): - # # default timeout for the download link is 2 hours(in minutes) - default_timeout = os.getenv('QUERY_RESULT_TIMEOUT_MINUTES', 120) - environment = EnvironmentService.get_environment_by_uri(session, data.get('environmentUri')) worksheet = WorksheetService.get_worksheet_by_uri(session, data.get('worksheetUri')) env_group = EnvironmentService.get_environment_group( session, worksheet.SamlAdminGroupName, environment.environmentUri ) - output_file_key = WorksheetQueryResultService.get_file_key( - env_group.environmentAthenaWorkGroup, data.get('athenaQueryId'), data.get('fileFormat') - ) sql_query_result = WorksheetRepository.find_query_result_by_format( session, data.get('worksheetUri'), data.get('athenaQueryId'), data.get('fileFormat') ) - if data.get('fileFormat') == 'xlsx': - is_job_failed = WorksheetQueryResultService.handle_xlsx_format( - environment.EnvironmentDefaultBucketName, output_file_key - ) - if is_job_failed: - raise ValueError('Error while preparing the xlsx file') - + s3_client = S3Client(environment) if not sql_query_result: - sql_query_result = WorksheetQueryResultService.create_query_result( + sql_query_result = WorksheetQueryResultService._create_query_result( environment.EnvironmentDefaultBucketName, env_group.environmentAthenaWorkGroup, worksheet.worksheetUri, @@ -120,18 +85,28 @@ def download_sql_query_result(session: 'Session', data: dict = None): environment.AwsAccountId, data, ) - S3_client.object_exists( - region=environment.region, bucket=environment.EnvironmentDefaultBucketName, key=output_file_key + output_file_s3_path = WorksheetQueryResultService.build_s3_file_path( + env_group.environmentAthenaWorkGroup, data.get('athenaQueryId') + ) + if sql_query_result.fileFormat == WorksheetResultsFormat.XLSX.value: + try: + csv_data = s3_client.get_object(bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{WorksheetResultsFormat.CSV.value}') + excel_buffer = WorksheetQueryResultService.convert_csv_to_xlsx(csv_data) + s3_client.put_object(bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{WorksheetResultsFormat.XLSX.value}', body=excel_buffer) + except Exception as e: + raise exceptions.AWSResourceNotAvailable('CONVERT_CSV_TO_EXCEL',f'Failed to convert csv to xlsx: {e}') + + s3_client.object_exists( + bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{sql_query_result.fileFormat}' ) if sql_query_result.is_download_link_expired(): - url = S3_client.get_presigned_url( - region=environment.region, + url = s3_client.get_presigned_url( bucket=environment.EnvironmentDefaultBucketName, - key=output_file_key, - expire_minutes=default_timeout, + key=f'{output_file_s3_path}.{sql_query_result.fileFormat}', + expire_minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT, ) sql_query_result.downloadLink = url - sql_query_result.expiresIn = datetime.utcnow() + timedelta(seconds=default_timeout) + sql_query_result.expiresIn = datetime.now(DATETIME_UTC) + timedelta(minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT) session.add(sql_query_result) session.commit() diff --git a/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py index 4b0b48e7d..b03a78d98 100644 --- a/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py +++ b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. revision = 'd1d6da1b2d67' -down_revision = 'd274e756f0ae' +down_revision = 'f87aecc36d39' branch_labels = None depends_on = None @@ -25,9 +25,16 @@ def upgrade(): op.add_column('worksheet_query_result', sa.Column('expiresIn', sa.DateTime(), nullable=True)) op.add_column('worksheet_query_result', sa.Column('updated', sa.DateTime(), nullable=False)) op.add_column('worksheet_query_result', sa.Column('fileFormat', sa.String(), nullable=True)) + op.drop_constraint('AthenaQueryId', 'worksheet_query_result', type_='primary') + op.create_primary_key('worksheet_query_result_pkey', 'worksheet_query_result', ['worksheetQueryResultUri']) + op.alter_column('worksheet_query_result', 'AthenaQueryId', nullable=False) + op.drop_column('worksheet_query_result', 'queryType') def downgrade(): + op.add_column('worksheet_query_result', sa.Column('queryType', sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_constraint('worksheet_query_result_pkey', 'worksheet_query_result', type_='primary') + op.create_primary_key('AthenaQueryId', 'worksheet_query_result', ['AthenaQueryId']) op.drop_column('worksheet_query_result', 'fileFormat') op.drop_column('worksheet_query_result', 'updated') op.drop_column('worksheet_query_result', 'expiresIn') diff --git a/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js index c5fec1942..bc6141cd5 100644 --- a/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js +++ b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js @@ -1,40 +1,22 @@ import { gql } from 'apollo-boost'; -export const createWorksheetQueryResultDownloadUrl = ({ - fileFormat, - environmentUri, - athenaQueryId, - worksheetUri -}) => ({ +export const createWorksheetQueryResultDownloadUrl = (input) => ({ variables: { - fileFormat, - environmentUri, - athenaQueryId, - worksheetUri + input }, query: gql` mutation CreateWorksheetQueryResultDownloadUrl( - $fileFormat: String! - $environmentUri: String! - $athenaQueryId: String! - $worksheetUri: String! + $input: WorksheetQueryResultDownloadUrlInput! ) { createWorksheetQueryResultDownloadUrl( - input: { - fileFormat: $fileFormat - environmentUri: $environmentUri - athenaQueryId: $athenaQueryId - worksheetUri: $worksheetUri - } + input: $input ) { downloadLink AthenaQueryId expiresIn fileFormat - OutputLocation + outputLocation } } ` }); - - diff --git a/tests/modules/worksheets/test_worksheet.py b/tests/modules/worksheets/test_worksheet.py index 2edffe0a7..3f73fcc1b 100644 --- a/tests/modules/worksheets/test_worksheet.py +++ b/tests/modules/worksheets/test_worksheet.py @@ -37,9 +37,7 @@ def worksheet(client, tenant, group): @pytest.fixture(scope='module', autouse=True) def mock_s3_client(module_mocker): s3_client = MagicMock() - module_mocker.patch('dataall.modules.worksheets.services.worksheet_query_result_service.S3_client', s3_client) - - # s3_client.client.return_value = s3_client + module_mocker.patch('dataall.modules.worksheets.services.worksheet_query_result_service.S3_client', s3_client, autospec=True) s3_client().object_exists.return_value = True s3_client().put_object.return_value = None @@ -48,24 +46,6 @@ def mock_s3_client(module_mocker): yield s3_client -# @pytest.fixture(scope='module') -# def dataset1( -# module_mocker, -# org_fixture: Organization, -# env_fixture: Environment, -# dataset: typing.Callable, -# group, -# ) -> S3Dataset: -# kms_client = MagicMock() -# module_mocker.patch('dataall.modules.s3_datasets.services.dataset_service.KmsClient', kms_client) -# -# kms_client().get_key_id.return_value = mocked_key_id -# -# d = dataset(org=org_fixture, env=env_fixture, name='dataset1', owner=env_fixture.owner, group=group.name) -# print(d) -# yield d - - def test_create_worksheet(client, worksheet): assert worksheet.label == 'my worksheet' assert worksheet.owner == 'alice' @@ -188,7 +168,6 @@ def test_create_query_download_url(client, worksheet, env_fixture): """ mutation CreateWorksheetQueryResultDownloadUrl($input:WorksheetQueryResultDownloadUrlInput){ createWorksheetQueryResultDownloadUrl(input:$input){ - queryType sqlBody AthenaQueryId region From caf497b5273c6a3ba1c0d5e1ad7389896c714e47 Mon Sep 17 00:00:00 2001 From: anison Date: Thu, 10 Oct 2024 16:52:18 +0530 Subject: [PATCH 4/8] fix linting problems --- .../modules/worksheets/api/__init__.py | 9 +++-- .../modules/worksheets/api/resolvers.py | 13 ++++--- .../modules/worksheets/aws/s3_client.py | 4 +-- .../modules/worksheets/db/worksheet_models.py | 1 - .../services/worksheet_permissions.py | 5 +++ .../worksheet_query_result_service.py | 35 +++++++++++-------- .../Worksheets/components/WorksheetResult.js | 34 +++++++++++------- .../createWorksheetQueryResultDownloadUrl.js | 4 +-- .../src/modules/Worksheets/services/index.js | 2 +- tests/modules/worksheets/test_worksheet.py | 19 +++++----- 10 files changed, 77 insertions(+), 49 deletions(-) diff --git a/backend/dataall/modules/worksheets/api/__init__.py b/backend/dataall/modules/worksheets/api/__init__.py index 1251b301b..48b0459d6 100644 --- a/backend/dataall/modules/worksheets/api/__init__.py +++ b/backend/dataall/modules/worksheets/api/__init__.py @@ -4,7 +4,12 @@ queries, resolvers, types, - enums, ) -__all__ = ['resolvers', 'types', 'input_types', 'queries', 'mutations', 'enums'] +__all__ = [ + 'resolvers', + 'types', + 'input_types', + 'queries', + 'mutations', +] diff --git a/backend/dataall/modules/worksheets/api/resolvers.py b/backend/dataall/modules/worksheets/api/resolvers.py index 331973be1..74c2c047b 100644 --- a/backend/dataall/modules/worksheets/api/resolvers.py +++ b/backend/dataall/modules/worksheets/api/resolvers.py @@ -73,18 +73,23 @@ def delete_worksheet(context, source, worksheetUri: str = None): def create_athena_query_result_download_url(context: Context, source, input: dict = None): - if not input: # raise exceptions.InvalidInput('data', input, 'input is required') raise exceptions.RequiredParameter('data') + if not input.get('environmentUri'): + raise exceptions.RequiredParameter('environmentUri') if not input.get('athenaQueryId'): raise exceptions.RequiredParameter('athenaQueryId') if not input.get('fileFormat'): raise exceptions.RequiredParameter('fileFormat') if not hasattr(WorksheetResultsFormat, input.get('fileFormat').upper()): raise exceptions.InvalidInput( - 'fileFormat', input.get('fileFormat'), - ', '.join(result_format.value for result_format in WorksheetResultsFormat)) + 'fileFormat', + input.get('fileFormat'), + ', '.join(result_format.value for result_format in WorksheetResultsFormat), + ) + + env_uri = input['environmentUri'] with context.engine.scoped_session() as session: - return WorksheetQueryResultService.download_sql_query_result(session=session, data=input) + return WorksheetQueryResultService.download_sql_query_result(session=session, env_uri=env_uri, data=input) diff --git a/backend/dataall/modules/worksheets/aws/s3_client.py b/backend/dataall/modules/worksheets/aws/s3_client.py index 8cab77e85..b20c78e20 100644 --- a/backend/dataall/modules/worksheets/aws/s3_client.py +++ b/backend/dataall/modules/worksheets/aws/s3_client.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from dataall.core.environment.db.environment_models import Environment + try: from mypy_boto3_s3 import S3Client as S3ClientType except ImportError: @@ -19,7 +20,6 @@ class S3Client: - def __init__(self, env: 'Environment'): self._client = SessionHelper.remote_session(env.AwsAccountId, env.region).client('s3', region_name=env.region) self._env = env @@ -55,7 +55,6 @@ def object_exists(self, bucket, key) -> bool: log.error(f'Failed to check object existence due to: {e}') raise AWSResourceNotFound('s3_object_exists', f'Object {key} not found in bucket {bucket}') - def put_object(self, bucket, key, body): try: self.client.put_object(Bucket=bucket, Key=key, Body=body) @@ -63,7 +62,6 @@ def put_object(self, bucket, key, body): log.error(f'Failed to put object due to: {e}') raise e - def get_object(self, bucket, key) -> str: try: response = self.client.get_object(Bucket=bucket, Key=key) diff --git a/backend/dataall/modules/worksheets/db/worksheet_models.py b/backend/dataall/modules/worksheets/db/worksheet_models.py index 5342a65b2..b7a01a871 100644 --- a/backend/dataall/modules/worksheets/db/worksheet_models.py +++ b/backend/dataall/modules/worksheets/db/worksheet_models.py @@ -1,7 +1,6 @@ import datetime import enum -from future.backports.email.policy import default from sqlalchemy import Column, DateTime, Integer, Enum, String, BigInteger from sqlalchemy.dialects import postgresql from sqlalchemy.orm import query_expression diff --git a/backend/dataall/modules/worksheets/services/worksheet_permissions.py b/backend/dataall/modules/worksheets/services/worksheet_permissions.py index 0f494567d..c4aa2266e 100644 --- a/backend/dataall/modules/worksheets/services/worksheet_permissions.py +++ b/backend/dataall/modules/worksheets/services/worksheet_permissions.py @@ -38,12 +38,17 @@ RUN ATHENA QUERY """ RUN_ATHENA_QUERY = 'RUN_ATHENA_QUERY' +RUN_ATHENA_QUERY_TENANT = 'RUN_ATHENA_QUERY_TENANT' ENVIRONMENT_INVITED.append(RUN_ATHENA_QUERY) ENVIRONMENT_INVITATION_REQUEST.append(RUN_ATHENA_QUERY) ENVIRONMENT_ALL.append(RUN_ATHENA_QUERY) +ENVIRONMENT_ALL.append(RUN_ATHENA_QUERY_TENANT) RESOURCES_ALL.append(RUN_ATHENA_QUERY) RESOURCES_ALL_WITH_DESC[RUN_ATHENA_QUERY] = 'Run Worksheet Athena queries on this environment' + +TENANT_ALL.append(RUN_ATHENA_QUERY_TENANT) +TENANT_ALL_WITH_DESC[RUN_ATHENA_QUERY_TENANT] = 'Run Worksheet Athena queries on any environment' diff --git a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py index 58840b148..073b0ec6f 100644 --- a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py +++ b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py @@ -1,19 +1,19 @@ import csv import io import os -from datetime import datetime, timedelta, UTC as DATETIME_UTC +from datetime import datetime, timedelta from typing import TYPE_CHECKING from openpyxl import Workbook from dataall.base.db import exceptions from dataall.core.environment.services.environment_service import EnvironmentService -from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService from dataall.modules.worksheets.aws.s3_client import S3Client from dataall.modules.worksheets.db.worksheet_models import WorksheetQueryResult from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository from dataall.modules.worksheets.services.worksheet_enums import WorksheetResultsFormat -from dataall.modules.worksheets.services.worksheet_permissions import RUN_ATHENA_QUERY +from dataall.modules.worksheets.services.worksheet_permissions import RUN_ATHENA_QUERY_TENANT from dataall.modules.worksheets.services.worksheet_service import WorksheetService if TYPE_CHECKING: @@ -40,14 +40,12 @@ def _create_query_result( fileFormat=data.get('fileFormat'), OutputLocation=f's3://{environment_bucket}/athenaqueries/{athena_workgroup}/', region=region, - AwsAccountId=aws_account_id + AwsAccountId=aws_account_id, ) return sql_query_result @staticmethod - def build_s3_file_path( - workgroup: str, query_id: str, athena_queries_dir: str = None - ) -> str: + def build_s3_file_path(workgroup: str, query_id: str, athena_queries_dir: str = None) -> str: athena_queries_dir = athena_queries_dir or WorksheetQueryResultService._DEFAULT_ATHENA_QUERIES_PATH return f'{athena_queries_dir}/{workgroup}/{query_id}' @@ -65,9 +63,9 @@ def convert_csv_to_xlsx(csv_data) -> io.BytesIO: return excel_buffer @staticmethod - @ResourcePolicyService.has_resource_permission(RUN_ATHENA_QUERY) - def download_sql_query_result(session: 'Session', data: dict = None): - environment = EnvironmentService.get_environment_by_uri(session, data.get('environmentUri')) + @TenantPolicyService.has_tenant_permission(RUN_ATHENA_QUERY_TENANT) + def download_sql_query_result(session: 'Session', env_uri: str, data: dict = None): + environment = EnvironmentService.get_environment_by_uri(session, env_uri) worksheet = WorksheetService.get_worksheet_by_uri(session, data.get('worksheetUri')) env_group = EnvironmentService.get_environment_group( session, worksheet.SamlAdminGroupName, environment.environmentUri @@ -90,11 +88,18 @@ def download_sql_query_result(session: 'Session', data: dict = None): ) if sql_query_result.fileFormat == WorksheetResultsFormat.XLSX.value: try: - csv_data = s3_client.get_object(bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{WorksheetResultsFormat.CSV.value}') + csv_data = s3_client.get_object( + bucket=environment.EnvironmentDefaultBucketName, + key=f'{output_file_s3_path}.{WorksheetResultsFormat.CSV.value}', + ) excel_buffer = WorksheetQueryResultService.convert_csv_to_xlsx(csv_data) - s3_client.put_object(bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{WorksheetResultsFormat.XLSX.value}', body=excel_buffer) + s3_client.put_object( + bucket=environment.EnvironmentDefaultBucketName, + key=f'{output_file_s3_path}.{WorksheetResultsFormat.XLSX.value}', + body=excel_buffer, + ) except Exception as e: - raise exceptions.AWSResourceNotAvailable('CONVERT_CSV_TO_EXCEL',f'Failed to convert csv to xlsx: {e}') + raise exceptions.AWSResourceNotAvailable('CONVERT_CSV_TO_EXCEL', f'Failed to convert csv to xlsx: {e}') s3_client.object_exists( bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{sql_query_result.fileFormat}' @@ -106,7 +111,9 @@ def download_sql_query_result(session: 'Session', data: dict = None): expire_minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT, ) sql_query_result.downloadLink = url - sql_query_result.expiresIn = datetime.now(DATETIME_UTC) + timedelta(minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT) + sql_query_result.expiresIn = datetime.utcnow() + timedelta( + minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT + ) session.add(sql_query_result) session.commit() diff --git a/frontend/src/modules/Worksheets/components/WorksheetResult.js b/frontend/src/modules/Worksheets/components/WorksheetResult.js index faee82e52..62fa6dea4 100644 --- a/frontend/src/modules/Worksheets/components/WorksheetResult.js +++ b/frontend/src/modules/Worksheets/components/WorksheetResult.js @@ -23,15 +23,17 @@ import RadioGroup from '@mui/material/RadioGroup'; import FormControlLabel from '@mui/material/FormControlLabel'; import { Download } from '@mui/icons-material'; import { LoadingButton } from '@mui/lab'; -import { - useClient -} from 'services'; -import { - createWorksheetQueryResultDownloadUrl -} from '../services'; +import { useClient } from 'services'; +import { createWorksheetQueryResultDownloadUrl } from '../services'; import { SET_ERROR, useDispatch } from 'globalErrors'; -export const WorksheetResult = ({ results, loading, currentEnv, athenaQueryId, worksheetUri }) => { +export const WorksheetResult = ({ + results, + loading, + currentEnv, + athenaQueryId, + worksheetUri +}) => { const [runningDownloadQuery, setRunningDownloadQuery] = useState(false); const [fileType, setFileType] = useState('csv'); const client = useClient(); @@ -55,7 +57,8 @@ export const WorksheetResult = ({ results, loading, currentEnv, athenaQueryId, w if (!response.errors) { const link = document.createElement('a'); - link.href = response.data.createWorksheetQueryResultDownloadUrl.downloadLink; + link.href = + response.data.createWorksheetQueryResultDownloadUrl.downloadLink; // Append to html link element page document.body.appendChild(link); // Start download @@ -70,7 +73,7 @@ export const WorksheetResult = ({ results, loading, currentEnv, athenaQueryId, w } finally { setRunningDownloadQuery(false); } - }, [client, dispatch, currentEnv, athenaQueryId, fileType]) + }, [client, dispatch, currentEnv, athenaQueryId, fileType]); if (loading) { return ; @@ -98,8 +101,16 @@ export const WorksheetResult = ({ results, loading, currentEnv, athenaQueryId, w value={fileType} onChange={handleChange} > - } label="CSV" /> - } label="XLSX" /> + } + label="CSV" + /> + } + label="XLSX" + /> - } /> diff --git a/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js index bc6141cd5..305ab5599 100644 --- a/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js +++ b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js @@ -8,9 +8,7 @@ export const createWorksheetQueryResultDownloadUrl = (input) => ({ mutation CreateWorksheetQueryResultDownloadUrl( $input: WorksheetQueryResultDownloadUrlInput! ) { - createWorksheetQueryResultDownloadUrl( - input: $input - ) { + createWorksheetQueryResultDownloadUrl(input: $input) { downloadLink AthenaQueryId expiresIn diff --git a/frontend/src/modules/Worksheets/services/index.js b/frontend/src/modules/Worksheets/services/index.js index ee5a24b33..65637d6ae 100644 --- a/frontend/src/modules/Worksheets/services/index.js +++ b/frontend/src/modules/Worksheets/services/index.js @@ -6,4 +6,4 @@ export * from './listWorksheets'; export * from './runAthenaSqlQuery'; export * from './updateWorksheet'; export * from './listSharedDatasetTableColumns'; -export * from './createWorksheetQueryResultDownloadUrl'; \ No newline at end of file +export * from './createWorksheetQueryResultDownloadUrl'; diff --git a/tests/modules/worksheets/test_worksheet.py b/tests/modules/worksheets/test_worksheet.py index 3f73fcc1b..dc4ffe799 100644 --- a/tests/modules/worksheets/test_worksheet.py +++ b/tests/modules/worksheets/test_worksheet.py @@ -36,12 +36,13 @@ def worksheet(client, tenant, group): @pytest.fixture(scope='module', autouse=True) def mock_s3_client(module_mocker): - s3_client = MagicMock() - module_mocker.patch('dataall.modules.worksheets.services.worksheet_query_result_service.S3_client', s3_client, autospec=True) + s3_client = module_mocker.patch( + 'dataall.modules.worksheets.services.worksheet_query_result_service.S3Client', autospec=True + ) - s3_client().object_exists.return_value = True - s3_client().put_object.return_value = None - s3_client().get_object.return_value = '123,123,123' + s3_client.object_exists.return_value = True + s3_client.put_object.return_value = None + s3_client.get_object.return_value = '123,123,123' s3_client.get_presigned_url.return_value = 'https://s3.amazonaws.com/file/123.csv' yield s3_client @@ -166,16 +167,16 @@ def test_update_worksheet(client, worksheet, group): def test_create_query_download_url(client, worksheet, env_fixture): response = client.query( """ - mutation CreateWorksheetQueryResultDownloadUrl($input:WorksheetQueryResultDownloadUrlInput){ - createWorksheetQueryResultDownloadUrl(input:$input){ + mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){ + createWorksheetQueryResultDownloadUrl(input: $input){ sqlBody AthenaQueryId region AwsAccountId - ElapsedTimeInMs + elapsedTimeInMs created downloadLink - OutputLocation + outputLocation expiresIn fileFormat } From be85bdd36ab587f365432146f4c1073138b64c2f Mon Sep 17 00:00:00 2001 From: anison Date: Fri, 18 Oct 2024 15:22:58 +0530 Subject: [PATCH 5/8] refactor changes for result doenload feature --- .../modules/worksheets/api/resolvers.py | 6 +++-- .../modules/worksheets/aws/s3_client.py | 10 ++++---- .../services/worksheet_permissions.py | 23 +++++++++++-------- .../worksheet_query_result_service.py | 12 ++++++---- ...dd_columns_worksheet_query_result_model.py | 2 ++ 5 files changed, 30 insertions(+), 23 deletions(-) diff --git a/backend/dataall/modules/worksheets/api/resolvers.py b/backend/dataall/modules/worksheets/api/resolvers.py index 74c2c047b..07bf5ea80 100644 --- a/backend/dataall/modules/worksheets/api/resolvers.py +++ b/backend/dataall/modules/worksheets/api/resolvers.py @@ -74,7 +74,6 @@ def delete_worksheet(context, source, worksheetUri: str = None): def create_athena_query_result_download_url(context: Context, source, input: dict = None): if not input: - # raise exceptions.InvalidInput('data', input, 'input is required') raise exceptions.RequiredParameter('data') if not input.get('environmentUri'): raise exceptions.RequiredParameter('environmentUri') @@ -90,6 +89,9 @@ def create_athena_query_result_download_url(context: Context, source, input: dic ) env_uri = input['environmentUri'] + worksheet_uri = input['worksheetUri'] with context.engine.scoped_session() as session: - return WorksheetQueryResultService.download_sql_query_result(session=session, env_uri=env_uri, data=input) + return WorksheetQueryResultService.download_sql_query_result( + session=session, uri=worksheet_uri, env_uri=env_uri, data=input + ) diff --git a/backend/dataall/modules/worksheets/aws/s3_client.py b/backend/dataall/modules/worksheets/aws/s3_client.py index b20c78e20..9d07132e1 100644 --- a/backend/dataall/modules/worksheets/aws/s3_client.py +++ b/backend/dataall/modules/worksheets/aws/s3_client.py @@ -1,12 +1,10 @@ -import boto3 -from botocore.config import Config +import logging +from typing import TYPE_CHECKING from botocore.exceptions import ClientError -import logging -from dataall.base.db.exceptions import AWSResourceNotFound -from dataall.base.aws.sts import SessionHelper -from typing import TYPE_CHECKING +from dataall.base.aws.sts import SessionHelper +from dataall.base.db.exceptions import AWSResourceNotFound if TYPE_CHECKING: from dataall.core.environment.db.environment_models import Environment diff --git a/backend/dataall/modules/worksheets/services/worksheet_permissions.py b/backend/dataall/modules/worksheets/services/worksheet_permissions.py index c4aa2266e..b8cee3f5d 100644 --- a/backend/dataall/modules/worksheets/services/worksheet_permissions.py +++ b/backend/dataall/modules/worksheets/services/worksheet_permissions.py @@ -22,12 +22,10 @@ UPDATE_WORKSHEET = 'UPDATE_WORKSHEET' DELETE_WORKSHEET = 'DELETE_WORKSHEET' RUN_WORKSHEET_QUERY = 'RUN_WORKSHEET_QUERY' -WORKSHEET_ALL = [ - GET_WORKSHEET, - UPDATE_WORKSHEET, - DELETE_WORKSHEET, - RUN_WORKSHEET_QUERY, -] +DOWNLOAD_ATHENA_QUERY_RESULTS = 'DOWNLOAD_ATHENA_QUERY_RESULTS' + + +WORKSHEET_ALL = [GET_WORKSHEET, UPDATE_WORKSHEET, DELETE_WORKSHEET, RUN_WORKSHEET_QUERY, DOWNLOAD_ATHENA_QUERY_RESULTS] RESOURCES_ALL.extend(WORKSHEET_ALL) @@ -38,17 +36,22 @@ RUN ATHENA QUERY """ RUN_ATHENA_QUERY = 'RUN_ATHENA_QUERY' -RUN_ATHENA_QUERY_TENANT = 'RUN_ATHENA_QUERY_TENANT' ENVIRONMENT_INVITED.append(RUN_ATHENA_QUERY) ENVIRONMENT_INVITATION_REQUEST.append(RUN_ATHENA_QUERY) ENVIRONMENT_ALL.append(RUN_ATHENA_QUERY) -ENVIRONMENT_ALL.append(RUN_ATHENA_QUERY_TENANT) RESOURCES_ALL.append(RUN_ATHENA_QUERY) RESOURCES_ALL_WITH_DESC[RUN_ATHENA_QUERY] = 'Run Worksheet Athena queries on this environment' -TENANT_ALL.append(RUN_ATHENA_QUERY_TENANT) -TENANT_ALL_WITH_DESC[RUN_ATHENA_QUERY_TENANT] = 'Run Worksheet Athena queries on any environment' + +""" +DOWNLOAD ATHENA QUERY RESULTS +""" +ENVIRONMENT_INVITED.append(DOWNLOAD_ATHENA_QUERY_RESULTS) +ENVIRONMENT_INVITATION_REQUEST.append(DOWNLOAD_ATHENA_QUERY_RESULTS) +ENVIRONMENT_ALL.append(DOWNLOAD_ATHENA_QUERY_RESULTS) + +RESOURCES_ALL_WITH_DESC[DOWNLOAD_ATHENA_QUERY_RESULTS] = 'Download Worksheet Athena query results' diff --git a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py index 073b0ec6f..0521a2562 100644 --- a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py +++ b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py @@ -8,12 +8,13 @@ from dataall.base.db import exceptions from dataall.core.environment.services.environment_service import EnvironmentService +from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService from dataall.modules.worksheets.aws.s3_client import S3Client from dataall.modules.worksheets.db.worksheet_models import WorksheetQueryResult from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository from dataall.modules.worksheets.services.worksheet_enums import WorksheetResultsFormat -from dataall.modules.worksheets.services.worksheet_permissions import RUN_ATHENA_QUERY_TENANT +from dataall.modules.worksheets.services.worksheet_permissions import DOWNLOAD_ATHENA_QUERY_RESULTS, MANAGE_WORKSHEETS from dataall.modules.worksheets.services.worksheet_service import WorksheetService if TYPE_CHECKING: @@ -38,7 +39,7 @@ def _create_query_result( worksheetUri=worksheet_uri, AthenaQueryId=data.get('athenaQueryId'), fileFormat=data.get('fileFormat'), - OutputLocation=f's3://{environment_bucket}/athenaqueries/{athena_workgroup}/', + OutputLocation=f's3://{environment_bucket}/{WorksheetQueryResultService._DEFAULT_ATHENA_QUERIES_PATH}/{athena_workgroup}/', region=region, AwsAccountId=aws_account_id, ) @@ -63,10 +64,11 @@ def convert_csv_to_xlsx(csv_data) -> io.BytesIO: return excel_buffer @staticmethod - @TenantPolicyService.has_tenant_permission(RUN_ATHENA_QUERY_TENANT) - def download_sql_query_result(session: 'Session', env_uri: str, data: dict = None): + @TenantPolicyService.has_tenant_permission(MANAGE_WORKSHEETS) + @ResourcePolicyService.has_resource_permission(DOWNLOAD_ATHENA_QUERY_RESULTS) + def download_sql_query_result(session: 'Session', uri: str, env_uri: str, data: dict = None): environment = EnvironmentService.get_environment_by_uri(session, env_uri) - worksheet = WorksheetService.get_worksheet_by_uri(session, data.get('worksheetUri')) + worksheet = WorksheetService.get_worksheet_by_uri(session, uri) env_group = EnvironmentService.get_environment_group( session, worksheet.SamlAdminGroupName, environment.environmentUri ) diff --git a/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py index b03a78d98..d4cf99819 100644 --- a/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py +++ b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py @@ -25,6 +25,7 @@ def upgrade(): op.add_column('worksheet_query_result', sa.Column('expiresIn', sa.DateTime(), nullable=True)) op.add_column('worksheet_query_result', sa.Column('updated', sa.DateTime(), nullable=False)) op.add_column('worksheet_query_result', sa.Column('fileFormat', sa.String(), nullable=True)) + op.add_column('worksheet_query_result', sa.Column('worksheetQueryResultUri', sa.String(), nullable=False)) op.drop_constraint('AthenaQueryId', 'worksheet_query_result', type_='primary') op.create_primary_key('worksheet_query_result_pkey', 'worksheet_query_result', ['worksheetQueryResultUri']) op.alter_column('worksheet_query_result', 'AthenaQueryId', nullable=False) @@ -35,6 +36,7 @@ def downgrade(): op.add_column('worksheet_query_result', sa.Column('queryType', sa.VARCHAR(), autoincrement=False, nullable=True)) op.drop_constraint('worksheet_query_result_pkey', 'worksheet_query_result', type_='primary') op.create_primary_key('AthenaQueryId', 'worksheet_query_result', ['AthenaQueryId']) + op.drop_column('worksheet_query_result', 'worksheetQueryResultUri') op.drop_column('worksheet_query_result', 'fileFormat') op.drop_column('worksheet_query_result', 'updated') op.drop_column('worksheet_query_result', 'expiresIn') From f7cb94b7758ed074a5a2fd5d33cafa6d012e3ae0 Mon Sep 17 00:00:00 2001 From: anison Date: Fri, 18 Oct 2024 17:22:33 +0530 Subject: [PATCH 6/8] fix: test_create_query_download_url Testcase --- tests/modules/worksheets/test_worksheet.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/modules/worksheets/test_worksheet.py b/tests/modules/worksheets/test_worksheet.py index dc4ffe799..3e467b4f9 100644 --- a/tests/modules/worksheets/test_worksheet.py +++ b/tests/modules/worksheets/test_worksheet.py @@ -40,10 +40,10 @@ def mock_s3_client(module_mocker): 'dataall.modules.worksheets.services.worksheet_query_result_service.S3Client', autospec=True ) - s3_client.object_exists.return_value = True - s3_client.put_object.return_value = None - s3_client.get_object.return_value = '123,123,123' - s3_client.get_presigned_url.return_value = 'https://s3.amazonaws.com/file/123.csv' + s3_client.return_value.object_exists.return_value = True + s3_client.return_value.put_object.return_value = None + s3_client.return_value.get_object.return_value = '123,123,123' + s3_client.return_value.get_presigned_url.return_value = 'https://s3.amazonaws.com/file/123.csv' yield s3_client @@ -164,7 +164,7 @@ def test_update_worksheet(client, worksheet, group): assert response.data.updateWorksheet.label == 'change label' -def test_create_query_download_url(client, worksheet, env_fixture): +def test_create_query_download_url(client, worksheet, env_fixture, group): response = client.query( """ mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){ @@ -188,6 +188,8 @@ def test_create_query_download_url(client, worksheet, env_fixture): 'fileFormat': 'csv', 'environmentUri': env_fixture.environmentUri, }, + username='alice', + groups=[group.name], ) expires_in = datetime.strptime(response.data.createWorksheetQueryResultDownloadUrl.created, '%Y-%m-%d %H:%M:%S.%f') From 87f7cc754b133b0be23f5ef7290a7f0c4ad72683 Mon Sep 17 00:00:00 2001 From: anison Date: Tue, 22 Oct 2024 18:38:29 +0530 Subject: [PATCH 7/8] fix: add unauthorized testcases for query result download --- backend/dataall/base/aws/s3_client.py | 33 +-------- .../services/worksheet_permissions.py | 19 ++---- .../worksheet_query_result_service.py | 8 +-- tests/modules/worksheets/test_worksheet.py | 67 ++++++++++++++++++- 4 files changed, 74 insertions(+), 53 deletions(-) diff --git a/backend/dataall/base/aws/s3_client.py b/backend/dataall/base/aws/s3_client.py index 04d43e6b3..e05fa9933 100644 --- a/backend/dataall/base/aws/s3_client.py +++ b/backend/dataall/base/aws/s3_client.py @@ -1,9 +1,8 @@ +import logging + import boto3 from botocore.config import Config - from botocore.exceptions import ClientError -import logging -from dataall.base.db.exceptions import AWSResourceNotFound log = logging.getLogger(__name__) @@ -34,31 +33,3 @@ def get_presigned_url(region, bucket, key, expire_minutes: int = 15): except ClientError as e: log.error(f'Failed to get presigned URL due to: {e}') raise e - - @staticmethod - def object_exists(region, bucket, key) -> bool: - try: - S3_client.client(region, None).head_object(Bucket=bucket, Key=key) - return True - except ClientError as e: - log.error(f'Failed to check object existence due to: {e}') - if e.response['Error']['Code'] == '404': - return False - raise AWSResourceNotFound('s3_object_exists', f'Object {key} not found in bucket {bucket}') - - @staticmethod - def put_object(region, bucket, key, body): - try: - S3_client.client(region, None).put_object(Bucket=bucket, Key=key, Body=body) - except ClientError as e: - log.error(f'Failed to put object due to: {e}') - raise e - - @staticmethod - def get_object(region, bucket, key): - try: - response = S3_client.client(region, None).get_object(Bucket=bucket, Key=key) - return response['Body'].read().decode('utf-8') - except ClientError as e: - log.error(f'Failed to get object due to: {e}') - raise e diff --git a/backend/dataall/modules/worksheets/services/worksheet_permissions.py b/backend/dataall/modules/worksheets/services/worksheet_permissions.py index b8cee3f5d..301c2ddca 100644 --- a/backend/dataall/modules/worksheets/services/worksheet_permissions.py +++ b/backend/dataall/modules/worksheets/services/worksheet_permissions.py @@ -1,13 +1,12 @@ -from dataall.core.permissions.services.resources_permissions import ( - RESOURCES_ALL, - RESOURCES_ALL_WITH_DESC, -) from dataall.core.permissions.services.environment_permissions import ( ENVIRONMENT_INVITED, ENVIRONMENT_INVITATION_REQUEST, ENVIRONMENT_ALL, ) - +from dataall.core.permissions.services.resources_permissions import ( + RESOURCES_ALL, + RESOURCES_ALL_WITH_DESC, +) from dataall.core.permissions.services.tenant_permissions import TENANT_ALL, TENANT_ALL_WITH_DESC MANAGE_WORKSHEETS = 'MANAGE_WORKSHEETS' @@ -45,13 +44,3 @@ RESOURCES_ALL.append(RUN_ATHENA_QUERY) RESOURCES_ALL_WITH_DESC[RUN_ATHENA_QUERY] = 'Run Worksheet Athena queries on this environment' - - -""" -DOWNLOAD ATHENA QUERY RESULTS -""" -ENVIRONMENT_INVITED.append(DOWNLOAD_ATHENA_QUERY_RESULTS) -ENVIRONMENT_INVITATION_REQUEST.append(DOWNLOAD_ATHENA_QUERY_RESULTS) -ENVIRONMENT_ALL.append(DOWNLOAD_ATHENA_QUERY_RESULTS) - -RESOURCES_ALL_WITH_DESC[DOWNLOAD_ATHENA_QUERY_RESULTS] = 'Download Worksheet Athena query results' diff --git a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py index 0521a2562..6557f68c2 100644 --- a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py +++ b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py @@ -46,12 +46,12 @@ def _create_query_result( return sql_query_result @staticmethod - def build_s3_file_path(workgroup: str, query_id: str, athena_queries_dir: str = None) -> str: + def _build_s3_file_path(workgroup: str, query_id: str, athena_queries_dir: str = None) -> str: athena_queries_dir = athena_queries_dir or WorksheetQueryResultService._DEFAULT_ATHENA_QUERIES_PATH return f'{athena_queries_dir}/{workgroup}/{query_id}' @staticmethod - def convert_csv_to_xlsx(csv_data) -> io.BytesIO: + def _convert_csv_to_xlsx(csv_data) -> io.BytesIO: wb = Workbook() ws: 'Worksheet' = wb.active csv_reader = csv.reader(csv_data.splitlines()) @@ -85,7 +85,7 @@ def download_sql_query_result(session: 'Session', uri: str, env_uri: str, data: environment.AwsAccountId, data, ) - output_file_s3_path = WorksheetQueryResultService.build_s3_file_path( + output_file_s3_path = WorksheetQueryResultService._build_s3_file_path( env_group.environmentAthenaWorkGroup, data.get('athenaQueryId') ) if sql_query_result.fileFormat == WorksheetResultsFormat.XLSX.value: @@ -94,7 +94,7 @@ def download_sql_query_result(session: 'Session', uri: str, env_uri: str, data: bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{WorksheetResultsFormat.CSV.value}', ) - excel_buffer = WorksheetQueryResultService.convert_csv_to_xlsx(csv_data) + excel_buffer = WorksheetQueryResultService._convert_csv_to_xlsx(csv_data) s3_client.put_object( bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{WorksheetResultsFormat.XLSX.value}', diff --git a/tests/modules/worksheets/test_worksheet.py b/tests/modules/worksheets/test_worksheet.py index 3e467b4f9..fd6b8bd89 100644 --- a/tests/modules/worksheets/test_worksheet.py +++ b/tests/modules/worksheets/test_worksheet.py @@ -1,7 +1,4 @@ import pytest - -from unittest.mock import MagicMock - from future.backports.datetime import datetime from dataall.modules.worksheets.api.resolvers import WorksheetRole @@ -196,3 +193,67 @@ def test_create_query_download_url(client, worksheet, env_fixture, group): assert response.data.createWorksheetQueryResultDownloadUrl.downloadLink is not None assert response.data.createWorksheetQueryResultDownloadUrl.fileFormat == 'csv' assert expires_in > datetime.utcnow() + + +def test_tenant_unauthorized__create_query_download_url(client, worksheet, env_fixture, tenant): + response = client.query( + """ + mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){ + createWorksheetQueryResultDownloadUrl(input: $input){ + sqlBody + AthenaQueryId + region + AwsAccountId + elapsedTimeInMs + created + downloadLink + outputLocation + expiresIn + fileFormat + } + } + """, + input={ + 'worksheetUri': worksheet.worksheetUri, + 'athenaQueryId': '123', + 'fileFormat': 'csv', + 'environmentUri': env_fixture.environmentUri, + }, + ) + + assert response.errors is not None + assert len(response.errors) > 0 + assert f'is not authorized to perform: MANAGE_WORKSHEETS on {tenant.name}' in response.errors[0].message + + +def test_resource_unauthorized__create_query_download_url(client, worksheet, env_fixture, group2): + response = client.query( + """ + mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){ + createWorksheetQueryResultDownloadUrl(input: $input){ + sqlBody + AthenaQueryId + region + AwsAccountId + elapsedTimeInMs + created + downloadLink + outputLocation + expiresIn + fileFormat + } + } + """, + input={ + 'worksheetUri': worksheet.worksheetUri, + 'athenaQueryId': '123', + 'fileFormat': 'csv', + 'environmentUri': env_fixture.environmentUri, + }, + username='bob', + groups=[group2.name] + ) + + assert response.errors is not None + assert len(response.errors) > 0 + assert f'is not authorized to perform: DOWNLOAD_ATHENA_QUERY_RESULTS on resource: {worksheet.worksheetUri}' in response.errors[0].message From 5c2be053cb3dfb17f0e24919f4391cbfa8f3130a Mon Sep 17 00:00:00 2001 From: anison Date: Wed, 23 Oct 2024 15:07:33 +0530 Subject: [PATCH 8/8] fixing CI/CD checks --- .../427db8f31999_backfill_MF_resource_permissions.py | 2 +- tests/modules/worksheets/test_worksheet.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py b/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py index 5209963e8..1932a7fd5 100644 --- a/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py +++ b/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py @@ -24,7 +24,7 @@ # revision identifiers, used by Alembic. revision = '427db8f31999' -down_revision = 'f87aecc36d39' +down_revision = 'd1d6da1b2d67' branch_labels = None depends_on = None diff --git a/tests/modules/worksheets/test_worksheet.py b/tests/modules/worksheets/test_worksheet.py index fd6b8bd89..0a7bd43dc 100644 --- a/tests/modules/worksheets/test_worksheet.py +++ b/tests/modules/worksheets/test_worksheet.py @@ -189,7 +189,9 @@ def test_create_query_download_url(client, worksheet, env_fixture, group): groups=[group.name], ) - expires_in = datetime.strptime(response.data.createWorksheetQueryResultDownloadUrl.created, '%Y-%m-%d %H:%M:%S.%f') + expires_in = datetime.strptime( + response.data.createWorksheetQueryResultDownloadUrl.expiresIn, '%Y-%m-%d %H:%M:%S.%f' + ) assert response.data.createWorksheetQueryResultDownloadUrl.downloadLink is not None assert response.data.createWorksheetQueryResultDownloadUrl.fileFormat == 'csv' assert expires_in > datetime.utcnow() @@ -251,9 +253,12 @@ def test_resource_unauthorized__create_query_download_url(client, worksheet, env 'environmentUri': env_fixture.environmentUri, }, username='bob', - groups=[group2.name] + groups=[group2.name], ) assert response.errors is not None assert len(response.errors) > 0 - assert f'is not authorized to perform: DOWNLOAD_ATHENA_QUERY_RESULTS on resource: {worksheet.worksheetUri}' in response.errors[0].message + assert ( + f'is not authorized to perform: DOWNLOAD_ATHENA_QUERY_RESULTS on resource: {worksheet.worksheetUri}' + in response.errors[0].message + )