Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only create one session at a time per device #857

Merged
merged 2 commits into from
Mar 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion spec/unit/crypto/algorithms/olm.spec.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2018 New Vector Ltd
Copyright 2018,2019 New Vector Ltd

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,8 @@ import MockStorageApi from '../../../MockStorageApi';
import testUtils from '../../../test-utils';

import OlmDevice from '../../../../lib/crypto/OlmDevice';
import olmlib from '../../../../lib/crypto/olmlib';
import DeviceInfo from '../../../../lib/crypto/deviceinfo';

function makeOlmDevice() {
const mockStorage = new MockStorageApi();
Expand Down Expand Up @@ -82,5 +84,61 @@ describe("OlmDecryption", function() {
"The olm or proteus is an aquatic salamander in the family Proteidae",
);
});

it("creates only one session at a time", async function() {
// if we call ensureOlmSessionsForDevices multiple times, it should
// only try to create one session at a time, even if the server is
// slow
let count = 0;
const baseApis = {
claimOneTimeKeys: () => {
// simulate a very slow server (.5 seconds to respond)
count++;
return new Promise((resolve, reject) => {
setTimeout(reject, 500);
});
},
};
const devicesByUser = {
"@bob:example.com": [
DeviceInfo.fromStorage({
keys: {
"curve25519:ABCDEFG": "akey",
},
}, "ABCDEFG"),
],
};
function alwaysSucceed(promise) {
// swallow any exception thrown by a promise, so that
// Promise.all doesn't abort
return promise.catch(() => {});
}

// start two tasks that try to ensure that there's an olm session
const promises = Promise.all([
alwaysSucceed(olmlib.ensureOlmSessionsForDevices(
aliceOlmDevice, baseApis, devicesByUser,
)),
alwaysSucceed(olmlib.ensureOlmSessionsForDevices(
aliceOlmDevice, baseApis, devicesByUser,
)),
]);

await new Promise((resolve) => {
setTimeout(resolve, 200);
});

// after .2s, both tasks should have started, but one should be
// waiting on the other before trying to create a session, so
// claimOneTimeKeys should have only been called once
expect(count).toBe(1);

await promises;

// after waiting for both tasks to complete, the first task should
// have failed, so the second task should have tried to create a
// new session and will have called claimOneTimeKeys
expect(count).toBe(2);
});
});
});
41 changes: 37 additions & 4 deletions src/crypto/OlmDevice.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
Copyright 2016 OpenMarket Ltd
Copyright 2017 New Vector Ltd
Copyright 2017, 2019 New Vector Ltd

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -102,6 +102,10 @@ function OlmDevice(sessionStore, cryptoStore) {
// Keys are strings of form "<senderKey>|<session_id>|<message_index>"
// Values are objects of the form "{id: <event id>, timestamp: <ts>}"
this._inboundGroupSessionMessageIndexes = {};

// Keep track of sessions that we're starting, so that we don't start
// multiple sessions for the same device at the same time.
this._sessionsInProgress = {};
}

