diff --git a/spec/unit/crypto/algorithms/olm.spec.js b/spec/unit/crypto/algorithms/olm.spec.js index 7b6bc626a4c..493c0eba238 100644 --- a/spec/unit/crypto/algorithms/olm.spec.js +++ b/spec/unit/crypto/algorithms/olm.spec.js @@ -1,5 +1,5 @@ /* -Copyright 2018 New Vector Ltd +Copyright 2018,2019 New Vector Ltd Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,6 +23,8 @@ import MockStorageApi from '../../../MockStorageApi'; import testUtils from '../../../test-utils'; import OlmDevice from '../../../../lib/crypto/OlmDevice'; +import olmlib from '../../../../lib/crypto/olmlib'; +import DeviceInfo from '../../../../lib/crypto/deviceinfo'; function makeOlmDevice() { const mockStorage = new MockStorageApi(); @@ -82,5 +84,61 @@ describe("OlmDecryption", function() { "The olm or proteus is an aquatic salamander in the family Proteidae", ); }); + + it("creates only one session at a time", async function() { + // if we call ensureOlmSessionsForDevices multiple times, it should + // only try to create one session at a time, even if the server is + // slow + let count = 0; + const baseApis = { + claimOneTimeKeys: () => { + // simulate a very slow server (.5 seconds to respond) + count++; + return new Promise((resolve, reject) => { + setTimeout(reject, 500); + }); + }, + }; + const devicesByUser = { + "@bob:example.com": [ + DeviceInfo.fromStorage({ + keys: { + "curve25519:ABCDEFG": "akey", + }, + }, "ABCDEFG"), + ], + }; + function alwaysSucceed(promise) { + // swallow any exception thrown by a promise, so that + // Promise.all doesn't abort + return promise.catch(() => {}); + } + + // start two tasks that try to ensure that there's an olm session + const promises = Promise.all([ + alwaysSucceed(olmlib.ensureOlmSessionsForDevices( + aliceOlmDevice, baseApis, devicesByUser, + )), + alwaysSucceed(olmlib.ensureOlmSessionsForDevices( + aliceOlmDevice, baseApis, devicesByUser, + )), + ]); + + await new Promise((resolve) => { + setTimeout(resolve, 200); + }); + + // after .2s, both tasks should have started, but one should be + // waiting on the other before trying to create a session, so + // claimOneTimeKeys should have only been called once + expect(count).toBe(1); + + await promises; + + // after waiting for both tasks to complete, the first task should + // have failed, so the second task should have tried to create a + // new session and will have called claimOneTimeKeys + expect(count).toBe(2); + }); }); }); diff --git a/src/crypto/OlmDevice.js b/src/crypto/OlmDevice.js index 421598ea5b2..77257bf523b 100644 --- a/src/crypto/OlmDevice.js +++ b/src/crypto/OlmDevice.js @@ -1,6 +1,6 @@ /* Copyright 2016 OpenMarket Ltd -Copyright 2017 New Vector Ltd +Copyright 2017, 2019 New Vector Ltd Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -102,6 +102,10 @@ function OlmDevice(sessionStore, cryptoStore) { // Keys are strings of form "||" // Values are objects of the form "{id: , timestamp: }" this._inboundGroupSessionMessageIndexes = {}; + + // Keep track of sessions that we're starting, so that we don't start + // multiple sessions for the same device at the same time. + this._sessionsInProgress = {}; } /** @@ -553,6 +557,15 @@ OlmDevice.prototype.createInboundSession = async function( * @return {Promise} a list of known session ids for the device */ OlmDevice.prototype.getSessionIdsForDevice = async function(theirDeviceIdentityKey) { + if (this._sessionsInProgress[theirDeviceIdentityKey]) { + console.log("waiting for session to be created"); + try { + await this._sessionsInProgress[theirDeviceIdentityKey]; + } catch (e) { + // if the session failed to be created, just fall through and + // return an empty result + } + } let sessionIds; await this._cryptoStore.doTxn( 'readonly', [IndexedDBCryptoStore.STORE_SESSIONS], @@ -573,10 +586,18 @@ OlmDevice.prototype.getSessionIdsForDevice = async function(theirDeviceIdentityK * * @param {string} theirDeviceIdentityKey Curve25519 identity key for the * remote device + * @param {boolean} nowait Don't wait for an in-progress session to complete. + * This should only be set to true of the calling function is the function + * that marked the session as being in-progress. * @return {Promise} session id, or null if no established session */ -OlmDevice.prototype.getSessionIdForDevice = async function(theirDeviceIdentityKey) { - const sessionInfos = await this.getSessionInfoForDevice(theirDeviceIdentityKey); +OlmDevice.prototype.getSessionIdForDevice = async function( + theirDeviceIdentityKey, nowait, +) { + const sessionInfos = await this.getSessionInfoForDevice( + theirDeviceIdentityKey, nowait, + ); + if (sessionInfos.length === 0) { return null; } @@ -611,9 +632,21 @@ OlmDevice.prototype.getSessionIdForDevice = async function(theirDeviceIdentityKe * message and is therefore past the pre-key stage), and 'sessionId'. * * @param {string} deviceIdentityKey Curve25519 identity key for the device + * @param {boolean} nowait Don't wait for an in-progress session to complete. + * This should only be set to true of the calling function is the function + * that marked the session as being in-progress. * @return {Array.<{sessionId: string, hasReceivedMessage: Boolean}>} */ -OlmDevice.prototype.getSessionInfoForDevice = async function(deviceIdentityKey) { +OlmDevice.prototype.getSessionInfoForDevice = async function(deviceIdentityKey, nowait) { + if (this._sessionsInProgress[deviceIdentityKey] && !nowait) { + logger.log("waiting for session to be created"); + try { + await this._sessionsInProgress[deviceIdentityKey]; + } catch (e) { + // if the session failed to be created, then just fall through and + // return an empty result + } + } const info = []; await this._cryptoStore.doTxn( diff --git a/src/crypto/olmlib.js b/src/crypto/olmlib.js index 4ee89bf8d3a..2752c368cff 100644 --- a/src/crypto/olmlib.js +++ b/src/crypto/olmlib.js @@ -1,5 +1,6 @@ /* Copyright 2016 OpenMarket Ltd +Copyright 2019 New Vector Ltd Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -137,6 +138,7 @@ module.exports.ensureOlmSessionsForDevices = async function( // [userId, deviceId], ... ]; const result = {}; + const resolveSession = {}; for (const userId in devicesByUser) { if (!devicesByUser.hasOwnProperty(userId)) { @@ -148,7 +150,36 @@ module.exports.ensureOlmSessionsForDevices = async function( const deviceInfo = devices[j]; const deviceId = deviceInfo.deviceId; const key = deviceInfo.getIdentityKey(); - const sessionId = await olmDevice.getSessionIdForDevice(key); + if (!olmDevice._sessionsInProgress[key]) { + // pre-emptively mark the session as in-progress to avoid race + // conditions. If we find that we already have a session, then + // we'll resolve + olmDevice._sessionsInProgress[key] = new Promise( + (resolve, reject) => { + resolveSession[key] = { + resolve: (...args) => { + delete olmDevice._sessionsInProgress[key]; + resolve(...args); + }, + reject: (...args) => { + delete olmDevice._sessionsInProgress[key]; + reject(...args); + }, + }; + }, + ); + } + const sessionId = await olmDevice.getSessionIdForDevice( + key, resolveSession[key], + ); + if (sessionId !== null && resolveSession[key]) { + // we found a session, but we had marked the session as + // in-progress, so unmark it and unblock anything that was + // waiting + delete olmDevice._sessionsInProgress[key]; + resolveSession[key].resolve(); + delete resolveSession[key]; + } if (sessionId === null || force) { devicesWithoutSession.push([userId, deviceId]); } @@ -163,16 +194,19 @@ module.exports.ensureOlmSessionsForDevices = async function( return result; } - // TODO: this has a race condition - if we try to send another message - // while we are claiming a key, we will end up claiming two and setting up - // two sessions. - // - // That should eventually resolve itself, but it's poor form. - const oneTimeKeyAlgorithm = "signed_curve25519"; - const res = await baseApis.claimOneTimeKeys( - devicesWithoutSession, oneTimeKeyAlgorithm, - ); + let res; + try { + res = await baseApis.claimOneTimeKeys( + devicesWithoutSession, oneTimeKeyAlgorithm, + ); + } catch (e) { + for (const resolver of Object.values(resolveSession)) { + resolver.resolve(); + } + logger.log("failed to claim one-time keys", e, devicesWithoutSession); + throw e; + } const otk_res = res.one_time_keys || {}; const promises = []; @@ -185,6 +219,7 @@ module.exports.ensureOlmSessionsForDevices = async function( for (let j = 0; j < devices.length; j++) { const deviceInfo = devices[j]; const deviceId = deviceInfo.deviceId; + const key = deviceInfo.getIdentityKey(); if (result[userId][deviceId].sessionId && !force) { // we already have a result for this device continue; @@ -199,10 +234,10 @@ module.exports.ensureOlmSessionsForDevices = async function( } if (!oneTimeKey) { - logger.warn( - "No one-time keys (alg=" + oneTimeKeyAlgorithm + - ") for device " + userId + ":" + deviceId, - ); + const msg = "No one-time keys (alg=" + oneTimeKeyAlgorithm + + ") for device " + userId + ":" + deviceId; + logger.warn(msg); + resolveSession[key].resolve(); continue; } @@ -210,7 +245,11 @@ module.exports.ensureOlmSessionsForDevices = async function( _verifyKeyAndStartSession( olmDevice, oneTimeKey, userId, deviceInfo, ).then((sid) => { + resolveSession[key].resolve(sid); result[userId][deviceId].sessionId = sid; + }, (e) => { + resolveSession[key].resolve(); + throw e; }), ); }