diff --git a/lib/aws/AwsSigv4Signer.js b/lib/aws/AwsSigv4Signer.js index 7233611d5..f2086aee5 100644 --- a/lib/aws/AwsSigv4Signer.js +++ b/lib/aws/AwsSigv4Signer.js @@ -15,19 +15,66 @@ const Transport = require('../Transport') const aws4 = require('aws4') const AwsSigv4SignerError = require('./errors') -function AwsSigv4Signer (opts) { +const getAwsSDKCredentialsProvider = async () => { + // First try V3 + try { + const awsV3 = await import('@aws-sdk/credential-provider-node') + if (typeof awsV3.defaultProvider === 'function') { + return awsV3.defaultProvider() + } + } catch (err) { + // Ignore + } + try { + const awsV2 = await import('aws-sdk') + if (awsV2.default && typeof awsV2.default.config.getCredentials === 'function') { + return () => + new Promise((resolve, reject) => { + awsV2.default.config.getCredentials((err, credentials) => { + if (err) { + reject(err) + } else { + resolve(credentials) + } + }) + }) + } + } catch (err) { + // Ignore + } + + throw new AwsSigv4SignerError( + 'Unable to find a valid AWS SDK, please provide a valid getCredentials function to AwsSigv4Signer options.' + ) +} + +const awsDefaultCredentialsProvider = () => + new Promise((resolve, reject) => { + getAwsSDKCredentialsProvider() + .then((provider) => { + provider().then(resolve).catch(reject) + }) + .catch((err) => { + reject(err) + }) + }) + +function AwsSigv4Signer (opts = {}) { const credentialsState = { credentials: null } - if (opts && (!opts.region || opts.region === null || opts.region === '')) { + if (!opts.region) { throw new AwsSigv4SignerError('Region cannot be empty') } - if (opts && typeof opts.getCredentials !== 'function') { - throw new AwsSigv4SignerError('getCredentials function is required') + if (!opts.service) { + opts.service = 'es' + } + if (typeof opts.getCredentials !== 'function') { + opts.getCredentials = awsDefaultCredentialsProvider } function buildSignedRequestObject (request = {}) { - request.service = 'es' + request.service = opts.service request.region = opts.region request.headers = request.headers || {} request.headers.host = request.hostname diff --git a/test/unit/lib/aws/awssigv4signer.test.js b/test/unit/lib/aws/awssigv4signer.test.js index 9568450be..991b1835f 100644 --- a/test/unit/lib/aws/awssigv4signer.test.js +++ b/test/unit/lib/aws/awssigv4signer.test.js @@ -17,7 +17,7 @@ const { Connection } = require('../../../../index') const { Client, buildServer } = require('../../../utils') test('Sign with SigV4', (t) => { - t.plan(2) + t.plan(3) const mockCreds = { accessKeyId: uuidv4(), @@ -51,16 +51,17 @@ test('Sign with SigV4', (t) => { const signedRequest = auth.buildSignedRequestObject(request) t.hasProp(signedRequest.headers, 'X-Amz-Date') t.hasProp(signedRequest.headers, 'Authorization') + t.same(signedRequest.service, 'es') }) test('Sign with SigV4 failure (with empty region)', (t) => { - t.plan(2) - const mockCreds = { accessKeyId: uuidv4(), secretAccessKey: uuidv4() } + const mockRegions = [{ region: undefined }, { region: null }, { region: '' }, {}] + const AwsSigv4SignerOptions = { getCredentials: () => new Promise((resolve) => { @@ -68,13 +69,52 @@ test('Sign with SigV4 failure (with empty region)', (t) => { }) } - try { - AwsSigv4Signer(AwsSigv4SignerOptions) - t.fail('Should fail') - } catch (err) { - t.ok(err instanceof AwsSigv4SignerError) - t.same(err.message, 'Region cannot be empty') + mockRegions.forEach((region) => { + try { + AwsSigv4Signer(Object.assign({}, AwsSigv4SignerOptions, region)) + t.fail('Should fail') + } catch (err) { + t.ok(err instanceof AwsSigv4SignerError) + t.same(err.message, 'Region cannot be empty') + } + }) + + t.end() +}) + +test('Sign with SigV4 and provided service', (t) => { + t.plan(1) + + const mockCreds = { + accessKeyId: uuidv4(), + secretAccessKey: uuidv4() } + + const mockRegion = 'us-west-2' + const mockService = 'foo' + + const AwsSigv4SignerOptions = { + getCredentials: () => + new Promise((resolve) => { + setTimeout(() => resolve(mockCreds), 100) + }), + region: mockRegion, + service: mockService + } + + const auth = AwsSigv4Signer(AwsSigv4SignerOptions) + + const connection = new Connection({ + url: new URL('https://localhost:9200') + }) + + const request = connection.buildRequestObject({ + path: '/hello', + method: 'GET' + }) + + const signedRequest = auth.buildSignedRequestObject(request) + t.same(signedRequest.service, mockService) }) test('Sign with SigV4 failure (without getCredentials function)', (t) => {