diff --git a/backend/dataall/base/aws/s3_client.py b/backend/dataall/base/aws/s3_client.py
index 74f79e7b4..e05fa9933 100644
--- a/backend/dataall/base/aws/s3_client.py
+++ b/backend/dataall/base/aws/s3_client.py
@@ -1,8 +1,8 @@
+import logging
+
import boto3
from botocore.config import Config
-
from botocore.exceptions import ClientError
-import logging
log = logging.getLogger(__name__)
diff --git a/backend/dataall/modules/worksheets/api/__init__.py b/backend/dataall/modules/worksheets/api/__init__.py
index 1251b301b..48b0459d6 100644
--- a/backend/dataall/modules/worksheets/api/__init__.py
+++ b/backend/dataall/modules/worksheets/api/__init__.py
@@ -4,7 +4,12 @@
queries,
resolvers,
types,
- enums,
)
-__all__ = ['resolvers', 'types', 'input_types', 'queries', 'mutations', 'enums']
+__all__ = [
+ 'resolvers',
+ 'types',
+ 'input_types',
+ 'queries',
+ 'mutations',
+]
diff --git a/backend/dataall/modules/worksheets/api/input_types.py b/backend/dataall/modules/worksheets/api/input_types.py
index e2ec62dea..9fbeb75b1 100644
--- a/backend/dataall/modules/worksheets/api/input_types.py
+++ b/backend/dataall/modules/worksheets/api/input_types.py
@@ -75,3 +75,14 @@
gql.Argument(name='measures', type=gql.ArrayType(gql.Ref('WorksheetMeasureInput'))),
],
)
+
+
+WorksheetQueryResultDownloadUrlInput = gql.InputType(
+ name='WorksheetQueryResultDownloadUrlInput',
+ arguments=[
+ gql.Argument(name='athenaQueryId', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='fileFormat', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='worksheetUri', type=gql.NonNullableType(gql.String)),
+ ],
+)
diff --git a/backend/dataall/modules/worksheets/api/mutations.py b/backend/dataall/modules/worksheets/api/mutations.py
index a232fba83..efe4afa6b 100644
--- a/backend/dataall/modules/worksheets/api/mutations.py
+++ b/backend/dataall/modules/worksheets/api/mutations.py
@@ -1,5 +1,10 @@
from dataall.base.api import gql
-from dataall.modules.worksheets.api.resolvers import create_worksheet, delete_worksheet, update_worksheet
+from dataall.modules.worksheets.api.resolvers import (
+ create_worksheet,
+ delete_worksheet,
+ update_worksheet,
+ create_athena_query_result_download_url,
+)
createWorksheet = gql.MutationField(
@@ -27,3 +32,12 @@
],
type=gql.Boolean,
)
+
+createWorksheetQueryResultDownloadUrl = gql.MutationField(
+ name='createWorksheetQueryResultDownloadUrl',
+ resolver=create_athena_query_result_download_url,
+ args=[
+ gql.Argument(name='input', type=gql.Ref('WorksheetQueryResultDownloadUrlInput')),
+ ],
+ type=gql.Ref('WorksheetQueryResult'),
+)
diff --git a/backend/dataall/modules/worksheets/api/resolvers.py b/backend/dataall/modules/worksheets/api/resolvers.py
index 450667217..07bf5ea80 100644
--- a/backend/dataall/modules/worksheets/api/resolvers.py
+++ b/backend/dataall/modules/worksheets/api/resolvers.py
@@ -1,9 +1,10 @@
from dataall.base.db import exceptions
-from dataall.modules.worksheets.api.enums import WorksheetRole
+from dataall.modules.worksheets.services.worksheet_enums import WorksheetRole, WorksheetResultsFormat
from dataall.modules.worksheets.db.worksheet_models import Worksheet
from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository
from dataall.modules.worksheets.services.worksheet_service import WorksheetService
from dataall.base.api.context import Context
+from dataall.modules.worksheets.services.worksheet_query_result_service import WorksheetQueryResultService
def create_worksheet(context: Context, source, input: dict = None):
@@ -69,3 +70,28 @@ def run_sql_query(context: Context, source, environmentUri: str = None, workshee
def delete_worksheet(context, source, worksheetUri: str = None):
with context.engine.scoped_session() as session:
return WorksheetService.delete_worksheet(session=session, uri=worksheetUri)
+
+
+def create_athena_query_result_download_url(context: Context, source, input: dict = None):
+ if not input:
+ raise exceptions.RequiredParameter('data')
+ if not input.get('environmentUri'):
+ raise exceptions.RequiredParameter('environmentUri')
+ if not input.get('athenaQueryId'):
+ raise exceptions.RequiredParameter('athenaQueryId')
+ if not input.get('fileFormat'):
+ raise exceptions.RequiredParameter('fileFormat')
+ if not hasattr(WorksheetResultsFormat, input.get('fileFormat').upper()):
+ raise exceptions.InvalidInput(
+ 'fileFormat',
+ input.get('fileFormat'),
+ ', '.join(result_format.value for result_format in WorksheetResultsFormat),
+ )
+
+ env_uri = input['environmentUri']
+ worksheet_uri = input['worksheetUri']
+
+ with context.engine.scoped_session() as session:
+ return WorksheetQueryResultService.download_sql_query_result(
+ session=session, uri=worksheet_uri, env_uri=env_uri, data=input
+ )
diff --git a/backend/dataall/modules/worksheets/api/types.py b/backend/dataall/modules/worksheets/api/types.py
index e9187353f..f55741f8b 100644
--- a/backend/dataall/modules/worksheets/api/types.py
+++ b/backend/dataall/modules/worksheets/api/types.py
@@ -86,15 +86,16 @@
name='WorksheetQueryResult',
fields=[
gql.Field(name='worksheetQueryResultUri', type=gql.ID),
- gql.Field(name='queryType', type=gql.NonNullableType(gql.String)),
- gql.Field(name='sqlBody', type=gql.NonNullableType(gql.String)),
+ gql.Field(name='sqlBody', type=gql.String),
gql.Field(name='AthenaQueryId', type=gql.NonNullableType(gql.String)),
gql.Field(name='region', type=gql.NonNullableType(gql.String)),
gql.Field(name='AwsAccountId', type=gql.NonNullableType(gql.String)),
- gql.Field(name='AthenaOutputBucketName', type=gql.NonNullableType(gql.String)),
- gql.Field(name='AthenaOutputKey', type=gql.NonNullableType(gql.String)),
- gql.Field(name='timeElapsedInSecond', type=gql.NonNullableType(gql.Integer)),
+ gql.Field(name='elapsedTimeInMs', type=gql.Integer),
gql.Field(name='created', type=gql.NonNullableType(gql.String)),
+ gql.Field(name='downloadLink', type=gql.String),
+ gql.Field(name='outputLocation', type=gql.String),
+ gql.Field(name='expiresIn', type=gql.AWSDateTime),
+ gql.Field(name='fileFormat', type=gql.String),
],
)
diff --git a/backend/dataall/modules/worksheets/aws/s3_client.py b/backend/dataall/modules/worksheets/aws/s3_client.py
new file mode 100644
index 000000000..9d07132e1
--- /dev/null
+++ b/backend/dataall/modules/worksheets/aws/s3_client.py
@@ -0,0 +1,69 @@
+import logging
+from typing import TYPE_CHECKING
+
+from botocore.exceptions import ClientError
+
+from dataall.base.aws.sts import SessionHelper
+from dataall.base.db.exceptions import AWSResourceNotFound
+
+if TYPE_CHECKING:
+ from dataall.core.environment.db.environment_models import Environment
+
+ try:
+ from mypy_boto3_s3 import S3Client as S3ClientType
+ except ImportError:
+ S3ClientType = None
+
+log = logging.getLogger(__name__)
+
+
+class S3Client:
+ def __init__(self, env: 'Environment'):
+ self._client = SessionHelper.remote_session(env.AwsAccountId, env.region).client('s3', region_name=env.region)
+ self._env = env
+
+ @property
+ def client(self) -> 'S3ClientType':
+ return self._client
+
+ def get_presigned_url(self, bucket, key, expire_minutes: int = 15):
+ expire_seconds = expire_minutes * 60
+ try:
+ presigned_url = self.client.generate_presigned_url(
+ 'get_object',
+ Params=dict(
+ Bucket=bucket,
+ Key=key,
+ ),
+ ExpiresIn=expire_seconds,
+ )
+ return presigned_url
+ except ClientError as e:
+ log.error(f'Failed to get presigned URL due to: {e}')
+ raise e
+
+ def object_exists(self, bucket, key) -> bool:
+ try:
+ self.client.head_object(Bucket=bucket, Key=key)
+ return True
+ except ClientError as e:
+ if e.response['Error']['Code'] == '404':
+ log.info(f'Object {key} not found in bucket {bucket}')
+ return False
+ log.error(f'Failed to check object existence due to: {e}')
+ raise AWSResourceNotFound('s3_object_exists', f'Object {key} not found in bucket {bucket}')
+
+ def put_object(self, bucket, key, body):
+ try:
+ self.client.put_object(Bucket=bucket, Key=key, Body=body)
+ except ClientError as e:
+ log.error(f'Failed to put object due to: {e}')
+ raise e
+
+ def get_object(self, bucket, key) -> str:
+ try:
+ response = self.client.get_object(Bucket=bucket, Key=key)
+ return response['Body'].read().decode('utf-8')
+ except ClientError as e:
+ log.error(f'Failed to get object due to: {e}')
+ raise e
diff --git a/backend/dataall/modules/worksheets/db/worksheet_models.py b/backend/dataall/modules/worksheets/db/worksheet_models.py
index 6549cb96c..b7a01a871 100644
--- a/backend/dataall/modules/worksheets/db/worksheet_models.py
+++ b/backend/dataall/modules/worksheets/db/worksheet_models.py
@@ -1,7 +1,7 @@
import datetime
import enum
-from sqlalchemy import Column, DateTime, Integer, Enum, String
+from sqlalchemy import Column, DateTime, Integer, Enum, String, BigInteger
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import query_expression
@@ -27,15 +27,23 @@ class Worksheet(Resource, Base):
class WorksheetQueryResult(Base):
__tablename__ = 'worksheet_query_result'
+ worksheetQueryResultUri = Column(String, primary_key=True, default=utils.uuid('worksheetQueryResultUri'))
worksheetUri = Column(String, nullable=False)
- AthenaQueryId = Column(String, primary_key=True)
- status = Column(String, nullable=False)
- queryType = Column(Enum(QueryType), nullable=False, default=True)
- sqlBody = Column(String, nullable=False)
+ AthenaQueryId = Column(String, nullable=False)
+ status = Column(String, nullable=True)
+ sqlBody = Column(String, nullable=True)
AwsAccountId = Column(String, nullable=False)
region = Column(String, nullable=False)
OutputLocation = Column(String, nullable=False)
error = Column(String, nullable=True)
ElapsedTimeInMs = Column(Integer, nullable=True)
- DataScannedInBytes = Column(Integer, nullable=True)
+ DataScannedInBytes = Column(BigInteger, nullable=True)
created = Column(DateTime, default=datetime.datetime.now)
+
+ downloadLink = Column(String, nullable=True)
+ expiresIn = Column(DateTime, nullable=True)
+ updated = Column(DateTime, nullable=False, onupdate=datetime.datetime.utcnow, default=datetime.datetime.utcnow)
+ fileFormat = Column(String, nullable=True)
+
+ def is_download_link_expired(self):
+ return self.expiresIn is None or self.expiresIn <= datetime.datetime.utcnow()
diff --git a/backend/dataall/modules/worksheets/db/worksheet_repositories.py b/backend/dataall/modules/worksheets/db/worksheet_repositories.py
index aea51761b..cac450943 100644
--- a/backend/dataall/modules/worksheets/db/worksheet_repositories.py
+++ b/backend/dataall/modules/worksheets/db/worksheet_repositories.py
@@ -53,3 +53,17 @@ def paginated_user_worksheets(session, username, groups, uri, data=None, check_p
page=data.get('page', WorksheetRepository._DEFAULT_PAGE),
page_size=data.get('pageSize', WorksheetRepository._DEFAULT_PAGE_SIZE),
).to_dict()
+
+ @staticmethod
+ def find_query_result_by_format(
+ session, worksheet_uri: str, athena_query_id: str, file_format: str
+ ) -> WorksheetQueryResult:
+ return (
+ session.query(WorksheetQueryResult)
+ .filter(
+ WorksheetQueryResult.worksheetUri == worksheet_uri,
+ WorksheetQueryResult.AthenaQueryId == athena_query_id,
+ WorksheetQueryResult.fileFormat == file_format,
+ )
+ .first()
+ )
diff --git a/backend/dataall/modules/worksheets/api/enums.py b/backend/dataall/modules/worksheets/services/worksheet_enums.py
similarity index 65%
rename from backend/dataall/modules/worksheets/api/enums.py
rename to backend/dataall/modules/worksheets/services/worksheet_enums.py
index 3e9549f2a..c973c2399 100644
--- a/backend/dataall/modules/worksheets/api/enums.py
+++ b/backend/dataall/modules/worksheets/services/worksheet_enums.py
@@ -5,3 +5,8 @@ class WorksheetRole(GraphQLEnumMapper):
Creator = '950'
Admin = '900'
NoPermission = '000'
+
+
+class WorksheetResultsFormat(GraphQLEnumMapper):
+ CSV = 'csv'
+ XLSX = 'xlsx'
diff --git a/backend/dataall/modules/worksheets/services/worksheet_permissions.py b/backend/dataall/modules/worksheets/services/worksheet_permissions.py
index 0f494567d..301c2ddca 100644
--- a/backend/dataall/modules/worksheets/services/worksheet_permissions.py
+++ b/backend/dataall/modules/worksheets/services/worksheet_permissions.py
@@ -1,13 +1,12 @@
-from dataall.core.permissions.services.resources_permissions import (
- RESOURCES_ALL,
- RESOURCES_ALL_WITH_DESC,
-)
from dataall.core.permissions.services.environment_permissions import (
ENVIRONMENT_INVITED,
ENVIRONMENT_INVITATION_REQUEST,
ENVIRONMENT_ALL,
)
-
+from dataall.core.permissions.services.resources_permissions import (
+ RESOURCES_ALL,
+ RESOURCES_ALL_WITH_DESC,
+)
from dataall.core.permissions.services.tenant_permissions import TENANT_ALL, TENANT_ALL_WITH_DESC
MANAGE_WORKSHEETS = 'MANAGE_WORKSHEETS'
@@ -22,12 +21,10 @@
UPDATE_WORKSHEET = 'UPDATE_WORKSHEET'
DELETE_WORKSHEET = 'DELETE_WORKSHEET'
RUN_WORKSHEET_QUERY = 'RUN_WORKSHEET_QUERY'
-WORKSHEET_ALL = [
- GET_WORKSHEET,
- UPDATE_WORKSHEET,
- DELETE_WORKSHEET,
- RUN_WORKSHEET_QUERY,
-]
+DOWNLOAD_ATHENA_QUERY_RESULTS = 'DOWNLOAD_ATHENA_QUERY_RESULTS'
+
+
+WORKSHEET_ALL = [GET_WORKSHEET, UPDATE_WORKSHEET, DELETE_WORKSHEET, RUN_WORKSHEET_QUERY, DOWNLOAD_ATHENA_QUERY_RESULTS]
RESOURCES_ALL.extend(WORKSHEET_ALL)
diff --git a/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py
new file mode 100644
index 000000000..6557f68c2
--- /dev/null
+++ b/backend/dataall/modules/worksheets/services/worksheet_query_result_service.py
@@ -0,0 +1,123 @@
+import csv
+import io
+import os
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING
+
+from openpyxl import Workbook
+
+from dataall.base.db import exceptions
+from dataall.core.environment.services.environment_service import EnvironmentService
+from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService
+from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService
+from dataall.modules.worksheets.aws.s3_client import S3Client
+from dataall.modules.worksheets.db.worksheet_models import WorksheetQueryResult
+from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository
+from dataall.modules.worksheets.services.worksheet_enums import WorksheetResultsFormat
+from dataall.modules.worksheets.services.worksheet_permissions import DOWNLOAD_ATHENA_QUERY_RESULTS, MANAGE_WORKSHEETS
+from dataall.modules.worksheets.services.worksheet_service import WorksheetService
+
+if TYPE_CHECKING:
+ try:
+ from sqlalchemy.orm import Session
+ from openpyxl.worksheet.worksheet import Worksheet
+ except ImportError:
+ print('skipping type checks as stubs are not installed')
+ Session = None
+ Worksheet = None
+
+
+class WorksheetQueryResultService:
+ _DEFAULT_ATHENA_QUERIES_PATH = 'athenaqueries'
+ _DEFAULT_QUERY_RESULTS_TIMEOUT = os.getenv('QUERY_RESULT_TIMEOUT_MINUTES', 120)
+
+ @staticmethod
+ def _create_query_result(
+ environment_bucket: str, athena_workgroup: str, worksheet_uri: str, region: str, aws_account_id: str, data: dict
+ ) -> WorksheetQueryResult:
+ sql_query_result = WorksheetQueryResult(
+ worksheetUri=worksheet_uri,
+ AthenaQueryId=data.get('athenaQueryId'),
+ fileFormat=data.get('fileFormat'),
+ OutputLocation=f's3://{environment_bucket}/{WorksheetQueryResultService._DEFAULT_ATHENA_QUERIES_PATH}/{athena_workgroup}/',
+ region=region,
+ AwsAccountId=aws_account_id,
+ )
+ return sql_query_result
+
+ @staticmethod
+ def _build_s3_file_path(workgroup: str, query_id: str, athena_queries_dir: str = None) -> str:
+ athena_queries_dir = athena_queries_dir or WorksheetQueryResultService._DEFAULT_ATHENA_QUERIES_PATH
+ return f'{athena_queries_dir}/{workgroup}/{query_id}'
+
+ @staticmethod
+ def _convert_csv_to_xlsx(csv_data) -> io.BytesIO:
+ wb = Workbook()
+ ws: 'Worksheet' = wb.active
+ csv_reader = csv.reader(csv_data.splitlines())
+ for row in csv_reader:
+ ws.append(row)
+
+ excel_buffer = io.BytesIO()
+ wb.save(excel_buffer)
+ excel_buffer.seek(0)
+ return excel_buffer
+
+ @staticmethod
+ @TenantPolicyService.has_tenant_permission(MANAGE_WORKSHEETS)
+ @ResourcePolicyService.has_resource_permission(DOWNLOAD_ATHENA_QUERY_RESULTS)
+ def download_sql_query_result(session: 'Session', uri: str, env_uri: str, data: dict = None):
+ environment = EnvironmentService.get_environment_by_uri(session, env_uri)
+ worksheet = WorksheetService.get_worksheet_by_uri(session, uri)
+ env_group = EnvironmentService.get_environment_group(
+ session, worksheet.SamlAdminGroupName, environment.environmentUri
+ )
+ sql_query_result = WorksheetRepository.find_query_result_by_format(
+ session, data.get('worksheetUri'), data.get('athenaQueryId'), data.get('fileFormat')
+ )
+ s3_client = S3Client(environment)
+ if not sql_query_result:
+ sql_query_result = WorksheetQueryResultService._create_query_result(
+ environment.EnvironmentDefaultBucketName,
+ env_group.environmentAthenaWorkGroup,
+ worksheet.worksheetUri,
+ environment.region,
+ environment.AwsAccountId,
+ data,
+ )
+ output_file_s3_path = WorksheetQueryResultService._build_s3_file_path(
+ env_group.environmentAthenaWorkGroup, data.get('athenaQueryId')
+ )
+ if sql_query_result.fileFormat == WorksheetResultsFormat.XLSX.value:
+ try:
+ csv_data = s3_client.get_object(
+ bucket=environment.EnvironmentDefaultBucketName,
+ key=f'{output_file_s3_path}.{WorksheetResultsFormat.CSV.value}',
+ )
+ excel_buffer = WorksheetQueryResultService._convert_csv_to_xlsx(csv_data)
+ s3_client.put_object(
+ bucket=environment.EnvironmentDefaultBucketName,
+ key=f'{output_file_s3_path}.{WorksheetResultsFormat.XLSX.value}',
+ body=excel_buffer,
+ )
+ except Exception as e:
+ raise exceptions.AWSResourceNotAvailable('CONVERT_CSV_TO_EXCEL', f'Failed to convert csv to xlsx: {e}')
+
+ s3_client.object_exists(
+ bucket=environment.EnvironmentDefaultBucketName, key=f'{output_file_s3_path}.{sql_query_result.fileFormat}'
+ )
+ if sql_query_result.is_download_link_expired():
+ url = s3_client.get_presigned_url(
+ bucket=environment.EnvironmentDefaultBucketName,
+ key=f'{output_file_s3_path}.{sql_query_result.fileFormat}',
+ expire_minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT,
+ )
+ sql_query_result.downloadLink = url
+ sql_query_result.expiresIn = datetime.utcnow() + timedelta(
+ minutes=WorksheetQueryResultService._DEFAULT_QUERY_RESULTS_TIMEOUT
+ )
+
+ session.add(sql_query_result)
+ session.commit()
+
+ return sql_query_result
diff --git a/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py b/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py
index 5209963e8..1932a7fd5 100644
--- a/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py
+++ b/backend/migrations/versions/427db8f31999_backfill_MF_resource_permissions.py
@@ -24,7 +24,7 @@
# revision identifiers, used by Alembic.
revision = '427db8f31999'
-down_revision = 'f87aecc36d39'
+down_revision = 'd1d6da1b2d67'
branch_labels = None
depends_on = None
diff --git a/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py
new file mode 100644
index 000000000..d4cf99819
--- /dev/null
+++ b/backend/migrations/versions/d1d6da1b2d67_add_columns_worksheet_query_result_model.py
@@ -0,0 +1,46 @@
+"""add_columns_worksheet_query_result_model
+
+Revision ID: d1d6da1b2d67
+Revises: d274e756f0ae
+Create Date: 2024-09-10 14:34:31.492186
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = 'd1d6da1b2d67'
+down_revision = 'f87aecc36d39'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ op.alter_column('worksheet_query_result', 'status', nullable=True)
+ op.alter_column('worksheet_query_result', 'sqlBody', nullable=True)
+ op.alter_column('worksheet_query_result', 'DataScannedInBytes', type_=sa.BigInteger(), nullable=True)
+ op.add_column('worksheet_query_result', sa.Column('downloadLink', sa.String(), nullable=True))
+ op.add_column('worksheet_query_result', sa.Column('expiresIn', sa.DateTime(), nullable=True))
+ op.add_column('worksheet_query_result', sa.Column('updated', sa.DateTime(), nullable=False))
+ op.add_column('worksheet_query_result', sa.Column('fileFormat', sa.String(), nullable=True))
+ op.add_column('worksheet_query_result', sa.Column('worksheetQueryResultUri', sa.String(), nullable=False))
+ op.drop_constraint('AthenaQueryId', 'worksheet_query_result', type_='primary')
+ op.create_primary_key('worksheet_query_result_pkey', 'worksheet_query_result', ['worksheetQueryResultUri'])
+ op.alter_column('worksheet_query_result', 'AthenaQueryId', nullable=False)
+ op.drop_column('worksheet_query_result', 'queryType')
+
+
+def downgrade():
+ op.add_column('worksheet_query_result', sa.Column('queryType', sa.VARCHAR(), autoincrement=False, nullable=True))
+ op.drop_constraint('worksheet_query_result_pkey', 'worksheet_query_result', type_='primary')
+ op.create_primary_key('AthenaQueryId', 'worksheet_query_result', ['AthenaQueryId'])
+ op.drop_column('worksheet_query_result', 'worksheetQueryResultUri')
+ op.drop_column('worksheet_query_result', 'fileFormat')
+ op.drop_column('worksheet_query_result', 'updated')
+ op.drop_column('worksheet_query_result', 'expiresIn')
+ op.drop_column('worksheet_query_result', 'downloadLink')
+ op.alter_column('worksheet_query_result', 'DataScannedInBytes', type_=sa.Integer(), nullable=True)
+ op.alter_column('worksheet_query_result', 'sqlBody', nullable=False)
+ op.alter_column('worksheet_query_result', 'status', nullable=False)
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 05cb6619c..4e18f9db7 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -1,7 +1,10 @@
ariadne==0.17.0
aws-xray-sdk==2.4.3
-boto3==1.35.26
-fastapi == 0.115.0
+boto3==1.34.119
+botocore==1.34.119
+fastapi == 0.109.2
+Flask==3.0.3
+flask-cors==5.0.0
nanoid==2.0.0
opensearch-py==1.0.0
PyAthena==2.3.0
@@ -11,5 +14,7 @@ PyYAML==6.0
requests==2.32.2
requests_aws4auth==1.1.1
sqlalchemy==1.3.24
+starlette==0.36.3
alembic==1.13.1
+openpyxl==3.1.5
retrying==1.3.4
diff --git a/frontend/src/modules/Worksheets/components/WorksheetResult.js b/frontend/src/modules/Worksheets/components/WorksheetResult.js
index e8e74e557..62fa6dea4 100644
--- a/frontend/src/modules/Worksheets/components/WorksheetResult.js
+++ b/frontend/src/modules/Worksheets/components/WorksheetResult.js
@@ -12,12 +12,69 @@ import {
TableRow
} from '@mui/material';
import PropTypes from 'prop-types';
-import React from 'react';
+import React, { useState, useCallback } from 'react';
import { FaBars } from 'react-icons/fa';
import * as ReactIf from 'react-if';
import { Scrollbar } from 'design';
-export const WorksheetResult = ({ results, loading }) => {
+import Stack from '@mui/material/Stack';
+import Radio from '@mui/material/Radio';
+import RadioGroup from '@mui/material/RadioGroup';
+import FormControlLabel from '@mui/material/FormControlLabel';
+import { Download } from '@mui/icons-material';
+import { LoadingButton } from '@mui/lab';
+import { useClient } from 'services';
+import { createWorksheetQueryResultDownloadUrl } from '../services';
+import { SET_ERROR, useDispatch } from 'globalErrors';
+
+export const WorksheetResult = ({
+ results,
+ loading,
+ currentEnv,
+ athenaQueryId,
+ worksheetUri
+}) => {
+ const [runningDownloadQuery, setRunningDownloadQuery] = useState(false);
+ const [fileType, setFileType] = useState('csv');
+ const client = useClient();
+ const dispatch = useDispatch();
+
+ const handleChange = (event) => {
+ setFileType(event.target.value);
+ };
+
+ const downloadFile = useCallback(async () => {
+ try {
+ setRunningDownloadQuery(true);
+ const response = await client.query(
+ createWorksheetQueryResultDownloadUrl({
+ fileFormat: fileType,
+ environmentUri: currentEnv.environmentUri,
+ athenaQueryId: athenaQueryId,
+ worksheetUri: worksheetUri
+ })
+ );
+
+ if (!response.errors) {
+ const link = document.createElement('a');
+ link.href =
+ response.data.createWorksheetQueryResultDownloadUrl.downloadLink;
+ // Append to html link element page
+ document.body.appendChild(link);
+ // Start download
+ link.click();
+ // Clean up and remove the link
+ link.parentNode.removeChild(link);
+ } else {
+ dispatch({ type: SET_ERROR, error: response.errors[0].message });
+ }
+ } catch (e) {
+ dispatch({ type: SET_ERROR, error: e.message });
+ } finally {
+ setRunningDownloadQuery(false);
+ }
+ }, [client, dispatch, currentEnv, athenaQueryId, fileType]);
+
if (loading) {
return ;
}
@@ -34,6 +91,41 @@ export const WorksheetResult = ({ results, loading }) => {
Query Results
}
+ action={
+ <>
+
+
+ }
+ label="CSV"
+ />
+ }
+ label="XLSX"
+ />
+
+
+ }
+ sx={{ m: 1 }}
+ variant="contained"
+ >
+ Download
+
+
+ >
+ }
/>
diff --git a/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js
new file mode 100644
index 000000000..305ab5599
--- /dev/null
+++ b/frontend/src/modules/Worksheets/services/createWorksheetQueryResultDownloadUrl.js
@@ -0,0 +1,20 @@
+import { gql } from 'apollo-boost';
+
+export const createWorksheetQueryResultDownloadUrl = (input) => ({
+ variables: {
+ input
+ },
+ query: gql`
+ mutation CreateWorksheetQueryResultDownloadUrl(
+ $input: WorksheetQueryResultDownloadUrlInput!
+ ) {
+ createWorksheetQueryResultDownloadUrl(input: $input) {
+ downloadLink
+ AthenaQueryId
+ expiresIn
+ fileFormat
+ outputLocation
+ }
+ }
+ `
+});
diff --git a/frontend/src/modules/Worksheets/services/index.js b/frontend/src/modules/Worksheets/services/index.js
index b10e7d361..65637d6ae 100644
--- a/frontend/src/modules/Worksheets/services/index.js
+++ b/frontend/src/modules/Worksheets/services/index.js
@@ -6,3 +6,4 @@ export * from './listWorksheets';
export * from './runAthenaSqlQuery';
export * from './updateWorksheet';
export * from './listSharedDatasetTableColumns';
+export * from './createWorksheetQueryResultDownloadUrl';
diff --git a/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js b/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js
index 8d5666cf4..80bda11bf 100644
--- a/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js
+++ b/frontend/src/modules/Worksheets/services/runAthenaSqlQuery.js
@@ -21,6 +21,7 @@ export const runAthenaSqlQuery = ({
worksheetUri: $worksheetUri
sqlQuery: $sqlQuery
) {
+ AthenaQueryId
rows {
cells {
columnName
diff --git a/frontend/src/modules/Worksheets/views/WorksheetView.js b/frontend/src/modules/Worksheets/views/WorksheetView.js
index 545c3f190..419e51d77 100644
--- a/frontend/src/modules/Worksheets/views/WorksheetView.js
+++ b/frontend/src/modules/Worksheets/views/WorksheetView.js
@@ -76,6 +76,8 @@ const WorksheetView = () => {
const [runningQuery, setRunningQuery] = useState(false);
const [isEditWorksheetOpen, setIsEditWorksheetOpen] = useState(null);
const [isDeleteWorksheetOpen, setIsDeleteWorksheetOpen] = useState(null);
+ const [athenaQueryId, setAthenaQueryId] = useState();
+
const handleEditWorksheetModalOpen = () => {
setIsEditWorksheetOpen(true);
};
@@ -291,6 +293,7 @@ const WorksheetView = () => {
);
if (!response.errors) {
const athenaResults = response.data.runAthenaSqlQuery;
+ setAthenaQueryId(response.data.runAthenaSqlQuery.AthenaQueryId);
setResults({
rows: athenaResults.rows.map((c, index) => ({ ...c, id: index })),
columns: athenaResults.columns.map((c, index) => ({
@@ -636,7 +639,13 @@ const WorksheetView = () => {
-
+
diff --git a/tests/modules/worksheets/test_worksheet.py b/tests/modules/worksheets/test_worksheet.py
index 0f3b10e41..0a7bd43dc 100644
--- a/tests/modules/worksheets/test_worksheet.py
+++ b/tests/modules/worksheets/test_worksheet.py
@@ -1,4 +1,5 @@
import pytest
+from future.backports.datetime import datetime
from dataall.modules.worksheets.api.resolvers import WorksheetRole
@@ -30,6 +31,19 @@ def worksheet(client, tenant, group):
return response.data.createWorksheet
+@pytest.fixture(scope='module', autouse=True)
+def mock_s3_client(module_mocker):
+ s3_client = module_mocker.patch(
+ 'dataall.modules.worksheets.services.worksheet_query_result_service.S3Client', autospec=True
+ )
+
+ s3_client.return_value.object_exists.return_value = True
+ s3_client.return_value.put_object.return_value = None
+ s3_client.return_value.get_object.return_value = '123,123,123'
+ s3_client.return_value.get_presigned_url.return_value = 'https://s3.amazonaws.com/file/123.csv'
+ yield s3_client
+
+
def test_create_worksheet(client, worksheet):
assert worksheet.label == 'my worksheet'
assert worksheet.owner == 'alice'
@@ -145,3 +159,106 @@ def test_update_worksheet(client, worksheet, group):
)
assert response.data.updateWorksheet.label == 'change label'
+
+
+def test_create_query_download_url(client, worksheet, env_fixture, group):
+ response = client.query(
+ """
+ mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){
+ createWorksheetQueryResultDownloadUrl(input: $input){
+ sqlBody
+ AthenaQueryId
+ region
+ AwsAccountId
+ elapsedTimeInMs
+ created
+ downloadLink
+ outputLocation
+ expiresIn
+ fileFormat
+ }
+ }
+ """,
+ input={
+ 'worksheetUri': worksheet.worksheetUri,
+ 'athenaQueryId': '123',
+ 'fileFormat': 'csv',
+ 'environmentUri': env_fixture.environmentUri,
+ },
+ username='alice',
+ groups=[group.name],
+ )
+
+ expires_in = datetime.strptime(
+ response.data.createWorksheetQueryResultDownloadUrl.expiresIn, '%Y-%m-%d %H:%M:%S.%f'
+ )
+ assert response.data.createWorksheetQueryResultDownloadUrl.downloadLink is not None
+ assert response.data.createWorksheetQueryResultDownloadUrl.fileFormat == 'csv'
+ assert expires_in > datetime.utcnow()
+
+
+def test_tenant_unauthorized__create_query_download_url(client, worksheet, env_fixture, tenant):
+ response = client.query(
+ """
+ mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){
+ createWorksheetQueryResultDownloadUrl(input: $input){
+ sqlBody
+ AthenaQueryId
+ region
+ AwsAccountId
+ elapsedTimeInMs
+ created
+ downloadLink
+ outputLocation
+ expiresIn
+ fileFormat
+ }
+ }
+ """,
+ input={
+ 'worksheetUri': worksheet.worksheetUri,
+ 'athenaQueryId': '123',
+ 'fileFormat': 'csv',
+ 'environmentUri': env_fixture.environmentUri,
+ },
+ )
+
+ assert response.errors is not None
+ assert len(response.errors) > 0
+ assert f'is not authorized to perform: MANAGE_WORKSHEETS on {tenant.name}' in response.errors[0].message
+
+
+def test_resource_unauthorized__create_query_download_url(client, worksheet, env_fixture, group2):
+ response = client.query(
+ """
+ mutation CreateWorksheetQueryResultDownloadUrl($input: WorksheetQueryResultDownloadUrlInput){
+ createWorksheetQueryResultDownloadUrl(input: $input){
+ sqlBody
+ AthenaQueryId
+ region
+ AwsAccountId
+ elapsedTimeInMs
+ created
+ downloadLink
+ outputLocation
+ expiresIn
+ fileFormat
+ }
+ }
+ """,
+ input={
+ 'worksheetUri': worksheet.worksheetUri,
+ 'athenaQueryId': '123',
+ 'fileFormat': 'csv',
+ 'environmentUri': env_fixture.environmentUri,
+ },
+ username='bob',
+ groups=[group2.name],
+ )
+
+ assert response.errors is not None
+ assert len(response.errors) > 0
+ assert (
+ f'is not authorized to perform: DOWNLOAD_ATHENA_QUERY_RESULTS on resource: {worksheet.worksheetUri}'
+ in response.errors[0].message
+ )