/**
Expand Down Expand Up @@ -553,6 +557,15 @@ OlmDevice.prototype.createInboundSession = async function(
* @return {Promise<string[]>} a list of known session ids for the device
*/
OlmDevice.prototype.getSessionIdsForDevice = async function(theirDeviceIdentityKey) {
if (this._sessionsInProgress[theirDeviceIdentityKey]) {
console.log("waiting for session to be created");
try {
await this._sessionsInProgress[theirDeviceIdentityKey];
} catch (e) {
// if the session failed to be created, just fall through and
// return an empty result
}
}
let sessionIds;
await this._cryptoStore.doTxn(
'readonly', [IndexedDBCryptoStore.STORE_SESSIONS],
Expand All @@ -573,10 +586,18 @@ OlmDevice.prototype.getSessionIdsForDevice = async function(theirDeviceIdentityK
*
* @param {string} theirDeviceIdentityKey Curve25519 identity key for the
* remote device
* @param {boolean} nowait Don't wait for an in-progress session to complete.
* This should only be set to true of the calling function is the function
* that marked the session as being in-progress.
* @return {Promise<?string>} session id, or null if no established session
*/
OlmDevice.prototype.getSessionIdForDevice = async function(theirDeviceIdentityKey) {
const sessionInfos = await this.getSessionInfoForDevice(theirDeviceIdentityKey);
OlmDevice.prototype.getSessionIdForDevice = async function(
theirDeviceIdentityKey, nowait,
) {
const sessionInfos = await this.getSessionInfoForDevice(
theirDeviceIdentityKey, nowait,
);

if (sessionInfos.length === 0) {
return null;
}
Expand Down Expand Up @@ -611,9 +632,21 @@ OlmDevice.prototype.getSessionIdForDevice = async function(theirDeviceIdentityKe
* message and is therefore past the pre-key stage), and 'sessionId'.
*
* @param {string} deviceIdentityKey Curve25519 identity key for the device
* @param {boolean} nowait Don't wait for an in-progress session to complete.
* This should only be set to true of the calling function is the function
* that marked the session as being in-progress.
* @return {Array.<{sessionId: string, hasReceivedMessage: Boolean}>}
*/
OlmDevice.prototype.getSessionInfoForDevice = async function(deviceIdentityKey) {
OlmDevice.prototype.getSessionInfoForDevice = async function(deviceIdentityKey, nowait) {
if (this._sessionsInProgress[deviceIdentityKey] && !nowait) {
logger.log("waiting for session to be created");
try {
await this._sessionsInProgress[deviceIdentityKey];
} catch (e) {
// if the session failed to be created, then just fall through and
// return an empty result
}
}
const info = [];

await this._cryptoStore.doTxn(
Expand Down
67 changes: 53 additions & 14 deletions src/crypto/olmlib.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
Copyright 2016 OpenMarket Ltd
Copyright 2019 New Vector Ltd

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -137,6 +138,7 @@ module.exports.ensureOlmSessionsForDevices = async function(
// [userId, deviceId], ...
];
const result = {};
const resolveSession = {};

for (const userId in devicesByUser) {
if (!devicesByUser.hasOwnProperty(userId)) {
Expand All @@ -148,7 +150,36 @@ module.exports.ensureOlmSessionsForDevices = async function(
const deviceInfo = devices[j];
const deviceId = deviceInfo.deviceId;
const key = deviceInfo.getIdentityKey();
const sessionId = await olmDevice.getSessionIdForDevice(key);
if (!olmDevice._sessionsInProgress[key]) {
// pre-emptively mark the session as in-progress to avoid race
// conditions. If we find that we already have a session, then
// we'll resolve
olmDevice._sessionsInProgress[key] = new Promise(
(resolve, reject) => {
resolveSession[key] = {
resolve: (...args) => {
delete olmDevice._sessionsInProgress[key];
resolve(...args);
},
reject: (...args) => {
delete olmDevice._sessionsInProgress[key];
reject(...args);
},
};
},
);
}
const sessionId = await olmDevice.getSessionIdForDevice(
key, resolveSession[key],
);
if (sessionId !== null && resolveSession[key]) {
// we found a session, but we had marked the session as
// in-progress, so unmark it and unblock anything that was
// waiting
delete olmDevice._sessionsInProgress[key];
resolveSession[key].resolve();
delete resolveSession[key];
}
if (sessionId === null || force) {
devicesWithoutSession.push([userId, deviceId]);
}
Expand All @@ -163,16 +194,19 @@ module.exports.ensureOlmSessionsForDevices = async function(
return result;
}

// TODO: this has a race condition - if we try to send another message
// while we are claiming a key, we will end up claiming two and setting up
// two sessions.
//
// That should eventually resolve itself, but it's poor form.

const oneTimeKeyAlgorithm = "signed_curve25519";
const res = await baseApis.claimOneTimeKeys(
devicesWithoutSession, oneTimeKeyAlgorithm,
);
let res;
try {
res = await baseApis.claimOneTimeKeys(
devicesWithoutSession, oneTimeKeyAlgorithm,
);
} catch (e) {
for (const resolver of Object.values(resolveSession)) {
resolver.resolve();
}
logger.log("failed to claim one-time keys", e, devicesWithoutSession);
throw e;
}

const otk_res = res.one_time_keys || {};
const promises = [];
Expand All @@ -185,6 +219,7 @@ module.exports.ensureOlmSessionsForDevices = async function(
for (let j = 0; j < devices.length; j++) {
const deviceInfo = devices[j];
const deviceId = deviceInfo.deviceId;
const key = deviceInfo.getIdentityKey();
if (result[userId][deviceId].sessionId && !force) {
// we already have a result for this device
continue;
Expand All @@ -199,18 +234,22 @@ module.exports.ensureOlmSessionsForDevices = async function(
}

if (!oneTimeKey) {
logger.warn(
"No one-time keys (alg=" + oneTimeKeyAlgorithm +
") for device " + userId + ":" + deviceId,
);
const msg = "No one-time keys (alg=" + oneTimeKeyAlgorithm +
") for device " + userId + ":" + deviceId;
logger.warn(msg);
resolveSession[key].resolve();
continue;
}

promises.push(
_verifyKeyAndStartSession(
olmDevice, oneTimeKey, userId, deviceInfo,
).then((sid) => {
resolveSession[key].resolve(sid);
result[userId][deviceId].sessionId = sid;
}, (e) => {
resolveSession[key].resolve();
throw e;
}),
);
}
Expand Down