diff --git a/CHANGELOG.md b/CHANGELOG.md index fb28276cec2..2a0326fb469 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +Changes in [24.0.0](https://github.com/matrix-org/matrix-js-sdk/releases/tag/v24.0.0) (2023-03-28) +================================================================================================== + +## ๐Ÿ› Bug Fixes + * Changes for matrix-js-sdk v24.0.0 + Changes in [23.5.0](https://github.com/matrix-org/matrix-js-sdk/releases/tag/v23.5.0) (2023-03-15) ================================================================================================== diff --git a/package.json b/package.json index 151458b5a19..15121a78871 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "matrix-js-sdk", - "version": "23.5.0", + "version": "24.0.0", "description": "Matrix Client-Server SDK for Javascript", "engines": { "node": ">=16.0.0" diff --git a/spec/integ/crypto.spec.ts b/spec/integ/crypto.spec.ts index 83e49cca7f1..030fdafb09f 100644 --- a/spec/integ/crypto.spec.ts +++ b/spec/integ/crypto.spec.ts @@ -543,7 +543,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, // if we're using the old crypto impl, stub out some methods in the device manager. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. if (aliceClient.crypto) { - aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); + aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map()); aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; } @@ -603,7 +603,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, // if we're using the old crypto impl, stub out some methods in the device manager. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. if (aliceClient.crypto) { - aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); + aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map()); aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; } @@ -671,7 +671,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, // if we're using the old crypto impl, stub out some methods in the device manager. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. if (aliceClient.crypto) { - aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); + aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map()); aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; } @@ -1027,8 +1027,8 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, throw new Error("sendTextMessage succeeded on an unknown device"); } catch (e) { expect((e as any).name).toEqual("UnknownDeviceError"); - expect(Object.keys((e as any).devices)).toEqual([aliceClient.getUserId()!]); - expect(Object.keys((e as any)?.devices[aliceClient.getUserId()!])).toEqual(["DEVICE_ID"]); + expect([...(e as any).devices.keys()]).toEqual([aliceClient.getUserId()!]); + expect((e as any).devices.get(aliceClient.getUserId()!).has("DEVICE_ID")); } // mark the device as known, and resend. @@ -1099,7 +1099,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, // if we're using the old crypto impl, stub out some methods in the device manager. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. if (aliceClient.crypto) { - aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); + aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map()); aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; } @@ -1255,7 +1255,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, // if we're using the old crypto impl, stub out some methods in the device manager. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. if (aliceClient.crypto) { - aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve({}); + aliceClient.crypto.deviceList.downloadKeys = () => Promise.resolve(new Map()); aliceClient.crypto.deviceList.getUserByIdentityKey = () => "@bob:xyz"; } @@ -1322,7 +1322,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string, // if we're using the old crypto impl, stub out some methods in the device manager. // TODO: replace this with intercepts of the /keys/query endpoint to make it impl agnostic. if (aliceClient.crypto) { - aliceClient.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); + aliceClient.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map()); aliceClient.crypto!.deviceList.getDeviceByIdentityKey = () => device; aliceClient.crypto!.deviceList.getUserByIdentityKey = () => beccaTestClient.client.getUserId()!; } diff --git a/spec/integ/matrix-client-methods.spec.ts b/spec/integ/matrix-client-methods.spec.ts index 21d224150d6..e8896c518ec 100644 --- a/spec/integ/matrix-client-methods.spec.ts +++ b/spec/integ/matrix-client-methods.spec.ts @@ -603,14 +603,14 @@ describe("MatrixClient", function () { }); const prom = client!.downloadKeys(["boris", "chaz"]).then(function (res) { - assertObjectContains(res.boris.dev1, { + assertObjectContains(res.get("boris")!.get("dev1")!, { verified: 0, // DeviceVerification.UNVERIFIED keys: { "ed25519:dev1": ed25519key }, algorithms: ["1"], unsigned: { abc: "def" }, }); - assertObjectContains(res.chaz.dev2, { + assertObjectContains(res.get("chaz")!.get("dev2")!, { verified: 0, // DeviceVerification.UNVERIFIED keys: { "ed25519:dev2": ed25519key }, algorithms: ["2"], diff --git a/spec/integ/olm-encryption-spec.ts b/spec/integ/olm-encryption-spec.ts index b6ce85492d4..171cb3fa0cc 100644 --- a/spec/integ/olm-encryption-spec.ts +++ b/spec/integ/olm-encryption-spec.ts @@ -472,7 +472,7 @@ describe("MatrixClient crypto", () => { aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} }); await aliTestClient.start(); await bobTestClient.start(); - bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); + bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map()); await firstSync(aliTestClient); await aliEnablesEncryption(); await aliSendsFirstMessage(); @@ -483,7 +483,7 @@ describe("MatrixClient crypto", () => { aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} }); await aliTestClient.start(); await bobTestClient.start(); - bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); + bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map()); await firstSync(aliTestClient); await aliEnablesEncryption(); await aliSendsFirstMessage(); @@ -545,7 +545,7 @@ describe("MatrixClient crypto", () => { aliTestClient.expectKeyQuery({ device_keys: { [aliUserId]: {} }, failures: {} }); await aliTestClient.start(); await bobTestClient.start(); - bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); + bobTestClient.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map()); await firstSync(aliTestClient); await aliEnablesEncryption(); await aliSendsFirstMessage(); diff --git a/spec/test-utils/webrtc.ts b/spec/test-utils/webrtc.ts index ba93a6c1c97..aaac27ff451 100644 --- a/spec/test-utils/webrtc.ts +++ b/spec/test-utils/webrtc.ts @@ -30,6 +30,7 @@ import { RoomState, RoomStateEvent, RoomStateEventHandlerMap, + SendToDeviceContentMap, } from "../../src"; import { TypedEventEmitter } from "../../src/models/typed-event-emitter"; import { ReEmitter } from "../../src/ReEmitter"; @@ -443,11 +444,7 @@ export class MockCallMatrixClient extends TypedEventEmitter(); public sendToDevice = jest.fn< Promise<{}>, - [ - eventType: string, - contentMap: { [userId: string]: { [deviceId: string]: Record } }, - txnId?: string, - ] + [eventType: string, contentMap: SendToDeviceContentMap, txnId?: string] >(); public isInitialSyncComplete(): boolean { diff --git a/spec/unit/crypto.spec.ts b/spec/unit/crypto.spec.ts index ec1c660b7a1..50c9aa854fb 100644 --- a/spec/unit/crypto.spec.ts +++ b/spec/unit/crypto.spec.ts @@ -405,7 +405,7 @@ describe("Crypto", function () { // the first message can't be decrypted yet, but the second one // can let ksEvent = await keyshareEventForEvent(aliceClient, events[1], 1); - bobClient.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); + bobClient.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map()); bobClient.crypto!.deviceList.getUserByIdentityKey = () => "@alice:example.com"; await bobDecryptor.onRoomKeyEvent(ksEvent); await decryptEventsPromise; @@ -1039,7 +1039,7 @@ describe("Crypto", function () { beforeEach(async () => { ensureOlmSessionsForDevices = jest.spyOn(olmlib, "ensureOlmSessionsForDevices"); - ensureOlmSessionsForDevices.mockResolvedValue({}); + ensureOlmSessionsForDevices.mockResolvedValue(new Map()); encryptMessageForDevice = jest.spyOn(olmlib, "encryptMessageForDevice"); encryptMessageForDevice.mockImplementation(async (...[result, , , , , , payload]) => { result.plaintext = { type: 0, body: JSON.stringify(payload) }; diff --git a/spec/unit/crypto/algorithms/megolm.spec.ts b/spec/unit/crypto/algorithms/megolm.spec.ts index 3931ee3fc5e..e5e92aebdb2 100644 --- a/spec/unit/crypto/algorithms/megolm.spec.ts +++ b/spec/unit/crypto/algorithms/megolm.spec.ts @@ -34,6 +34,7 @@ import { ClientEvent, MatrixClient, RoomMember } from "../../../../src"; import { DeviceInfo, IDevice } from "../../../../src/crypto/deviceinfo"; import { DeviceTrustLevel } from "../../../../src/crypto/CrossSigning"; import { MegolmEncryption as MegolmEncryptionClass } from "../../../../src/crypto/algorithms/megolm"; +import { recursiveMapToObject } from "../../../../src/utils"; import { sleep } from "../../../../src/utils"; const MegolmDecryption = algorithms.DECRYPTION_CLASSES.get("m.megolm.v1.aes-sha2")!; @@ -183,14 +184,22 @@ describe("MegolmDecryption", function () { const deviceInfo = {} as DeviceInfo; mockCrypto.getStoredDevice.mockReturnValue(deviceInfo); - mockOlmLib.ensureOlmSessionsForDevices.mockResolvedValue({ - "@alice:foo": { - alidevice: { - sessionId: "alisession", - device: new DeviceInfo("alidevice"), - }, - }, - }); + mockOlmLib.ensureOlmSessionsForDevices.mockResolvedValue( + new Map([ + [ + "@alice:foo", + new Map([ + [ + "alidevice", + { + sessionId: "alisession", + device: new DeviceInfo("alidevice"), + }, + ], + ]), + ], + ]), + ); const awaitEncryptForDevice = new Promise((res, rej) => { mockOlmLib.encryptMessageForDevice.mockImplementation(() => { @@ -357,11 +366,7 @@ describe("MegolmDecryption", function () { } as unknown as DeviceInfo; mockCrypto.downloadKeys.mockReturnValue( - Promise.resolve({ - "@alice:home.server": { - aliceDevice: aliceDeviceInfo, - }, - }), + Promise.resolve(new Map([["@alice:home.server", new Map([["aliceDevice", aliceDeviceInfo]])]])), ); mockCrypto.checkDeviceTrust.mockReturnValue({ @@ -523,23 +528,32 @@ describe("MegolmDecryption", function () { let megolm: MegolmEncryptionClass; let room: jest.Mocked; - const deviceMap: DeviceInfoMap = { - "user-a": { - "device-a": new DeviceInfo("device-a"), - "device-b": new DeviceInfo("device-b"), - "device-c": new DeviceInfo("device-c"), - }, - "user-b": { - "device-d": new DeviceInfo("device-d"), - "device-e": new DeviceInfo("device-e"), - "device-f": new DeviceInfo("device-f"), - }, - "user-c": { - "device-g": new DeviceInfo("device-g"), - "device-h": new DeviceInfo("device-h"), - "device-i": new DeviceInfo("device-i"), - }, - }; + const deviceMap: DeviceInfoMap = new Map([ + [ + "user-a", + new Map([ + ["device-a", new DeviceInfo("device-a")], + ["device-b", new DeviceInfo("device-b")], + ["device-c", new DeviceInfo("device-c")], + ]), + ], + [ + "user-b", + new Map([ + ["device-d", new DeviceInfo("device-d")], + ["device-e", new DeviceInfo("device-e")], + ["device-f", new DeviceInfo("device-f")], + ]), + ], + [ + "user-c", + new Map([ + ["device-g", new DeviceInfo("device-g")], + ["device-h", new DeviceInfo("device-h")], + ["device-i", new DeviceInfo("device-i")], + ]), + ], + ]); beforeEach(() => { room = testUtils.mock(Room, "Room") as jest.Mocked; @@ -572,8 +586,8 @@ describe("MegolmDecryption", function () { //@ts-ignore private member access, gross await megolm.encryptionPreparation?.promise; - for (const userId in deviceMap) { - for (const deviceId in deviceMap[userId]) { + for (const [userId, devices] of deviceMap) { + for (const deviceId of devices.keys()) { expect(mockCrypto.checkDeviceTrust).toHaveBeenCalledWith(userId, deviceId); } } @@ -658,20 +672,20 @@ describe("MegolmDecryption", function () { expect(aliceClient.sendToDevice).toHaveBeenCalled(); const [msgtype, contentMap] = mocked(aliceClient.sendToDevice).mock.calls[0]; expect(msgtype).toMatch(/^(org.matrix|m).room_key.withheld$/); - delete contentMap["@bob:example.com"].bobdevice1.session_id; - delete contentMap["@bob:example.com"].bobdevice1["org.matrix.msgid"]; - delete contentMap["@bob:example.com"].bobdevice2.session_id; - delete contentMap["@bob:example.com"].bobdevice2["org.matrix.msgid"]; - expect(contentMap).toStrictEqual({ - "@bob:example.com": { - bobdevice1: { + delete contentMap.get("@bob:example.com")?.get("bobdevice1")?.["session_id"]; + delete contentMap.get("@bob:example.com")?.get("bobdevice1")?.["org.matrix.msgid"]; + delete contentMap.get("@bob:example.com")?.get("bobdevice2")?.["session_id"]; + delete contentMap.get("@bob:example.com")?.get("bobdevice2")?.["org.matrix.msgid"]; + expect(recursiveMapToObject(contentMap)).toStrictEqual({ + ["@bob:example.com"]: { + ["bobdevice1"]: { algorithm: "m.megolm.v1.aes-sha2", room_id: roomId, code: "m.unverified", reason: "The sender has disabled encrypting to unverified devices.", sender_key: aliceDevice.deviceCurve25519Key, }, - bobdevice2: { + ["bobdevice2"]: { algorithm: "m.megolm.v1.aes-sha2", room_id: roomId, code: "m.blacklisted", @@ -839,10 +853,10 @@ describe("MegolmDecryption", function () { expect(aliceClient.sendToDevice).toHaveBeenCalled(); const [msgtype, contentMap] = mocked(aliceClient.sendToDevice).mock.calls[0]; expect(msgtype).toMatch(/^(org.matrix|m).room_key.withheld$/); - delete contentMap["@bob:example.com"]["bobdevice"]["org.matrix.msgid"]; - expect(contentMap).toStrictEqual({ - "@bob:example.com": { - bobdevice: { + delete contentMap.get("@bob:example.com")?.get("bobdevice")?.["org.matrix.msgid"]; + expect(recursiveMapToObject(contentMap)).toStrictEqual({ + ["@bob:example.com"]: { + ["bobdevice"]: { algorithm: "m.megolm.v1.aes-sha2", code: "m.no_olm", reason: "Unable to establish a secure channel.", diff --git a/spec/unit/crypto/algorithms/olm.spec.ts b/spec/unit/crypto/algorithms/olm.spec.ts index 6099ccceb6c..644bb96e391 100644 --- a/spec/unit/crypto/algorithms/olm.spec.ts +++ b/spec/unit/crypto/algorithms/olm.spec.ts @@ -146,18 +146,21 @@ describe("OlmDevice", function () { }); }, } as unknown as MockedObject; - const devicesByUser = { - "@bob:example.com": [ - DeviceInfo.fromStorage( - { - keys: { - "curve25519:ABCDEFG": "akey", + const devicesByUser = new Map([ + [ + "@bob:example.com", + [ + DeviceInfo.fromStorage( + { + keys: { + "curve25519:ABCDEFG": "akey", + }, }, - }, - "ABCDEFG", - ), + "ABCDEFG", + ), + ], ], - }; + ]); // start two tasks that try to ensure that there's an olm session const promises = Promise.all([ @@ -218,12 +221,8 @@ describe("OlmDevice", function () { // There's no required ordering of devices per user, so here we // create two different orderings so that each task reserves a // device the other task needs before continuing. - const devicesByUserAB = { - "@bob:example.com": [deviceBobA, deviceBobB], - }; - const devicesByUserBA = { - "@bob:example.com": [deviceBobB, deviceBobA], - }; + const devicesByUserAB = new Map([["@bob:example.com", [deviceBobA, deviceBobB]]]); + const devicesByUserBA = new Map([["@bob:example.com", [deviceBobB, deviceBobA]]]); const task1 = alwaysSucceed(olmlib.ensureOlmSessionsForDevices(aliceOlmDevice, baseApis, devicesByUserAB)); diff --git a/spec/unit/crypto/secrets.spec.ts b/spec/unit/crypto/secrets.spec.ts index dc86642578d..2a31f856c84 100644 --- a/spec/unit/crypto/secrets.spec.ts +++ b/spec/unit/crypto/secrets.spec.ts @@ -45,7 +45,7 @@ async function makeTestClient( await client.initCrypto(); // No need to download keys for these tests - jest.spyOn(client.crypto!, "downloadKeys").mockResolvedValue({}); + jest.spyOn(client.crypto!, "downloadKeys").mockResolvedValue(new Map()); return client; } @@ -274,7 +274,7 @@ describe("Secrets", function () { Object.values(otks)[0], ); - osborne2.client.crypto!.deviceList.downloadKeys = () => Promise.resolve({}); + osborne2.client.crypto!.deviceList.downloadKeys = () => Promise.resolve(new Map()); osborne2.client.crypto!.deviceList.getUserByIdentityKey = () => "@alice:example.com"; const request = await secretStorage.request("foo", ["VAX"]); diff --git a/spec/unit/crypto/verification/sas.spec.ts b/spec/unit/crypto/verification/sas.spec.ts index 3c4f224428e..c78dfeba237 100644 --- a/spec/unit/crypto/verification/sas.spec.ts +++ b/spec/unit/crypto/verification/sas.spec.ts @@ -121,12 +121,12 @@ describe("SAS verification", function () { alice.client.crypto!.deviceList.storeDevicesForUser("@bob:example.com", BOB_DEVICES); alice.client.downloadKeys = () => { - return Promise.resolve({}); + return Promise.resolve(new Map()); }; bob.client.crypto!.deviceList.storeDevicesForUser("@alice:example.com", ALICE_DEVICES); bob.client.downloadKeys = () => { - return Promise.resolve({}); + return Promise.resolve(new Map()); }; aliceSasEvent = null; @@ -176,6 +176,7 @@ describe("SAS verification", function () { } }); }); + afterEach(async () => { await Promise.all([alice.stop(), bob.stop()]); @@ -186,10 +187,14 @@ describe("SAS verification", function () { let macMethod; let keyAgreement; const origSendToDevice = bob.client.sendToDevice.bind(bob.client); - bob.client.sendToDevice = function (type, map) { + bob.client.sendToDevice = async (type, map) => { if (type === "m.key.verification.accept") { - macMethod = map[alice.client.getUserId()!][alice.client.deviceId!].message_authentication_code; - keyAgreement = map[alice.client.getUserId()!][alice.client.deviceId!].key_agreement_protocol; + macMethod = map + .get(alice.client.getUserId()!) + ?.get(alice.client.deviceId!)?.message_authentication_code; + keyAgreement = map + .get(alice.client.getUserId()!) + ?.get(alice.client.deviceId!)?.key_agreement_protocol; } return origSendToDevice(type, map); }; @@ -237,7 +242,7 @@ describe("SAS verification", function () { // has, since it is the same object. If this does not // happen, the verification will fail due to a hash // commitment mismatch. - map[bob.client.getUserId()!][bob.client.deviceId!].message_authentication_codes = [ + map.get(bob.client.getUserId()!)!.get(bob.client.deviceId!)!.message_authentication_codes = [ "hkdf-hmac-sha256", ]; } @@ -246,7 +251,9 @@ describe("SAS verification", function () { const bobOrigSendToDevice = bob.client.sendToDevice.bind(bob.client); bob.client.sendToDevice = (type, map) => { if (type === "m.key.verification.accept") { - macMethod = map[alice.client.getUserId()!][alice.client.deviceId!].message_authentication_code; + macMethod = map + .get(alice.client.getUserId()!)! + .get(alice.client.deviceId!)!.message_authentication_code; } return bobOrigSendToDevice(type, map); }; @@ -291,14 +298,18 @@ describe("SAS verification", function () { // has, since it is the same object. If this does not // happen, the verification will fail due to a hash // commitment mismatch. - map[bob.client.getUserId()!][bob.client.deviceId!].message_authentication_codes = ["hmac-sha256"]; + map.get(bob.client.getUserId()!)!.get(bob.client.deviceId!)!.message_authentication_codes = [ + "hmac-sha256", + ]; } return aliceOrigSendToDevice(type, map); }; const bobOrigSendToDevice = bob.client.sendToDevice.bind(bob.client); bob.client.sendToDevice = (type, map) => { if (type === "m.key.verification.accept") { - macMethod = map[alice.client.getUserId()!][alice.client.deviceId!].message_authentication_code; + macMethod = map + .get(alice.client.getUserId()!)! + .get(alice.client.deviceId!)!.message_authentication_code; } return bobOrigSendToDevice(type, map); }; @@ -454,7 +465,7 @@ describe("SAS verification", function () { ); }; alice.client.downloadKeys = () => { - return Promise.resolve({}); + return Promise.resolve(new Map()); }; bob.client.crypto!.setDeviceVerification = jest.fn(); @@ -472,7 +483,7 @@ describe("SAS verification", function () { return "bob+base64+ed25519+key"; }; bob.client.downloadKeys = () => { - return Promise.resolve({}); + return Promise.resolve(new Map()); }; aliceSasEvent = null; diff --git a/spec/unit/crypto/verification/util.ts b/spec/unit/crypto/verification/util.ts index 8b478660c1f..6454fd6004a 100644 --- a/spec/unit/crypto/verification/util.ts +++ b/spec/unit/crypto/verification/util.ts @@ -20,7 +20,7 @@ import { IContent, MatrixEvent } from "../../../../src/models/event"; import { IRoomTimelineData } from "../../../../src/models/event-timeline-set"; import { Room, RoomEvent } from "../../../../src/models/room"; import { logger } from "../../../../src/logger"; -import { MatrixClient, ClientEvent, ICreateClientOpts } from "../../../../src/client"; +import { MatrixClient, ClientEvent, ICreateClientOpts, SendToDeviceContentMap } from "../../../../src/client"; interface UserInfo { userId: string; @@ -36,16 +36,16 @@ export async function makeTestClients( const clientMap: Record> = {}; const makeSendToDevice = (matrixClient: MatrixClient): MatrixClient["sendToDevice"] => - async (type, map) => { + async (type: string, contentMap: SendToDeviceContentMap) => { // logger.log(this.getUserId(), "sends", type, map); - for (const [userId, devMap] of Object.entries(map)) { + for (const [userId, deviceMessages] of contentMap) { if (userId in clientMap) { - for (const [deviceId, msg] of Object.entries(devMap)) { + for (const [deviceId, message] of deviceMessages) { if (deviceId in clientMap[userId]) { const event = new MatrixEvent({ sender: matrixClient.getUserId()!, type: type, - content: msg, + content: message, }); const client = clientMap[userId][deviceId]; const decryptionPromise = event.isEncrypted() diff --git a/spec/unit/crypto/verification/verification_request.spec.ts b/spec/unit/crypto/verification/verification_request.spec.ts index ea6919216a3..2f42e54d8ca 100644 --- a/spec/unit/crypto/verification/verification_request.spec.ts +++ b/spec/unit/crypto/verification/verification_request.spec.ts @@ -25,6 +25,7 @@ import { IContent, MatrixEvent } from "../../../../src/models/event"; import { MatrixClient } from "../../../../src/client"; import { IVerificationChannel } from "../../../../src/crypto/verification/request/Channel"; import { VerificationBase } from "../../../../src/crypto/verification/Base"; +import { MapWithDefault } from "../../../../src/utils"; type MockClient = MatrixClient & { popEvents: () => MatrixEvent[]; @@ -33,7 +34,9 @@ type MockClient = MatrixClient & { function makeMockClient(userId: string, deviceId: string): MockClient { let counter = 1; let events: MatrixEvent[] = []; - const deviceEvents: Record> = {}; + const deviceEvents: MapWithDefault> = new MapWithDefault( + () => new MapWithDefault(() => []), + ); return { getUserId() { return userId; @@ -58,15 +61,11 @@ function makeMockClient(userId: string, deviceId: string): MockClient { return Promise.resolve({ event_id: eventId }); }, - sendToDevice(type: string, msgMap: Record>) { - for (const userId of Object.keys(msgMap)) { - const deviceMap = msgMap[userId]; - for (const deviceId of Object.keys(deviceMap)) { - const content = deviceMap[deviceId]; + sendToDevice(type: string, msgMap: Map>) { + for (const [userId, deviceMessages] of msgMap) { + for (const [deviceId, content] of deviceMessages) { const event = new MatrixEvent({ content, type }); - deviceEvents[userId] = deviceEvents[userId] || {}; - deviceEvents[userId][deviceId] = deviceEvents[userId][deviceId] || []; - deviceEvents[userId][deviceId].push(event); + deviceEvents.getOrCreate(userId).getOrCreate(deviceId).push(event); } } return Promise.resolve({}); @@ -79,14 +78,9 @@ function makeMockClient(userId: string, deviceId: string): MockClient { return e; }, - // @ts-ignore special testing fn popDeviceEvents(userId: string, deviceId: string): MatrixEvent[] { - const forDevice = deviceEvents[userId]; - const events = forDevice && forDevice[deviceId]; - const result = events || []; - if (events) { - delete forDevice[deviceId]; - } + const result = deviceEvents.get(userId)?.get(deviceId) || []; + deviceEvents?.get(userId)?.delete(deviceId); return result; }, } as unknown as MockClient; diff --git a/spec/unit/embedded.spec.ts b/spec/unit/embedded.spec.ts index ef5215e20f0..caab40ac056 100644 --- a/spec/unit/embedded.spec.ts +++ b/spec/unit/embedded.spec.ts @@ -204,9 +204,14 @@ describe("RoomWidgetClient", () => { }); describe("to-device messages", () => { - const unencryptedContentMap = { - "@alice:example.org": { "*": { hello: "alice!" } }, - "@bob:example.org": { bobDesktop: { hello: "bob!" } }, + const unencryptedContentMap = new Map([ + ["@alice:example.org", new Map([["*", { hello: "alice!" }]])], + ["@bob:example.org", new Map([["bobDesktop", { hello: "bob!" }]])], + ]); + + const expectedRequestData = { + ["@alice:example.org"]: { ["*"]: { hello: "alice!" } }, + ["@bob:example.org"]: { ["bobDesktop"]: { hello: "bob!" } }, }; it("sends unencrypted (sendToDevice)", async () => { @@ -214,7 +219,7 @@ describe("RoomWidgetClient", () => { expect(widgetApi.requestCapabilityToSendToDevice).toHaveBeenCalledWith("org.example.foo"); await client.sendToDevice("org.example.foo", unencryptedContentMap); - expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, unencryptedContentMap); + expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, expectedRequestData); }); it("sends unencrypted (queueToDevice)", async () => { @@ -229,7 +234,7 @@ describe("RoomWidgetClient", () => { ], }; await client.queueToDevice(batch); - expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, unencryptedContentMap); + expect(widgetApi.sendToDevice).toHaveBeenCalledWith("org.example.foo", false, expectedRequestData); }); it("sends encrypted (encryptAndSendToDevices)", async () => { diff --git a/spec/unit/stores/memory.spec.ts b/spec/unit/stores/memory.spec.ts index fac3267dbba..30d6cff45f9 100644 --- a/spec/unit/stores/memory.spec.ts +++ b/spec/unit/stores/memory.spec.ts @@ -59,7 +59,7 @@ describe("MemoryStore", () => { await store.deleteAllData(); // empty object - expect(store.accountData).toEqual({}); + expect(store.accountData).toEqual(new Map()); }); }); }); diff --git a/spec/unit/utils.spec.ts b/spec/unit/utils.spec.ts index 98a4ddc3081..448bed0a5c2 100644 --- a/spec/unit/utils.spec.ts +++ b/spec/unit/utils.spec.ts @@ -24,9 +24,12 @@ import { lexicographicCompare, nextString, prevString, + recursiveMapToObject, simpleRetryOperation, stringToBase, sortEventsByLatestContentTimestamp, + safeSet, + MapWithDefault, } from "../../src/utils"; import { logger } from "../../src/logger"; import { mkMessage } from "../test-utils/test-utils"; @@ -606,6 +609,105 @@ describe("utils", function () { }); }); + describe("recursiveMapToObject", () => { + it.each([ + // empty map + { + map: new Map(), + expected: {}, + }, + // one level map + { + map: new Map([ + ["key1", "value 1"], + ["key2", 23], + ["key3", undefined], + ["key4", null], + ["key5", [1, 2, 3]], + ]), + expected: { key1: "value 1", key2: 23, key3: undefined, key4: null, key5: [1, 2, 3] }, + }, + // two level map + { + map: new Map([ + [ + "key1", + new Map([ + ["key1_1", "value 1"], + ["key1_2", "value 1.2"], + ]), + ], + ["key2", "value 2"], + ]), + expected: { key1: { key1_1: "value 1", key1_2: "value 1.2" }, key2: "value 2" }, + }, + // multi level map + { + map: new Map([ + ["key1", new Map([["key1_1", new Map([["key1_1_1", "value 1.1.1"]])]])], + ]), + expected: { key1: { key1_1: { key1_1_1: "value 1.1.1" } } }, + }, + // list of maps + { + map: new Map([ + [ + "key1", + [new Map([["key1_1", "value 1.1"]]), new Map([["key1_2", "value 1.2"]])], + ], + ]), + expected: { key1: [{ key1_1: "value 1.1" }, { key1_2: "value 1.2" }] }, + }, + // map โ†’ array โ†’ array โ†’ map + { + map: new Map([["key1", [[new Map([["key2", "value 2"]])]]]]), + expected: { + key1: [ + [ + { + key2: "value 2", + }, + ], + ], + }, + }, + ])("%# should convert the value", ({ map, expected }) => { + expect(recursiveMapToObject(map)).toStrictEqual(expected); + }); + }); + + describe("safeSet", () => { + it("should set a value", () => { + const obj = {}; + safeSet(obj, "testProp", "test value"); + expect(obj).toEqual({ testProp: "test value" }); + }); + + it.each(["__proto__", "prototype", "constructor"])("should raise an error when setting ยป%sยซ", (prop) => { + expect(() => { + safeSet({}, prop, "teset value"); + }).toThrow("Trying to modify prototype or constructor"); + }); + }); + + describe("MapWithDefault", () => { + it("getOrCreate should create the value if it does not exist", () => { + const newValue = {}; + const map = new MapWithDefault(() => newValue); + + // undefined before getOrCreate + expect(map.get("test")).toBeUndefined(); + + expect(map.getOrCreate("test")).toBe(newValue); + + // default value after getOrCreate + expect(map.get("test")).toBe(newValue); + + // test that it always returns the same value + expect(map.getOrCreate("test")).toBe(newValue); + }); + }); + describe("sleep", () => { it("resolves", async () => { await utils.sleep(0); diff --git a/spec/unit/webrtc/groupCall.spec.ts b/spec/unit/webrtc/groupCall.spec.ts index 9743b332e14..914a1246a3a 100644 --- a/spec/unit/webrtc/groupCall.spec.ts +++ b/spec/unit/webrtc/groupCall.spec.ts @@ -688,15 +688,15 @@ describe("Group Call", function () { expect(client1.sendToDevice.mock.calls[0][0]).toBe("m.call.invite"); const toDeviceCallContent = client1.sendToDevice.mock.calls[0][1]; - expect(Object.keys(toDeviceCallContent).length).toBe(1); - expect(Object.keys(toDeviceCallContent)[0]).toBe(FAKE_USER_ID_2); + expect(toDeviceCallContent.size).toBe(1); + expect(toDeviceCallContent.has(FAKE_USER_ID_2)).toBe(true); - const toDeviceBobDevices = toDeviceCallContent[FAKE_USER_ID_2]; - expect(Object.keys(toDeviceBobDevices).length).toBe(1); - expect(Object.keys(toDeviceBobDevices)[0]).toBe(FAKE_DEVICE_ID_2); + const toDeviceBobDevices = toDeviceCallContent.get(FAKE_USER_ID_2); + expect(toDeviceBobDevices?.size).toBe(1); + expect(toDeviceBobDevices?.has(FAKE_DEVICE_ID_2)).toBe(true); - const bobDeviceMessage = toDeviceBobDevices[FAKE_DEVICE_ID_2]; - expect(bobDeviceMessage.conf_id).toBe(FAKE_CONF_ID); + const bobDeviceMessage = toDeviceBobDevices?.get(FAKE_DEVICE_ID_2); + expect(bobDeviceMessage?.conf_id).toBe(FAKE_CONF_ID); } finally { await Promise.all([groupCall1.leave(), groupCall2.leave()]); } diff --git a/src/@types/read_receipts.ts b/src/@types/read_receipts.ts index 3032c5934df..759240387fc 100644 --- a/src/@types/read_receipts.ts +++ b/src/@types/read_receipts.ts @@ -38,7 +38,7 @@ export interface CachedReceipt { data: Receipt; } -export type ReceiptCache = { [eventId: string]: CachedReceipt[] }; +export type ReceiptCache = Map; export interface ReceiptContent { [eventId: string]: { @@ -49,11 +49,8 @@ export interface ReceiptContent { } // We will only hold a synthetic receipt if we do not have a real receipt or the synthetic is newer. -export type Receipts = { - [receiptType: string]: { - [userId: string]: [WrappedReceipt | null, WrappedReceipt | null]; // Pair (both nullable) - }; -}; +// map: receipt type โ†’ user Id โ†’ receipt +export type Receipts = Map>; export type CachedReceiptStructure = { eventId: string; diff --git a/src/ToDeviceMessageQueue.ts b/src/ToDeviceMessageQueue.ts index ec5922bb68c..59eada4db22 100644 --- a/src/ToDeviceMessageQueue.ts +++ b/src/ToDeviceMessageQueue.ts @@ -21,6 +21,7 @@ import { MatrixError } from "./http-api"; import { IndexedToDeviceBatch, ToDeviceBatch, ToDeviceBatchWithTxnId, ToDevicePayload } from "./models/ToDeviceMessage"; import { MatrixScheduler } from "./scheduler"; import { SyncState } from "./sync"; +import { MapWithDefault } from "./utils"; const MAX_BATCH_SIZE = 20; @@ -122,12 +123,9 @@ export class ToDeviceMessageQueue { * Attempts to send a batch of to-device messages. */ private async sendBatch(batch: IndexedToDeviceBatch): Promise { - const contentMap: Record> = {}; + const contentMap: MapWithDefault> = new MapWithDefault(() => new Map()); for (const item of batch.batch) { - if (!contentMap[item.userId]) { - contentMap[item.userId] = {}; - } - contentMap[item.userId][item.deviceId] = item.payload; + contentMap.getOrCreate(item.userId).set(item.deviceId, item.payload); } logger.info( diff --git a/src/client.ts b/src/client.ts index 8f861892a69..708abd33636 100644 --- a/src/client.ts +++ b/src/client.ts @@ -37,7 +37,7 @@ import { Filter, IFilterDefinition, IRoomEventFilter } from "./filter"; import { CallEventHandlerEvent, CallEventHandler, CallEventHandlerEventHandlerMap } from "./webrtc/callEventHandler"; import { GroupCallEventHandlerEvent, GroupCallEventHandlerEventHandlerMap } from "./webrtc/groupCallEventHandler"; import * as utils from "./utils"; -import { replaceParam, QueryDict, sleep } from "./utils"; +import { replaceParam, QueryDict, sleep, noUnsafeEventProps } from "./utils"; import { Direction, EventTimeline } from "./models/event-timeline"; import { IActionsObject, PushProcessor } from "./pushprocessor"; import { AutoDiscovery, AutoDiscoveryAction } from "./autodiscovery"; @@ -79,7 +79,7 @@ import { VerificationMethod, IRoomKeyRequestBody, } from "./crypto"; -import { DeviceInfo, IDevice } from "./crypto/deviceinfo"; +import { DeviceInfo } from "./crypto/deviceinfo"; import { decodeRecoveryKey } from "./crypto/recoverykey"; import { keyFromAuthData } from "./crypto/key_passphrase"; import { User, UserEvent, UserEventHandlerMap } from "./models/user"; @@ -206,6 +206,7 @@ import { LocalNotificationSettings } from "./@types/local_notifications"; import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature"; import { CryptoBackend } from "./common-crypto/CryptoBackend"; import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants"; +import { DeviceInfoMap } from "./crypto/DeviceList"; export type Store = IStore; @@ -505,6 +506,8 @@ enum CrossSigningKeyType { export type CrossSigningKeys = Record; +export type SendToDeviceContentMap = Map>>; + export interface ISignedKey { keys: Record; signatures: ISignatures; @@ -2243,7 +2246,7 @@ export class MatrixClient extends TypedEventEmitterdeviceId-\>{@link DeviceInfo} */ - public downloadKeys(userIds: string[], forceDownload?: boolean): Promise>> { + public downloadKeys(userIds: string[], forceDownload?: boolean): Promise { if (!this.crypto) { return Promise.reject(new Error("End-to-end encryption disabled")); } @@ -3760,9 +3763,9 @@ export class MatrixClient extends TypedEventEmitter = {}; - for (const [userId, devices] of Object.entries(deviceInfos)) { - devicesByUser[userId] = Object.values(devices); + const devicesByUser: Map = new Map(); + for (const [userId, devices] of deviceInfos) { + devicesByUser.set(userId, Array.from(devices.values())); } // XXX: Private member access @@ -5977,6 +5980,8 @@ export class MatrixClient extends TypedEventEmitter { if (res.state) { const roomState = eventTimeline.getState(dir)!; - const stateEvents = res.state.map(this.getEventMapper()); + const stateEvents = res.state.filter(noUnsafeEventProps).map(this.getEventMapper()); roomState.setUnknownStateEvents(stateEvents); } const token = res.end; - const matrixEvents = res.chunk.map(this.getEventMapper()); + const matrixEvents = res.chunk.filter(noUnsafeEventProps).map(this.getEventMapper()); const timelineSet = eventTimeline.getTimelineSet(); timelineSet.addEventsToTimeline(matrixEvents, backwards, eventTimeline, token); @@ -6059,7 +6064,7 @@ export class MatrixClient extends TypedEventEmitter { const mapper = this.getEventMapper(); - const matrixEvents = res.chunk.map(mapper); + const matrixEvents = res.chunk.filter(noUnsafeEventProps).map(mapper); // Process latest events first for (const event of matrixEvents.slice().reverse()) { @@ -6107,11 +6112,11 @@ export class MatrixClient extends TypedEventEmitter { if (res.state) { const roomState = eventTimeline.getState(dir)!; - const stateEvents = res.state.map(this.getEventMapper()); + const stateEvents = res.state.filter(noUnsafeEventProps).map(this.getEventMapper()); roomState.setUnknownStateEvents(stateEvents); } const token = res.end; - const matrixEvents = res.chunk.map(this.getEventMapper()); + const matrixEvents = res.chunk.filter(noUnsafeEventProps).map(this.getEventMapper()); const timelineSet = eventTimeline.getTimelineSet(); const [timelineEvents] = room.partitionThreadedEvents(matrixEvents); @@ -9129,24 +9134,22 @@ export class MatrixClient extends TypedEventEmitter } }, - txnId?: string, - ): Promise<{}> { + public sendToDevice(eventType: string, contentMap: SendToDeviceContentMap, txnId?: string): Promise<{}> { const path = utils.encodeUri("/sendToDevice/$eventType/$txnId", { $eventType: eventType, $txnId: txnId ? txnId : this.makeTxnId(), }); const body = { - messages: contentMap, + messages: utils.recursiveMapToObject(contentMap), }; - const targets = Object.keys(contentMap).reduce>((obj, key) => { - obj[key] = Object.keys(contentMap[key]); - return obj; - }, {}); + const targets = new Map(); + + for (const [userId, deviceMessages] of contentMap) { + targets.set(userId, Array.from(deviceMessages.keys())); + } + logger.log(`PUT ${path}`, targets); return this.http.authedRequest(Method.Put, path, undefined, body); diff --git a/src/crypto/DeviceList.ts b/src/crypto/DeviceList.ts index 292cf159b5a..a1ff0ebf144 100644 --- a/src/crypto/DeviceList.ts +++ b/src/crypto/DeviceList.ts @@ -58,7 +58,8 @@ export enum TrackingStatus { UpToDate, } -export type DeviceInfoMap = Record>; +// user-Id โ†’ device-Id โ†’ DeviceInfo +export type DeviceInfoMap = Map>; type EmittedEvents = CryptoEvent.WillUpdateDevices | CryptoEvent.DevicesUpdated | CryptoEvent.UserCrossSigningUpdated; @@ -301,13 +302,13 @@ export class DeviceList extends TypedEventEmitterdeviceId-\>{@link DeviceInfo}. */ private getDevicesFromStore(userIds: string[]): DeviceInfoMap { - const stored: DeviceInfoMap = {}; - userIds.forEach((u) => { - stored[u] = {}; - const devices = this.getStoredDevicesForUser(u) || []; - devices.forEach(function (dev) { - stored[u][dev.deviceId] = dev; + const stored: DeviceInfoMap = new Map(); + userIds.forEach((userId) => { + const deviceMap = new Map(); + this.getStoredDevicesForUser(userId)?.forEach(function (device) { + deviceMap.set(device.deviceId, device); }); + stored.set(userId, deviceMap); }); return stored; } diff --git a/src/crypto/EncryptionSetup.ts b/src/crypto/EncryptionSetup.ts index 7fe6d6457c3..f0cf4bf40bd 100644 --- a/src/crypto/EncryptionSetup.ts +++ b/src/crypto/EncryptionSetup.ts @@ -61,7 +61,7 @@ export class EncryptionSetupBuilder { * @param accountData - pre-existing account data, will only be read, not written. * @param delegateCryptoCallbacks - crypto callbacks to delegate to if the key isn't in cache yet */ - public constructor(accountData: Record, delegateCryptoCallbacks?: ICryptoCallbacks) { + public constructor(accountData: Map, delegateCryptoCallbacks?: ICryptoCallbacks) { this.accountDataClientAdapter = new AccountDataClientAdapter(accountData); this.crossSigningCallbacks = new CrossSigningCallbacks(); this.ssssCryptoCallbacks = new SSSSCryptoCallbacks(delegateCryptoCallbacks); @@ -246,7 +246,7 @@ class AccountDataClientAdapter /** * @param existingValues - existing account data */ - public constructor(private readonly existingValues: Record) { + public constructor(private readonly existingValues: Map) { super(); } @@ -265,7 +265,7 @@ class AccountDataClientAdapter if (modifiedValue) { return modifiedValue; } - const existingValue = this.existingValues[type]; + const existingValue = this.existingValues.get(type); if (existingValue) { return existingValue.getContent(); } diff --git a/src/crypto/OutgoingRoomKeyRequestManager.ts b/src/crypto/OutgoingRoomKeyRequestManager.ts index 27bf8389bc0..4628b3e8dd9 100644 --- a/src/crypto/OutgoingRoomKeyRequestManager.ts +++ b/src/crypto/OutgoingRoomKeyRequestManager.ts @@ -21,6 +21,7 @@ import { MatrixClient } from "../client"; import { IRoomKeyRequestBody, IRoomKeyRequestRecipient } from "./index"; import { CryptoStore, OutgoingRoomKeyRequest } from "./store/base"; import { EventType, ToDeviceMessageId } from "../@types/event"; +import { MapWithDefault } from "../utils"; /** * Internal module. Management of outgoing room key requests. @@ -460,15 +461,13 @@ export class OutgoingRoomKeyRequestManager { recipients: IRoomKeyRequestRecipient[], txnId?: string, ): Promise<{}> { - const contentMap: Record>> = {}; + const contentMap = new MapWithDefault>>(() => new Map()); for (const recip of recipients) { - if (!contentMap[recip.userId]) { - contentMap[recip.userId] = {}; - } - contentMap[recip.userId][recip.deviceId] = { + const userDeviceMap = contentMap.getOrCreate(recip.userId); + userDeviceMap.set(recip.deviceId, { ...message, [ToDeviceMessageId]: uuidv4(), - }; + }); } return this.baseApis.sendToDevice(EventType.RoomKeyRequest, contentMap, txnId); diff --git a/src/crypto/SecretStorage.ts b/src/crypto/SecretStorage.ts index c0aab32b880..f5e3fb59ce9 100644 --- a/src/crypto/SecretStorage.ts +++ b/src/crypto/SecretStorage.ts @@ -367,13 +367,11 @@ export class SecretStorage { requesting_device_id: this.baseApis.deviceId, request_id: requestId, }; - const toDevice: Record = {}; + const toDevice: Map = new Map(); for (const device of devices) { - toDevice[device] = cancelData; + toDevice.set(device, cancelData); } - this.baseApis.sendToDevice("m.secret.request", { - [this.baseApis.getUserId()!]: toDevice, - }); + this.baseApis.sendToDevice("m.secret.request", new Map([[this.baseApis.getUserId()!, toDevice]])); // and reject the promise so that anyone waiting on it will be // notified @@ -388,14 +386,12 @@ export class SecretStorage { request_id: requestId, [ToDeviceMessageId]: uuidv4(), }; - const toDevice: Record = {}; + const toDevice: Map = new Map(); for (const device of devices) { - toDevice[device] = requestData; + toDevice.set(device, requestData); } logger.info(`Request secret ${name} from ${devices}, id ${requestId}`); - this.baseApis.sendToDevice("m.secret.request", { - [this.baseApis.getUserId()!]: toDevice, - }); + this.baseApis.sendToDevice("m.secret.request", new Map([[this.baseApis.getUserId()!, toDevice]])); return { requestId, @@ -469,9 +465,11 @@ export class SecretStorage { ciphertext: {}, [ToDeviceMessageId]: uuidv4(), }; - await olmlib.ensureOlmSessionsForDevices(this.baseApis.crypto!.olmDevice, this.baseApis, { - [sender]: [this.baseApis.getStoredDevice(sender, deviceId)!], - }); + await olmlib.ensureOlmSessionsForDevices( + this.baseApis.crypto!.olmDevice, + this.baseApis, + new Map([[sender, [this.baseApis.getStoredDevice(sender, deviceId)!]]]), + ); await olmlib.encryptMessageForDevice( encryptedContent.ciphertext, this.baseApis.getUserId()!, @@ -481,11 +479,7 @@ export class SecretStorage { this.baseApis.getStoredDevice(sender, deviceId)!, payload, ); - const contentMap = { - [sender]: { - [deviceId]: encryptedContent, - }, - }; + const contentMap = new Map([[sender, new Map([[deviceId, encryptedContent]])]]); logger.info(`Sending ${content.name} secret for ${deviceId}`); this.baseApis.sendToDevice("m.room.encrypted", contentMap); diff --git a/src/crypto/algorithms/base.ts b/src/crypto/algorithms/base.ts index 06cb1830323..647300948bd 100644 --- a/src/crypto/algorithms/base.ts +++ b/src/crypto/algorithms/base.ts @@ -26,6 +26,7 @@ import { IContent, MatrixEvent, RoomMember } from "../../matrix"; import { Crypto, IEncryptedContent, IEventDecryptionResult, IncomingRoomKeyRequest } from ".."; import { DeviceInfo } from "../deviceinfo"; import { IRoomEncryption } from "../RoomList"; +import { DeviceInfoMap } from "../DeviceList"; /** * Map of registered encryption algorithm classes. A map from string to {@link EncryptionAlgorithm} class @@ -195,7 +196,7 @@ export abstract class DecryptionAlgorithm { } public onRoomKeyWithheldEvent?(event: MatrixEvent): Promise; - public sendSharedHistoryInboundSessions?(devicesByUser: Record): Promise; + public sendSharedHistoryInboundSessions?(devicesByUser: Map): Promise; } /** @@ -241,11 +242,7 @@ export class UnknownDeviceError extends Error { * @param msg - message describing the problem * @param devices - set of unknown devices per user we're warning about */ - public constructor( - msg: string, - public readonly devices: Record>, - public event?: MatrixEvent, - ) { + public constructor(msg: string, public readonly devices: DeviceInfoMap, public event?: MatrixEvent) { super(msg); this.name = "UnknownDeviceError"; this.devices = devices; diff --git a/src/crypto/algorithms/megolm.ts b/src/crypto/algorithms/megolm.ts index 934b69bd35d..061e169e39b 100644 --- a/src/crypto/algorithms/megolm.ts +++ b/src/crypto/algorithms/megolm.ts @@ -43,7 +43,7 @@ import { IMegolmEncryptedContent, IncomingRoomKeyRequest, IEncryptedContent } fr import { RoomKeyRequestState } from "../OutgoingRoomKeyRequestManager"; import { OlmGroupSessionExtraData } from "../../@types/crypto"; import { MatrixError } from "../../http-api"; -import { immediate } from "../../utils"; +import { immediate, MapWithDefault } from "../../utils"; // determine whether the key can be shared with invitees export function isRoomSharedHistory(room: Room): boolean { @@ -63,17 +63,27 @@ interface IBlockedDevice { deviceInfo: DeviceInfo; } -interface IBlockedMap { - [userId: string]: { - [deviceId: string]: IBlockedDevice; - }; -} +// map user Id โ†’ device Id โ†’ IBlockedDevice +type BlockedMap = Map>; export interface IOlmDevice { userId: string; deviceInfo: T; } +/** + * Tests whether an encrypted content has a ciphertext. + * Ciphertext can be a string or object depending on the content type {@link IEncryptedContent}. + * + * @param content - Encrypted content + * @returns true: has ciphertext, else false + */ +const hasCiphertext = (content: IEncryptedContent): boolean => { + return typeof content.ciphertext === "string" + ? !!content.ciphertext.length + : !!Object.keys(content.ciphertext).length; +}; + /** The result of parsing the an `m.room_key` or `m.forwarded_room_key` to-device event */ interface RoomKey { /** @@ -147,8 +157,8 @@ class OutboundSessionInfo { /** when the session was created (ms since the epoch) */ public creationTime: number; /** devices with which we have shared the session key `userId -> {deviceId -> SharedWithData}` */ - public sharedWithDevices: Record> = {}; - public blockedDevicesNotified: Record> = {}; + public sharedWithDevices: MapWithDefault> = new MapWithDefault(() => new Map()); + public blockedDevicesNotified: MapWithDefault> = new MapWithDefault(() => new Map()); /** * @param sharedHistory - whether the session can be freely shared with @@ -173,17 +183,11 @@ class OutboundSessionInfo { } public markSharedWithDevice(userId: string, deviceId: string, deviceKey: string, chainIndex: number): void { - if (!this.sharedWithDevices[userId]) { - this.sharedWithDevices[userId] = {}; - } - this.sharedWithDevices[userId][deviceId] = { deviceKey, messageIndex: chainIndex }; + this.sharedWithDevices.getOrCreate(userId).set(deviceId, { deviceKey, messageIndex: chainIndex }); } public markNotifiedBlockedDevice(userId: string, deviceId: string): void { - if (!this.blockedDevicesNotified[userId]) { - this.blockedDevicesNotified[userId] = {}; - } - this.blockedDevicesNotified[userId][deviceId] = true; + this.blockedDevicesNotified.getOrCreate(userId).set(deviceId, true); } /** @@ -196,23 +200,15 @@ class OutboundSessionInfo { * @returns true if we have shared the session with devices which aren't * in devicesInRoom. */ - public sharedWithTooManyDevices(devicesInRoom: Record>): boolean { - for (const userId in this.sharedWithDevices) { - if (!this.sharedWithDevices.hasOwnProperty(userId)) { - continue; - } - - if (!devicesInRoom.hasOwnProperty(userId)) { + public sharedWithTooManyDevices(devicesInRoom: DeviceInfoMap): boolean { + for (const [userId, devices] of this.sharedWithDevices) { + if (!devicesInRoom.has(userId)) { logger.log("Starting new megolm session because we shared with " + userId); return true; } - for (const deviceId in this.sharedWithDevices[userId]) { - if (!this.sharedWithDevices[userId].hasOwnProperty(deviceId)) { - continue; - } - - if (!devicesInRoom[userId].hasOwnProperty(deviceId)) { + for (const [deviceId] of devices) { + if (!devicesInRoom.get(userId)?.get(deviceId)) { logger.log("Starting new megolm session because we shared with " + userId + ":" + deviceId); return true; } @@ -292,7 +288,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { private async ensureOutboundSession( room: Room, devicesInRoom: DeviceInfoMap, - blocked: IBlockedMap, + blocked: BlockedMap, singleOlmCreationPhase = false, ): Promise { // takes the previous OutboundSessionInfo, and considers whether to create @@ -360,21 +356,21 @@ export class MegolmEncryption extends EncryptionAlgorithm { devicesInRoom: DeviceInfoMap, sharedHistory: boolean, singleOlmCreationPhase: boolean, - blocked: IBlockedMap, + blocked: BlockedMap, session: OutboundSessionInfo, ): Promise { // now check if we need to share with any devices const shareMap: Record = {}; - for (const [userId, userDevices] of Object.entries(devicesInRoom)) { - for (const [deviceId, deviceInfo] of Object.entries(userDevices)) { + for (const [userId, userDevices] of devicesInRoom) { + for (const [deviceId, deviceInfo] of userDevices) { const key = deviceInfo.getIdentityKey(); if (key == this.olmDevice.deviceCurve25519Key) { // don't bother sending to ourself continue; } - if (!session.sharedWithDevices[userId] || session.sharedWithDevices[userId][deviceId] === undefined) { + if (!session.sharedWithDevices.get(userId)?.get(deviceId)) { shareMap[userId] = shareMap[userId] || []; shareMap[userId].push(deviceInfo); } @@ -402,9 +398,9 @@ export class MegolmEncryption extends EncryptionAlgorithm { await Promise.all([ (async (): Promise => { // share keys with devices that we already have a session for - const olmSessionList = Object.entries(olmSessions) + const olmSessionList = Array.from(olmSessions.entries()) .map(([userId, sessionsByUser]) => - Object.entries(sessionsByUser).map( + Array.from(sessionsByUser.entries()).map( ([deviceId, session]) => `${userId}/${deviceId}: ${session.sessionId}`, ), ) @@ -414,7 +410,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { this.prefixedLogger.debug("Shared keys with existing Olm sessions"); })(), (async (): Promise => { - const deviceList = Object.entries(devicesWithoutSession) + const deviceList = Array.from(devicesWithoutSession.entries()) .map(([userId, devicesByUser]) => devicesByUser.map((device) => `${userId}/${device.deviceId}`)) .flat(1); this.prefixedLogger.debug( @@ -450,7 +446,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { // do this in the background and don't block anything else while we // do this. We only need to retry users from servers that didn't // respond the first time. - const retryDevices: Record = {}; + const retryDevices: MapWithDefault = new MapWithDefault(() => []); const failedServerMap = new Set(); for (const server of failedServers) { failedServerMap.add(server); @@ -459,8 +455,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { for (const { userId, deviceInfo } of errorDevices) { const userHS = userId.slice(userId.indexOf(":") + 1); if (failedServerMap.has(userHS)) { - retryDevices[userId] = retryDevices[userId] || []; - retryDevices[userId].push(deviceInfo); + retryDevices.getOrCreate(userId).push(deviceInfo); } else { // if we aren't going to retry, then handle it // as a failed device @@ -468,7 +463,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { } } - const retryDeviceList = Object.entries(retryDevices) + const retryDeviceList = Array.from(retryDevices.entries()) .map(([userId, devicesByUser]) => devicesByUser.map((device) => `${userId}/${device.deviceId}`), ) @@ -493,25 +488,25 @@ export class MegolmEncryption extends EncryptionAlgorithm { })(), (async (): Promise => { this.prefixedLogger.debug( - `There are ${Object.entries(blocked).length} blocked devices:`, - Object.entries(blocked) + `There are ${blocked.size} blocked devices:`, + Array.from(blocked.entries()) .map(([userId, blockedByUser]) => - Object.entries(blockedByUser).map(([deviceId, _deviceInfo]) => `${userId}/${deviceId}`), + Array.from(blockedByUser.entries()).map( + ([deviceId, _deviceInfo]) => `${userId}/${deviceId}`, + ), ) .flat(1), ); // also, notify newly blocked devices that they're blocked - const blockedMap: Record> = {}; + const blockedMap: MapWithDefault> = new MapWithDefault( + () => new Map(), + ); let blockedCount = 0; - for (const [userId, userBlockedDevices] of Object.entries(blocked)) { - for (const [deviceId, device] of Object.entries(userBlockedDevices)) { - if ( - !session.blockedDevicesNotified[userId] || - session.blockedDevicesNotified[userId][deviceId] === undefined - ) { - blockedMap[userId] = blockedMap[userId] || {}; - blockedMap[userId][deviceId] = { device }; + for (const [userId, userBlockedDevices] of blocked) { + for (const [deviceId, device] of userBlockedDevices) { + if (session.blockedDevicesNotified.get(userId)?.get(deviceId) === undefined) { + blockedMap.getOrCreate(userId).set(deviceId, { device }); blockedCount++; } } @@ -520,7 +515,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { if (blockedCount) { this.prefixedLogger.debug( `Notifying ${blockedCount} newly blocked devices:`, - Object.entries(blockedMap) + Array.from(blockedMap.entries()) .map(([userId, blockedByUser]) => Object.entries(blockedByUser).map(([deviceId, _deviceInfo]) => `${userId}/${deviceId}`), ) @@ -566,7 +561,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { * * @internal * - * @param devicemap - the devices that have olm sessions, as returned by + * @param deviceMap - the devices that have olm sessions, as returned by * olmlib.ensureOlmSessionsForDevices. * @param devicesByUser - a map of user IDs to array of deviceInfo * @param noOlmDevices - an array to fill with devices that don't have @@ -576,23 +571,23 @@ export class MegolmEncryption extends EncryptionAlgorithm { * noOlmDevices is specified, then noOlmDevices will be returned. */ private getDevicesWithoutSessions( - devicemap: Record>, - devicesByUser: Record, + deviceMap: Map>, + devicesByUser: Map, noOlmDevices: IOlmDevice[] = [], ): IOlmDevice[] { - for (const [userId, devicesToShareWith] of Object.entries(devicesByUser)) { - const sessionResults = devicemap[userId]; + for (const [userId, devicesToShareWith] of devicesByUser) { + const sessionResults = deviceMap.get(userId); for (const deviceInfo of devicesToShareWith) { const deviceId = deviceInfo.deviceId; - const sessionResult = sessionResults[deviceId]; - if (!sessionResult.sessionId) { + const sessionResult = sessionResults?.get(deviceId); + if (!sessionResult?.sessionId) { // no session with this device, probably because there // were no one-time keys. noOlmDevices.push({ userId, deviceInfo }); - delete sessionResults[deviceId]; + sessionResults?.delete(deviceId); // ensureOlmSessionsForUsers has already done the logging, // so just skip it. @@ -615,7 +610,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { * @returns the blocked devices, split into chunks */ private splitDevices( - devicesByUser: Record>, + devicesByUser: Map>, ): IOlmDevice[][] { const maxDevicesPerRequest = 20; @@ -623,8 +618,8 @@ export class MegolmEncryption extends EncryptionAlgorithm { let currentSlice: IOlmDevice[] = []; const mapSlices = [currentSlice]; - for (const [userId, userDevices] of Object.entries(devicesByUser)) { - for (const deviceInfo of Object.values(userDevices)) { + for (const [userId, userDevices] of devicesByUser) { + for (const deviceInfo of userDevices.values()) { currentSlice.push({ userId: userId, deviceInfo: deviceInfo.device, @@ -702,7 +697,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { userDeviceMap: IOlmDevice[], payload: IPayload, ): Promise { - const contentMap: Record> = {}; + const contentMap: MapWithDefault> = new MapWithDefault(() => new Map()); for (const val of userDeviceMap) { const userId = val.userId; @@ -722,17 +717,14 @@ export class MegolmEncryption extends EncryptionAlgorithm { delete message.session_id; } - if (!contentMap[userId]) { - contentMap[userId] = {}; - } - contentMap[userId][deviceId] = message; + contentMap.getOrCreate(userId).set(deviceId, message); } await this.baseApis.sendToDevice("m.room_key.withheld", contentMap); // record the fact that we notified these blocked devices - for (const userId of Object.keys(contentMap)) { - for (const deviceId of Object.keys(contentMap[userId])) { + for (const [userId, userDeviceMap] of contentMap) { + for (const deviceId of userDeviceMap.keys()) { session.markNotifiedBlockedDevice(userId, deviceId); } } @@ -760,11 +752,11 @@ export class MegolmEncryption extends EncryptionAlgorithm { } // The chain index of the key we previously sent this device - if (obSessionInfo.sharedWithDevices[userId] === undefined) { + if (!obSessionInfo.sharedWithDevices.has(userId)) { this.prefixedLogger.debug(`megolm session ${senderKey}|${sessionId} never shared with user ${userId}`); return; } - const sessionSharedData = obSessionInfo.sharedWithDevices[userId][device.deviceId]; + const sessionSharedData = obSessionInfo.sharedWithDevices.get(userId)?.get(device.deviceId); if (sessionSharedData === undefined) { this.prefixedLogger.debug( `megolm session ${senderKey}|${sessionId} never shared with device ${userId}:${device.deviceId}`, @@ -796,9 +788,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { return; } - await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, { - [userId]: [device], - }); + await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, new Map([[userId, [device]]])); const payload = { type: "m.forwarded_room_key", @@ -831,11 +821,10 @@ export class MegolmEncryption extends EncryptionAlgorithm { payload, ); - await this.baseApis.sendToDevice("m.room.encrypted", { - [userId]: { - [device.deviceId]: encryptedContent, - }, - }); + await this.baseApis.sendToDevice( + "m.room.encrypted", + new Map([[userId, new Map([[device.deviceId, encryptedContent]])]]), + ); this.prefixedLogger.debug( `Re-shared key for megolm session ${senderKey}|${sessionId} with ${userId}:${device.deviceId}`, ); @@ -865,7 +854,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { session: OutboundSessionInfo, key: IOutboundGroupSessionKey, payload: IPayload, - devicesByUser: Record, + devicesByUser: Map, errorDevices: IOlmDevice[], otkTimeout: number, failedServers?: string[], @@ -887,9 +876,9 @@ export class MegolmEncryption extends EncryptionAlgorithm { session: OutboundSessionInfo, key: IOutboundGroupSessionKey, payload: IPayload, - devicemap: Record>, + deviceMap: Map>, ): Promise { - const userDeviceMaps = this.splitDevices(devicemap); + const userDeviceMaps = this.splitDevices(deviceMap); for (let i = 0; i < userDeviceMaps.length; i++) { const taskDetail = `megolm keys for ${session.sessionId} (slice ${i + 1}/${userDeviceMaps.length})`; @@ -934,19 +923,20 @@ export class MegolmEncryption extends EncryptionAlgorithm { this.prefixedLogger.debug( `Need to notify ${unnotifiedFailedDevices.length} failed devices which haven't been notified before`, ); - const blockedMap: Record> = {}; + const blockedMap: MapWithDefault> = new MapWithDefault( + () => new Map(), + ); for (const { userId, deviceInfo } of unnotifiedFailedDevices) { - blockedMap[userId] = blockedMap[userId] || {}; // we use a similar format to what // olmlib.ensureOlmSessionsForDevices returns, so that // we can use the same function to split - blockedMap[userId][deviceInfo.deviceId] = { + blockedMap.getOrCreate(userId).set(deviceInfo.deviceId, { device: { code: "m.no_olm", reason: WITHHELD_MESSAGES["m.no_olm"], deviceInfo, }, - }; + }); } // send the notifications @@ -964,7 +954,7 @@ export class MegolmEncryption extends EncryptionAlgorithm { */ private async notifyBlockedDevices( session: OutboundSessionInfo, - devicesByUser: Record>, + devicesByUser: Map>, ): Promise { const payload: IPayload = { room_id: this.roomId, @@ -1154,21 +1144,17 @@ export class MegolmEncryption extends EncryptionAlgorithm { * devices we should shared the session with. */ private checkForUnknownDevices(devicesInRoom: DeviceInfoMap): void { - const unknownDevices: Record> = {}; + const unknownDevices: MapWithDefault> = new MapWithDefault(() => new Map()); - Object.keys(devicesInRoom).forEach((userId) => { - Object.keys(devicesInRoom[userId]).forEach((deviceId) => { - const device = devicesInRoom[userId][deviceId]; + for (const [userId, userDevices] of devicesInRoom) { + for (const [deviceId, device] of userDevices) { if (device.isUnverified() && !device.isKnown()) { - if (!unknownDevices[userId]) { - unknownDevices[userId] = {}; - } - unknownDevices[userId][deviceId] = device; + unknownDevices.getOrCreate(userId).set(deviceId, device); } - }); - }); + } + } - if (Object.keys(unknownDevices).length) { + if (unknownDevices.size) { // it'd be kind to pass unknownDevices up to the user in this error throw new UnknownDeviceError( "This room contains unknown devices which have not been verified. " + @@ -1186,15 +1172,15 @@ export class MegolmEncryption extends EncryptionAlgorithm { * devices we should shared the session with. */ private removeUnknownDevices(devicesInRoom: DeviceInfoMap): void { - for (const [userId, userDevices] of Object.entries(devicesInRoom)) { - for (const [deviceId, device] of Object.entries(userDevices)) { + for (const [userId, userDevices] of devicesInRoom) { + for (const [deviceId, device] of userDevices) { if (device.isUnverified() && !device.isKnown()) { - delete userDevices[deviceId]; + userDevices.delete(deviceId); } } - if (Object.keys(userDevices).length === 0) { - delete devicesInRoom[userId]; + if (userDevices.size === 0) { + devicesInRoom.delete(userId); } } } @@ -1219,17 +1205,17 @@ export class MegolmEncryption extends EncryptionAlgorithm { private async getDevicesInRoom( room: Room, forceDistributeToUnverified?: boolean, - ): Promise<[DeviceInfoMap, IBlockedMap]>; + ): Promise<[DeviceInfoMap, BlockedMap]>; private async getDevicesInRoom( room: Room, forceDistributeToUnverified?: boolean, isCancelled?: () => boolean, - ): Promise; + ): Promise; private async getDevicesInRoom( room: Room, forceDistributeToUnverified = false, isCancelled?: () => boolean, - ): Promise { + ): Promise { const members = await room.getEncryptionTargetMembers(); this.prefixedLogger.debug( `Encrypting for users (shouldEncryptForInvitedMembers: ${room.shouldEncryptForInvitedMembers()}):`, @@ -1254,24 +1240,15 @@ export class MegolmEncryption extends EncryptionAlgorithm { // using all the device_lists changes and left fields. // See https://github.com/vector-im/element-web/issues/2305 for details. const devices = await this.crypto.downloadKeys(roomMembers, false); - const blocked: IBlockedMap = {}; if (isCancelled?.() === true) { return null; } + const blocked = new MapWithDefault>(() => new Map()); // remove any blocked devices - for (const userId in devices) { - if (!devices.hasOwnProperty(userId)) { - continue; - } - - const userDevices = devices[userId]; - for (const deviceId in userDevices) { - if (!userDevices.hasOwnProperty(deviceId)) { - continue; - } - + for (const [userId, userDevices] of devices) { + for (const [deviceId, userDevice] of userDevices) { // Yield prior to checking each device so that we don't block // updating/rendering for too long. // See https://github.com/vector-im/element-web/issues/21612 @@ -1280,19 +1257,17 @@ export class MegolmEncryption extends EncryptionAlgorithm { const deviceTrust = this.crypto.checkDeviceTrust(userId, deviceId); if ( - userDevices[deviceId].isBlocked() || + userDevice.isBlocked() || (!deviceTrust.isVerified() && isBlacklisting && !forceDistributeToUnverified) ) { - if (!blocked[userId]) { - blocked[userId] = {}; - } - const isBlocked = userDevices[deviceId].isBlocked(); - blocked[userId][deviceId] = { + const blockedDevices = blocked.getOrCreate(userId); + const isBlocked = userDevice.isBlocked(); + blockedDevices.set(deviceId, { code: isBlocked ? "m.blacklisted" : "m.unverified", reason: WITHHELD_MESSAGES[isBlocked ? "m.blacklisted" : "m.unverified"], - deviceInfo: userDevices[deviceId], - }; - delete userDevices[deviceId]; + deviceInfo: userDevice, + }); + userDevices.delete(deviceId); } } } @@ -1923,7 +1898,7 @@ export class MegolmDecryption extends DecryptionAlgorithm { // XXX: switch this to use encryptAndSendToDevices() rather than duplicating it? - await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, { [sender]: [device] }, false); + await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, new Map([[sender, [device]]]), false); const encryptedContent: IEncryptedContent = { algorithm: olmlib.OLM_ALGORITHM, sender_key: this.olmDevice.deviceCurve25519Key!, @@ -1942,11 +1917,10 @@ export class MegolmDecryption extends DecryptionAlgorithm { await this.olmDevice.recordSessionProblem(senderKey, "no_olm", true); - await this.baseApis.sendToDevice("m.room.encrypted", { - [sender]: { - [device.deviceId]: encryptedContent, - }, - }); + await this.baseApis.sendToDevice( + "m.room.encrypted", + new Map([[sender, new Map([[device.deviceId, encryptedContent]])]]), + ); } public hasKeysForKeyRequest(keyRequest: IncomingRoomKeyRequest): Promise { @@ -1969,12 +1943,10 @@ export class MegolmDecryption extends DecryptionAlgorithm { // XXX: switch this to use encryptAndSendToDevices()? this.olmlib - .ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, { - [userId]: [deviceInfo], - }) + .ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, new Map([[userId, [deviceInfo]]])) .then((devicemap) => { - const olmSessionResult = devicemap[userId][deviceId]; - if (!olmSessionResult.sessionId) { + const olmSessionResult = devicemap.get(userId)?.get(deviceId); + if (!olmSessionResult?.sessionId) { // no session with this device, probably because there // were no one-time keys. // @@ -2015,14 +1987,11 @@ export class MegolmDecryption extends DecryptionAlgorithm { payload!, ) .then(() => { - const contentMap = { - [userId]: { - [deviceId]: encryptedContent, - }, - }; - // TODO: retries - return this.baseApis.sendToDevice("m.room.encrypted", contentMap); + return this.baseApis.sendToDevice( + "m.room.encrypted", + new Map([[userId, new Map([[deviceId, encryptedContent]])]]), + ); }); }); } @@ -2162,12 +2131,12 @@ export class MegolmDecryption extends DecryptionAlgorithm { return !this.pendingEvents.has(senderKey); } - public async sendSharedHistoryInboundSessions(devicesByUser: Record): Promise { + public async sendSharedHistoryInboundSessions(devicesByUser: Map): Promise { await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, devicesByUser); const sharedHistorySessions = await this.olmDevice.getSharedHistoryInboundGroupSessions(this.roomId); this.prefixedLogger.log( - `Sharing history in with users ${Object.keys(devicesByUser)}`, + `Sharing history in with users ${Array.from(devicesByUser.keys())}`, sharedHistorySessions.map(([senderKey, sessionId]) => `${senderKey}|${sessionId}`), ); for (const [senderKey, sessionId] of sharedHistorySessions) { @@ -2175,9 +2144,10 @@ export class MegolmDecryption extends DecryptionAlgorithm { // FIXME: use encryptAndSendToDevices() rather than duplicating it here. const promises: Promise[] = []; - const contentMap: Record> = {}; - for (const [userId, devices] of Object.entries(devicesByUser)) { - contentMap[userId] = {}; + const contentMap: Map> = new Map(); + for (const [userId, devices] of devicesByUser) { + const deviceMessages = new Map(); + contentMap.set(userId, deviceMessages); for (const deviceInfo of devices) { const encryptedContent: IEncryptedContent = { algorithm: olmlib.OLM_ALGORITHM, @@ -2185,7 +2155,7 @@ export class MegolmDecryption extends DecryptionAlgorithm { ciphertext: {}, [ToDeviceMessageId]: uuidv4(), }; - contentMap[userId][deviceInfo.deviceId] = encryptedContent; + deviceMessages.set(deviceInfo.deviceId, encryptedContent); promises.push( olmlib.encryptMessageForDevice( encryptedContent.ciphertext, @@ -2205,22 +2175,22 @@ export class MegolmDecryption extends DecryptionAlgorithm { // in which case it will have just not added anything to the ciphertext object. // There's no point sending messages to devices if we couldn't encrypt to them, // since that's effectively a blank message. - for (const userId of Object.keys(contentMap)) { - for (const deviceId of Object.keys(contentMap[userId])) { - if (Object.keys(contentMap[userId][deviceId].ciphertext).length === 0) { + for (const [userId, deviceMessages] of contentMap) { + for (const [deviceId, content] of deviceMessages) { + if (!hasCiphertext(content)) { this.prefixedLogger.log("No ciphertext for device " + userId + ":" + deviceId + ": pruning"); - delete contentMap[userId][deviceId]; + deviceMessages.delete(deviceId); } } // No devices left for that user? Strip that too. - if (Object.keys(contentMap[userId]).length === 0) { + if (deviceMessages.size === 0) { this.prefixedLogger.log("Pruned all devices for user " + userId); - delete contentMap[userId]; + contentMap.delete(userId); } } // Is there anything left? - if (Object.keys(contentMap).length === 0) { + if (contentMap.size === 0) { this.prefixedLogger.log("No users left to send to: aborting"); return; } diff --git a/src/crypto/backup.ts b/src/crypto/backup.ts index d71cce99c7d..d240bdab83c 100644 --- a/src/crypto/backup.ts +++ b/src/crypto/backup.ts @@ -25,7 +25,7 @@ import { MEGOLM_ALGORITHM, verifySignature } from "./olmlib"; import { DeviceInfo } from "./deviceinfo"; import { DeviceTrustLevel } from "./CrossSigning"; import { keyFromPassphrase } from "./key_passphrase"; -import { sleep } from "../utils"; +import { safeSet, sleep } from "../utils"; import { IndexedDBCryptoStore } from "./store/indexeddb-crypto-store"; import { encodeRecoveryKey } from "./recoverykey"; import { calculateKeyCheck, decryptAES, encryptAES, IEncryptedPayload } from "./aes"; @@ -498,9 +498,7 @@ export class BackupManager { const rooms: IKeyBackup["rooms"] = {}; for (const session of sessions) { const roomId = session.sessionData!.room_id; - if (rooms[roomId] === undefined) { - rooms[roomId] = { sessions: {} }; - } + safeSet(rooms, roomId, rooms[roomId] || { sessions: {} }); const sessionData = this.baseApis.crypto!.olmDevice.exportInboundGroupSession( session.senderKey, @@ -517,12 +515,12 @@ export class BackupManager { undefined; const verified = this.baseApis.crypto!.checkDeviceInfoTrust(userId!, device).isVerified(); - rooms[roomId]["sessions"][session.sessionId] = { + safeSet(rooms[roomId]["sessions"], session.sessionId, { first_message_index: sessionData.first_known_index, forwarded_count: forwardedCount, is_verified: verified, session_data: await this.algorithm!.encryptSession(sessionData), - }; + }); } await this.baseApis.sendKeyBackup(undefined, undefined, this.backupInfo!.version, { rooms }); diff --git a/src/crypto/index.ts b/src/crypto/index.ts index 5500872226f..2390268c16c 100644 --- a/src/crypto/index.ts +++ b/src/crypto/index.ts @@ -90,6 +90,7 @@ import { ISignatures } from "../@types/signed"; import { IMessage } from "./algorithms/olm"; import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend"; import { RoomState, RoomStateEvent } from "../models/room-state"; +import { MapWithDefault, recursiveMapToObject } from "../utils"; const DeviceVerification = DeviceInfo.DeviceVerification; @@ -399,7 +400,10 @@ export class Crypto extends TypedEventEmitter> = {}; + // Map: user Id โ†’ device Id โ†’ timestamp + private lastNewSessionForced: MapWithDefault> = new MapWithDefault( + () => new MapWithDefault(() => 0), + ); // This flag will be unset whilst the client processes a sync response // so that we don't start requesting keys until we've actually finished @@ -2690,11 +2694,13 @@ export class Crypto extends TypedEventEmitter>> { - const devicesByUser: Record = {}; + ): Promise>> { + // map user Id โ†’ DeviceInfo[] + const devicesByUser: Map = new Map(); for (const userId of users) { - devicesByUser[userId] = []; + const userDevices: DeviceInfo[] = []; + devicesByUser.set(userId, userDevices); const devices = this.getStoredDevicesForUser(userId) || []; for (const deviceInfo of devices) { @@ -2708,7 +2714,7 @@ export class Crypto extends TypedEventEmitter Date.now()) { logger.debug( "New session already forced with device " + @@ -3482,11 +3492,10 @@ export class Crypto extends TypedEventEmitter = {}; - devicesByUser[sender] = [device]; + const devicesByUser = new Map([[sender, [device]]]); await olmlib.ensureOlmSessionsForDevices(this.olmDevice, this.baseApis, devicesByUser, true); - this.lastNewSessionForced[sender][deviceKey] = Date.now(); + lastNewSessionDevices.set(deviceKey, Date.now()); // Now send a blank message on that session so the other side knows about it. // (The keyshare request is sent in the clear so that won't do) @@ -3513,11 +3522,10 @@ export class Crypto extends TypedEventEmitter(obj: T): Promise { - const sigs = obj.signatures || {}; + const sigs = new Map(Object.entries(obj.signatures || {})); const unsigned = obj.unsigned; delete obj.signatures; delete obj.unsigned; - sigs[this.userId] = sigs[this.userId] || {}; - sigs[this.userId]["ed25519:" + this.deviceId] = await this.olmDevice.sign(anotherjson.stringify(obj)); - obj.signatures = sigs; + const userSignatures = sigs.get(this.userId) || {}; + sigs.set(this.userId, userSignatures); + userSignatures["ed25519:" + this.deviceId] = await this.olmDevice.sign(anotherjson.stringify(obj)); + obj.signatures = recursiveMapToObject(sigs); if (unsigned !== undefined) obj.unsigned = unsigned; } } diff --git a/src/crypto/olmlib.ts b/src/crypto/olmlib.ts index 5f343771b17..c37b7f0a49e 100644 --- a/src/crypto/olmlib.ts +++ b/src/crypto/olmlib.ts @@ -30,6 +30,7 @@ import { ISignatures } from "../@types/signed"; import { MatrixEvent } from "../models/event"; import { EventType } from "../@types/event"; import { IMessage } from "./algorithms/olm"; +import { MapWithDefault } from "../utils"; enum Algorithm { Olm = "m.olm.v1.curve25519-aes-sha2", @@ -154,9 +155,11 @@ export async function getExistingOlmSessions( olmDevice: OlmDevice, baseApis: MatrixClient, devicesByUser: Record, -): Promise<[Record, Record>]> { - const devicesWithoutSession: { [userId: string]: DeviceInfo[] } = {}; - const sessions: { [userId: string]: { [deviceId: string]: IExistingOlmSession } } = {}; +): Promise<[Map, Map>]> { + // map user Id โ†’ DeviceInfo[] + const devicesWithoutSession: MapWithDefault = new MapWithDefault(() => []); + // map user Id โ†’ device Id โ†’ IExistingOlmSession + const sessions: MapWithDefault> = new MapWithDefault(() => new Map()); const promises: Promise[] = []; @@ -168,14 +171,12 @@ export async function getExistingOlmSessions( (async (): Promise => { const sessionId = await olmDevice.getSessionIdForDevice(key, true); if (sessionId === null) { - devicesWithoutSession[userId] = devicesWithoutSession[userId] || []; - devicesWithoutSession[userId].push(deviceInfo); + devicesWithoutSession.getOrCreate(userId).push(deviceInfo); } else { - sessions[userId] = sessions[userId] || {}; - sessions[userId][deviceId] = { + sessions.getOrCreate(userId).set(deviceId, { device: deviceInfo, sessionId: sessionId, - }; + }); } })(), ); @@ -210,24 +211,26 @@ export async function getExistingOlmSessions( export async function ensureOlmSessionsForDevices( olmDevice: OlmDevice, baseApis: MatrixClient, - devicesByUser: Record, + devicesByUser: Map, force = false, otkTimeout?: number, failedServers?: string[], log = logger, -): Promise>> { +): Promise>> { const devicesWithoutSession: [string, string][] = [ // [userId, deviceId], ... ]; - const result: { [userId: string]: { [deviceId: string]: IExistingOlmSession } } = {}; - const resolveSession: Record void> = {}; + // map user Id โ†’ device Id โ†’ IExistingOlmSession + const result: Map> = new Map(); + // map device key โ†’ resolve session fn + const resolveSession: Map void> = new Map(); // Mark all sessions this task intends to update as in progress. It is // important to do this for all devices this task cares about in a single // synchronous operation, as otherwise it is possible to have deadlocks // where multiple tasks wait indefinitely on another task to update some set // of common devices. - for (const [, devices] of Object.entries(devicesByUser)) { + for (const devices of devicesByUser.values()) { for (const deviceInfo of devices) { const key = deviceInfo.getIdentityKey(); @@ -242,17 +245,19 @@ export async function ensureOlmSessionsForDevices( // conditions. If we find that we already have a session, then // we'll resolve olmDevice.sessionsInProgress[key] = new Promise((resolve) => { - resolveSession[key] = (v: any): void => { + resolveSession.set(key, (v: any): void => { delete olmDevice.sessionsInProgress[key]; resolve(v); - }; + }); }); } } } - for (const [userId, devices] of Object.entries(devicesByUser)) { - result[userId] = {}; + for (const [userId, devices] of devicesByUser) { + const resultDevices = new Map(); + result.set(userId, resultDevices); + for (const deviceInfo of devices) { const deviceId = deviceInfo.deviceId; const key = deviceInfo.getIdentityKey(); @@ -268,20 +273,21 @@ export async function ensureOlmSessionsForDevices( log.info("Attempted to start session with ourself! Ignoring"); // We must fill in the section in the return value though, as callers // expect it to be there. - result[userId][deviceId] = { + resultDevices.set(deviceId, { device: deviceInfo, sessionId: null, - }; + }); continue; } const forWhom = `for ${key} (${userId}:${deviceId})`; - const sessionId = await olmDevice.getSessionIdForDevice(key, !!resolveSession[key], log); - if (sessionId !== null && resolveSession[key]) { + const sessionId = await olmDevice.getSessionIdForDevice(key, !!resolveSession.get(key), log); + const resolveSessionFn = resolveSession.get(key); + if (sessionId !== null && resolveSessionFn) { // we found a session, but we had marked the session as // in-progress, so resolve it now, which will unmark it and // unblock anything that was waiting - resolveSession[key](); + resolveSessionFn(); } if (sessionId === null || force) { if (force) { @@ -291,10 +297,10 @@ export async function ensureOlmSessionsForDevices( } devicesWithoutSession.push([userId, deviceId]); } - result[userId][deviceId] = { + resultDevices.set(deviceId, { device: deviceInfo, sessionId: sessionId, - }; + }); } } @@ -310,7 +316,7 @@ export async function ensureOlmSessionsForDevices( res = await baseApis.claimOneTimeKeys(devicesWithoutSession, oneTimeKeyAlgorithm, otkTimeout); log.debug(`Claimed ${taskDetail}`); } catch (e) { - for (const resolver of Object.values(resolveSession)) { + for (const resolver of resolveSession.values()) { resolver(); } log.log(`Failed to claim ${taskDetail}`, e, devicesWithoutSession); @@ -323,7 +329,7 @@ export async function ensureOlmSessionsForDevices( const otkResult = res.one_time_keys || ({} as IClaimOTKsResult["one_time_keys"]); const promises: Promise[] = []; - for (const [userId, devices] of Object.entries(devicesByUser)) { + for (const [userId, devices] of devicesByUser) { const userRes = otkResult[userId] || {}; for (const deviceInfo of devices) { const deviceId = deviceInfo.deviceId; @@ -336,7 +342,7 @@ export async function ensureOlmSessionsForDevices( continue; } - if (result[userId][deviceId].sessionId && !force) { + if (result.get(userId)?.get(deviceId)?.sessionId && !force) { // we already have a result for this device continue; } @@ -351,24 +357,19 @@ export async function ensureOlmSessionsForDevices( if (!oneTimeKey) { log.warn(`No one-time keys (alg=${oneTimeKeyAlgorithm}) ` + `for device ${userId}:${deviceId}`); - if (resolveSession[key]) { - resolveSession[key](); - } + resolveSession.get(key)?.(); continue; } promises.push( _verifyKeyAndStartSession(olmDevice, oneTimeKey, userId, deviceInfo).then( (sid) => { - if (resolveSession[key]) { - resolveSession[key](sid ?? undefined); - } - result[userId][deviceId].sessionId = sid; + resolveSession.get(key)?.(sid ?? undefined); + const deviceInfo = result.get(userId)?.get(deviceId); + if (deviceInfo) deviceInfo.sessionId = sid; }, (e) => { - if (resolveSession[key]) { - resolveSession[key](); - } + resolveSession.get(key)?.(); throw e; }, ), diff --git a/src/crypto/store/localStorage-crypto-store.ts b/src/crypto/store/localStorage-crypto-store.ts index 1a9adfb25d1..5552540f5d0 100644 --- a/src/crypto/store/localStorage-crypto-store.ts +++ b/src/crypto/store/localStorage-crypto-store.ts @@ -21,6 +21,7 @@ import { IOlmDevice } from "../algorithms/megolm"; import { IRoomEncryption } from "../RoomList"; import { ICrossSigningKey } from "../../client"; import { InboundGroupSessionData } from "../OlmDevice"; +import { safeSet } from "../../utils"; /** * Internal module. Partial localStorage backed storage for e2e. @@ -178,11 +179,11 @@ export class LocalStorageCryptoStore extends MemoryCryptoStore { if (userId in notifiedErrorDevices) { if (!(deviceInfo.deviceId in notifiedErrorDevices[userId])) { ret.push(device); - notifiedErrorDevices[userId][deviceInfo.deviceId] = true; + safeSet(notifiedErrorDevices[userId], deviceInfo.deviceId, true); } } else { ret.push(device); - notifiedErrorDevices[userId] = { [deviceInfo.deviceId]: true }; + safeSet(notifiedErrorDevices, userId, { [deviceInfo.deviceId]: true }); } } diff --git a/src/crypto/store/memory-crypto-store.ts b/src/crypto/store/memory-crypto-store.ts index ad779ca993b..29ae81bfede 100644 --- a/src/crypto/store/memory-crypto-store.ts +++ b/src/crypto/store/memory-crypto-store.ts @@ -33,6 +33,7 @@ import { ICrossSigningKey } from "../../client"; import { IOlmDevice } from "../algorithms/megolm"; import { IRoomEncryption } from "../RoomList"; import { InboundGroupSessionData } from "../OlmDevice"; +import { safeSet } from "../../utils"; /** * Internal module. in-memory storage for e2e. @@ -375,11 +376,11 @@ export class MemoryCryptoStore implements CryptoStore { if (userId in notifiedErrorDevices) { if (!(deviceInfo.deviceId in notifiedErrorDevices[userId])) { ret.push(device); - notifiedErrorDevices[userId][deviceInfo.deviceId] = true; + safeSet(notifiedErrorDevices[userId], deviceInfo.deviceId, true); } } else { ret.push(device); - notifiedErrorDevices[userId] = { [deviceInfo.deviceId]: true }; + safeSet(notifiedErrorDevices, userId, { [deviceInfo.deviceId]: true }); } } diff --git a/src/crypto/verification/request/ToDeviceChannel.ts b/src/crypto/verification/request/ToDeviceChannel.ts index cb9066457e6..d51b85ac75f 100644 --- a/src/crypto/verification/request/ToDeviceChannel.ts +++ b/src/crypto/verification/request/ToDeviceChannel.ts @@ -269,12 +269,12 @@ export class ToDeviceChannel implements IVerificationChannel { private async sendToDevices(type: string, content: Record, devices: string[]): Promise { if (devices.length) { - const msgMap: Record> = {}; + const deviceMessages: Map> = new Map(); for (const deviceId of devices) { - msgMap[deviceId] = content; + deviceMessages.set(deviceId, content); } - await this.client.sendToDevice(type, { [this.userId]: msgMap }); + await this.client.sendToDevice(type, new Map([[this.userId, deviceMessages]])); } } diff --git a/src/embedded.ts b/src/embedded.ts index 02dac6db2ce..a08b79a2f5c 100644 --- a/src/embedded.ts +++ b/src/embedded.ts @@ -29,15 +29,16 @@ import { IEvent, IContent, EventStatus } from "./models/event"; import { ISendEventResponse } from "./@types/requests"; import { EventType } from "./@types/event"; import { logger } from "./logger"; -import { MatrixClient, ClientEvent, IMatrixClientCreateOpts, IStartClientOpts } from "./client"; +import { MatrixClient, ClientEvent, IMatrixClientCreateOpts, IStartClientOpts, SendToDeviceContentMap } from "./client"; import { SyncApi, SyncState } from "./sync"; import { SlidingSyncSdk } from "./sliding-sync-sdk"; import { MatrixEvent } from "./models/event"; import { User } from "./models/user"; import { Room } from "./models/room"; -import { ToDeviceBatch } from "./models/ToDeviceMessage"; +import { ToDeviceBatch, ToDevicePayload } from "./models/ToDeviceMessage"; import { DeviceInfo } from "./crypto/deviceinfo"; import { IOlmDevice } from "./crypto/algorithms/megolm"; +import { MapWithDefault, recursiveMapToObject } from "./utils"; interface IStateEventRequest { eventType: string; @@ -234,35 +235,32 @@ export class RoomWidgetClient extends MatrixClient { return await this.widgetApi.sendStateEvent(eventType, stateKey, content, roomId); } - public async sendToDevice( - eventType: string, - contentMap: { [userId: string]: { [deviceId: string]: Record } }, - ): Promise<{}> { - await this.widgetApi.sendToDevice(eventType, false, contentMap); + public async sendToDevice(eventType: string, contentMap: SendToDeviceContentMap): Promise<{}> { + await this.widgetApi.sendToDevice(eventType, false, recursiveMapToObject(contentMap)); return {}; } public async queueToDevice({ eventType, batch }: ToDeviceBatch): Promise { - const contentMap: { [userId: string]: { [deviceId: string]: object } } = {}; + // map: user Id โ†’ device Id โ†’ payload + const contentMap: MapWithDefault> = new MapWithDefault(() => new Map()); for (const { userId, deviceId, payload } of batch) { - if (!contentMap[userId]) contentMap[userId] = {}; - contentMap[userId][deviceId] = payload; + contentMap.getOrCreate(userId).set(deviceId, payload); } - await this.widgetApi.sendToDevice(eventType, false, contentMap); + await this.widgetApi.sendToDevice(eventType, false, recursiveMapToObject(contentMap)); } public async encryptAndSendToDevices(userDeviceInfoArr: IOlmDevice[], payload: object): Promise { - const contentMap: { [userId: string]: { [deviceId: string]: object } } = {}; + // map: user Id โ†’ device Id โ†’ payload + const contentMap: MapWithDefault> = new MapWithDefault(() => new Map()); for (const { userId, deviceInfo: { deviceId }, } of userDeviceInfoArr) { - if (!contentMap[userId]) contentMap[userId] = {}; - contentMap[userId][deviceId] = payload; + contentMap.getOrCreate(userId).set(deviceId, payload); } - await this.widgetApi.sendToDevice((payload as { type: string }).type, true, contentMap); + await this.widgetApi.sendToDevice((payload as { type: string }).type, true, recursiveMapToObject(contentMap)); } // Overridden since we get TURN servers automatically over the widget API, diff --git a/src/models/read-receipt.ts b/src/models/read-receipt.ts index 85c1929320c..5858fe5bb99 100644 --- a/src/models/read-receipt.ts +++ b/src/models/read-receipt.ts @@ -16,7 +16,6 @@ import { MAIN_ROOM_TIMELINE, Receipt, ReceiptCache, - Receipts, ReceiptType, WrappedReceipt, } from "../@types/read_receipts"; @@ -25,6 +24,7 @@ import * as utils from "../utils"; import { MatrixEvent } from "./event"; import { EventType } from "../@types/event"; import { EventTimelineSet } from "./event-timeline-set"; +import { MapWithDefault } from "../utils"; import { NotificationCountType } from "./room"; export function synthesizeReceipt(userId: string, event: MatrixEvent, receiptType: ReceiptType): MatrixEvent { @@ -56,8 +56,11 @@ export abstract class ReadReceipt< // the form of this structure. This is sub-optimal for the exposed APIs // which pass in an event ID and get back some receipts, so we also store // a pre-cached list for this purpose. - private receipts: Receipts = {}; // { receipt_type: { user_id: Receipt } } - private receiptCacheByEventId: ReceiptCache = {}; // { event_id: CachedReceipt[] } + // Map: receipt type โ†’ user Id โ†’ receipt + private receipts = new MapWithDefault>( + () => new Map(), + ); + private receiptCacheByEventId: ReceiptCache = new Map(); public abstract getUnfilteredTimelineSet(): EventTimelineSet; public abstract timeline: MatrixEvent[]; @@ -74,7 +77,7 @@ export abstract class ReadReceipt< ignoreSynthesized = false, receiptType = ReceiptType.Read, ): WrappedReceipt | null { - const [realReceipt, syntheticReceipt] = this.receipts[receiptType]?.[userId] ?? []; + const [realReceipt, syntheticReceipt] = this.receipts.get(receiptType)?.get(userId) ?? [null, null]; if (ignoreSynthesized) { return realReceipt; } @@ -126,14 +129,13 @@ export abstract class ReadReceipt< receipt: Receipt, synthetic: boolean, ): void { - if (!this.receipts[receiptType]) { - this.receipts[receiptType] = {}; - } - if (!this.receipts[receiptType][userId]) { - this.receipts[receiptType][userId] = [null, null]; - } + const receiptTypesMap = this.receipts.getOrCreate(receiptType); + let pair = receiptTypesMap.get(userId); - const pair = this.receipts[receiptType][userId]; + if (!pair) { + pair = [null, null]; + receiptTypesMap.set(userId, pair); + } let existingReceipt = pair[ReceiptPairRealIndex]; if (synthetic) { @@ -185,23 +187,26 @@ export abstract class ReadReceipt< if (cachedReceipt === newCachedReceipt) return; // clean up any previous cache entry - if (cachedReceipt && this.receiptCacheByEventId[cachedReceipt.eventId]) { + if (cachedReceipt && this.receiptCacheByEventId.get(cachedReceipt.eventId)) { const previousEventId = cachedReceipt.eventId; // Remove the receipt we're about to clobber out of existence from the cache - this.receiptCacheByEventId[previousEventId] = this.receiptCacheByEventId[previousEventId].filter((r) => { - return r.type !== receiptType || r.userId !== userId; - }); + this.receiptCacheByEventId.set( + previousEventId, + this.receiptCacheByEventId.get(previousEventId)!.filter((r) => { + return r.type !== receiptType || r.userId !== userId; + }), + ); - if (this.receiptCacheByEventId[previousEventId].length < 1) { - delete this.receiptCacheByEventId[previousEventId]; // clean up the cache keys + if (this.receiptCacheByEventId.get(previousEventId)!.length < 1) { + this.receiptCacheByEventId.delete(previousEventId); // clean up the cache keys } } // cache the new one - if (!this.receiptCacheByEventId[eventId]) { - this.receiptCacheByEventId[eventId] = []; + if (!this.receiptCacheByEventId.get(eventId)) { + this.receiptCacheByEventId.set(eventId, []); } - this.receiptCacheByEventId[eventId].push({ + this.receiptCacheByEventId.get(eventId)!.push({ userId: userId, type: receiptType as ReceiptType, data: receipt, @@ -215,7 +220,7 @@ export abstract class ReadReceipt< * an empty list. */ public getReceiptsForEvent(event: MatrixEvent): CachedReceipt[] { - return this.receiptCacheByEventId[event.getId()!] || []; + return this.receiptCacheByEventId.get(event.getId()!) || []; } public abstract addReceipt(event: MatrixEvent, synthetic: boolean): void; diff --git a/src/models/room.ts b/src/models/room.ts index 399f9861437..0254429139c 100644 --- a/src/models/room.ts +++ b/src/models/room.ts @@ -25,7 +25,7 @@ import { import { Direction, EventTimeline } from "./event-timeline"; import { getHttpUriForMxc } from "../content-repo"; import * as utils from "../utils"; -import { normalize } from "../utils"; +import { normalize, noUnsafeEventProps } from "../utils"; import { IEvent, IThreadBundledRelationship, MatrixEvent, MatrixEventEvent, MatrixEventHandlerMap } from "./event"; import { EventStatus } from "./event-status"; import { RoomMember } from "./room-member"; @@ -311,7 +311,7 @@ export type RoomEventHandlerMap = { export class Room extends ReadReceipt { public readonly reEmitter: TypedReEmitter; - private txnToEvent: Record = {}; // Pending in-flight requests { string: MatrixEvent } + private txnToEvent: Map = new Map(); // Pending in-flight requests { string: MatrixEvent } private notificationCounts: NotificationCount = {}; private readonly threadNotifications = new Map(); public readonly cachedThreadReadReceipts = new Map(); @@ -356,7 +356,7 @@ export class Room extends ReadReceipt { * accountData Dict of per-room account_data events; the keys are the * event type and the values are the events. */ - public accountData: Record = {}; // $eventType: $event + public accountData: Map = new Map(); // $eventType: $event /** * The room summary. */ @@ -902,7 +902,7 @@ export class Room extends ReadReceipt { rawMembersEvents = await this.loadMembersFromServer(); logger.log(`LL: got ${rawMembersEvents.length} ` + `members from server for room ${this.roomId}`); } - const memberEvents = rawMembersEvents.map(this.client.getEventMapper()); + const memberEvents = rawMembersEvents.filter(noUnsafeEventProps).map(this.client.getEventMapper()); return { memberEvents, fromServer }; } @@ -2252,8 +2252,7 @@ export class Room extends ReadReceipt { const txnId = event.getUnsigned().transaction_id; if (!txnId && event.getSender() === this.myUserId) { // check the txn map for a matching event ID - for (const tid in this.txnToEvent) { - const localEvent = this.txnToEvent[tid]; + for (const [tid, localEvent] of this.txnToEvent) { if (localEvent.getId() === event.getId()) { logger.debug("processLiveEvent: found sent event without txn ID: ", tid, event.getId()); // update the unsigned field so we can re-use the same codepaths @@ -2328,7 +2327,7 @@ export class Room extends ReadReceipt { throw new Error("addPendingEvent called on an event with status " + event.status); } - if (this.txnToEvent[txnId]) { + if (this.txnToEvent.get(txnId)) { throw new Error("addPendingEvent called on an event with known txnId " + txnId); } @@ -2337,7 +2336,7 @@ export class Room extends ReadReceipt { // on the unfiltered timelineSet. EventTimeline.setEventMetadata(event, this.getLiveTimeline().getState(EventTimeline.FORWARDS)!, false); - this.txnToEvent[txnId] = event; + this.txnToEvent.set(txnId, event); if (this.pendingEventList) { if (this.pendingEventList.some((e) => e.status === EventStatus.NOT_SENT)) { logger.warn("Setting event as NOT_SENT due to messages in the same state"); @@ -2429,8 +2428,8 @@ export class Room extends ReadReceipt { this.relations.aggregateChildEvent(event); } - public getEventForTxnId(txnId: string): MatrixEvent { - return this.txnToEvent[txnId]; + public getEventForTxnId(txnId: string): MatrixEvent | undefined { + return this.txnToEvent.get(txnId); } /** @@ -2457,7 +2456,7 @@ export class Room extends ReadReceipt { logger.debug(`Got remote echo for event ${oldEventId} -> ${newEventId} old status ${oldStatus}`); // no longer pending - delete this.txnToEvent[remoteEvent.getUnsigned().transaction_id!]; + this.txnToEvent.delete(remoteEvent.getUnsigned().transaction_id!); // if it's in the pending list, remove it if (this.pendingEventList) { @@ -2670,7 +2669,7 @@ export class Room extends ReadReceipt { this.processLiveEvent(event); if (event.getUnsigned().transaction_id) { - const existingEvent = this.txnToEvent[event.getUnsigned().transaction_id!]; + const existingEvent = this.txnToEvent.get(event.getUnsigned().transaction_id!); if (existingEvent) { // remote echo of an event we sent earlier this.handleRemoteEcho(event, existingEvent); @@ -2939,8 +2938,9 @@ export class Room extends ReadReceipt { if (event.getType() === "m.tag") { this.addTags(event); } - const lastEvent = this.accountData[event.getType()]; - this.accountData[event.getType()] = event; + const eventType = event.getType(); + const lastEvent = this.accountData.get(eventType); + this.accountData.set(eventType, event); this.emit(RoomEvent.AccountData, event, this, lastEvent); } } @@ -2951,7 +2951,7 @@ export class Room extends ReadReceipt { * @returns the account_data event in question */ public getAccountData(type: EventType | string): MatrixEvent | undefined { - return this.accountData[type]; + return this.accountData.get(type); } /** diff --git a/src/store/index.ts b/src/store/index.ts index 84f613cf4dd..78d4fe13ce4 100644 --- a/src/store/index.ts +++ b/src/store/index.ts @@ -36,7 +36,7 @@ export interface ISavedSync { * A store for most of the data js-sdk needs to store, apart from crypto data */ export interface IStore { - readonly accountData: Record; // type : content + readonly accountData: Map; // type : content // XXX: The indexeddb store exposes a non-standard emitter for the "degraded" event // for when it falls back to being a memory store due to errors. diff --git a/src/store/memory.ts b/src/store/memory.ts index 025a632aaf2..d859dddc10b 100644 --- a/src/store/memory.ts +++ b/src/store/memory.ts @@ -31,6 +31,7 @@ import { ISyncResponse } from "../sync-accumulator"; import { IStateEventWithRoomId } from "../@types/search"; import { IndexedToDeviceBatch, ToDeviceBatchWithTxnId } from "../models/ToDeviceMessage"; import { IStoredClientOpts } from "../client"; +import { MapWithDefault } from "../utils"; function isValidFilterId(filterId?: string | number | null): boolean { const isValidStr = @@ -54,10 +55,10 @@ export class MemoryStore implements IStore { // userId: { // filterId: Filter // } - private filters: Record> = {}; - public accountData: Record = {}; // type : content + private filters: MapWithDefault> = new MapWithDefault(() => new Map()); + public accountData: Map = new Map(); // type: content protected readonly localStorage?: Storage; - private oobMembers: Record = {}; // roomId: [member events] + private oobMembers: Map = new Map(); // roomId: [member events] private pendingEvents: { [roomId: string]: Partial[] } = {}; private clientOptions?: IStoredClientOpts; private pendingToDeviceBatches: IndexedToDeviceBatch[] = []; @@ -220,10 +221,7 @@ export class MemoryStore implements IStore { */ public storeFilter(filter: Filter): void { if (!filter?.userId || !filter?.filterId) return; - if (!this.filters[filter.userId]) { - this.filters[filter.userId] = {}; - } - this.filters[filter.userId][filter.filterId] = filter; + this.filters.getOrCreate(filter.userId).set(filter.filterId, filter); } /** @@ -231,10 +229,7 @@ export class MemoryStore implements IStore { * @returns A filter or null. */ public getFilter(userId: string, filterId: string): Filter | null { - if (!this.filters[userId] || !this.filters[userId][filterId]) { - return null; - } - return this.filters[userId][filterId]; + return this.filters.get(userId)?.get(filterId) || null; } /** @@ -289,9 +284,9 @@ export class MemoryStore implements IStore { // MSC3391: an event with content of {} should be interpreted as deleted const isDeleted = !Object.keys(event.getContent()).length; if (isDeleted) { - delete this.accountData[event.getType()]; + this.accountData.delete(event.getType()); } else { - this.accountData[event.getType()] = event; + this.accountData.set(event.getType(), event); } }); } @@ -302,7 +297,7 @@ export class MemoryStore implements IStore { * @returns the user account_data event of given type, if any */ public getAccountData(eventType: EventType | string): MatrixEvent | undefined { - return this.accountData[eventType]; + return this.accountData.get(eventType); } /** @@ -368,14 +363,8 @@ export class MemoryStore implements IStore { // userId: User }; this.syncToken = null; - this.filters = { - // userId: { - // filterId: Filter - // } - }; - this.accountData = { - // type : content - }; + this.filters = new MapWithDefault(() => new Map()); + this.accountData = new Map(); // type : content return Promise.resolve(); } @@ -386,7 +375,7 @@ export class MemoryStore implements IStore { * @returns in case the members for this room haven't been stored yet */ public getOutOfBandMembers(roomId: string): Promise { - return Promise.resolve(this.oobMembers[roomId] || null); + return Promise.resolve(this.oobMembers.get(roomId) || null); } /** @@ -397,12 +386,12 @@ export class MemoryStore implements IStore { * @returns when all members have been stored */ public setOutOfBandMembers(roomId: string, membershipEvents: IStateEventWithRoomId[]): Promise { - this.oobMembers[roomId] = membershipEvents; + this.oobMembers.set(roomId, membershipEvents); return Promise.resolve(); } public clearOutOfBandMembers(roomId: string): Promise { - this.oobMembers = {}; + this.oobMembers.delete(roomId); return Promise.resolve(); } diff --git a/src/store/stub.ts b/src/store/stub.ts index 445f9e8ff27..e4402ed1e40 100644 --- a/src/store/stub.ts +++ b/src/store/stub.ts @@ -34,7 +34,7 @@ import { IStoredClientOpts } from "../client"; * Construct a stub store. This does no-ops on most store methods. */ export class StubStore implements IStore { - public readonly accountData = {}; // stub + public readonly accountData = new Map(); // stub private fromToken: string | null = null; /** @returns whether or not the database was newly created in this session. */ diff --git a/src/sync-accumulator.ts b/src/sync-accumulator.ts index be1e8a16ffa..fef03d74ce0 100644 --- a/src/sync-accumulator.ts +++ b/src/sync-accumulator.ts @@ -19,7 +19,7 @@ limitations under the License. */ import { logger } from "./logger"; -import { deepCopy, isSupportedReceiptType } from "./utils"; +import { deepCopy, isSupportedReceiptType, MapWithDefault, recursiveMapToObject } from "./utils"; import { IContent, IUnsigned } from "./models/event"; import { IRoomSummary } from "./models/room-summary"; import { EventType } from "./@types/event"; @@ -585,29 +585,31 @@ export class SyncAccumulator { } as IContent, }; + const receiptEventContent: MapWithDefault< + string, + MapWithDefault> + > = new MapWithDefault(() => new MapWithDefault(() => new Map())); + for (const [userId, receiptData] of Object.entries(roomData._readReceipts)) { - if (!receiptEvent.content[receiptData.eventId]) { - receiptEvent.content[receiptData.eventId] = {}; - } - if (!receiptEvent.content[receiptData.eventId][receiptData.type]) { - receiptEvent.content[receiptData.eventId][receiptData.type] = {}; - } - receiptEvent.content[receiptData.eventId][receiptData.type][userId] = receiptData.data; + receiptEventContent + .getOrCreate(receiptData.eventId) + .getOrCreate(receiptData.type) + .set(userId, receiptData.data); } for (const threadReceipts of Object.values(roomData._threadReadReceipts)) { for (const [userId, receiptData] of Object.entries(threadReceipts)) { - if (!receiptEvent.content[receiptData.eventId]) { - receiptEvent.content[receiptData.eventId] = {}; - } - if (!receiptEvent.content[receiptData.eventId][receiptData.type]) { - receiptEvent.content[receiptData.eventId][receiptData.type] = {}; - } - receiptEvent.content[receiptData.eventId][receiptData.type][userId] = receiptData.data; + receiptEventContent + .getOrCreate(receiptData.eventId) + .getOrCreate(receiptData.type) + .set(userId, receiptData.data); } } + + receiptEvent.content = recursiveMapToObject(receiptEventContent); + // add only if we have some receipt data - if (Object.keys(receiptEvent.content).length > 0) { + if (receiptEventContent.size > 0) { roomJson.ephemeral.events.push(receiptEvent as IMinimalEvent); } diff --git a/src/sync.ts b/src/sync.ts index 057ee1540da..3fa3616c883 100644 --- a/src/sync.ts +++ b/src/sync.ts @@ -29,7 +29,7 @@ import type { SyncCryptoCallbacks } from "./common-crypto/CryptoBackend"; import { User, UserEvent } from "./models/user"; import { NotificationCountType, Room, RoomEvent } from "./models/room"; import * as utils from "./utils"; -import { IDeferred } from "./utils"; +import { IDeferred, noUnsafeEventProps, unsafeProp } from "./utils"; import { Filter } from "./filter"; import { EventTimeline } from "./models/event-timeline"; import { logger } from "./logger"; @@ -1133,22 +1133,24 @@ export class SyncApi { // handle presence events (User objects) if (Array.isArray(data.presence?.events)) { - data.presence!.events.map(client.getEventMapper()).forEach(function (presenceEvent) { - let user = client.store.getUser(presenceEvent.getSender()!); - if (user) { - user.setPresenceEvent(presenceEvent); - } else { - user = createNewUser(client, presenceEvent.getSender()!); - user.setPresenceEvent(presenceEvent); - client.store.storeUser(user); - } - client.emit(ClientEvent.Event, presenceEvent); - }); + data.presence!.events.filter(noUnsafeEventProps) + .map(client.getEventMapper()) + .forEach(function (presenceEvent) { + let user = client.store.getUser(presenceEvent.getSender()!); + if (user) { + user.setPresenceEvent(presenceEvent); + } else { + user = createNewUser(client, presenceEvent.getSender()!); + user.setPresenceEvent(presenceEvent); + client.store.storeUser(user); + } + client.emit(ClientEvent.Event, presenceEvent); + }); } // handle non-room account_data if (Array.isArray(data.account_data?.events)) { - const events = data.account_data.events.map(client.getEventMapper()); + const events = data.account_data.events.filter(noUnsafeEventProps).map(client.getEventMapper()); const prevEventsMap = events.reduce>((m, c) => { m[c.getType()!] = client.store.getAccountData(c.getType()); return m; @@ -1171,7 +1173,7 @@ export class SyncApi { // handle to-device events if (data.to_device && Array.isArray(data.to_device.events) && data.to_device.events.length > 0) { - let toDeviceMessages: IToDeviceEvent[] = data.to_device.events; + let toDeviceMessages: IToDeviceEvent[] = data.to_device.events.filter(noUnsafeEventProps); if (this.syncOpts.cryptoCallbacks) { toDeviceMessages = await this.syncOpts.cryptoCallbacks.preprocessToDeviceMessages(toDeviceMessages); @@ -1635,18 +1637,20 @@ export class SyncApi { // to // [{stuff+Room+isBrandNewRoom}, {stuff+Room+isBrandNewRoom}] const client = this.client; - return Object.keys(obj).map((roomId) => { - const arrObj = obj[roomId] as T & { room: Room; isBrandNewRoom: boolean }; - let room = client.store.getRoom(roomId); - let isBrandNewRoom = false; - if (!room) { - room = this.createRoom(roomId); - isBrandNewRoom = true; - } - arrObj.room = room; - arrObj.isBrandNewRoom = isBrandNewRoom; - return arrObj; - }); + return Object.keys(obj) + .filter((k) => !unsafeProp(k)) + .map((roomId) => { + const arrObj = obj[roomId] as T & { room: Room; isBrandNewRoom: boolean }; + let room = client.store.getRoom(roomId); + let isBrandNewRoom = false; + if (!room) { + room = this.createRoom(roomId); + isBrandNewRoom = true; + } + arrObj.room = room; + arrObj.isBrandNewRoom = isBrandNewRoom; + return arrObj; + }); } private mapSyncEventsFormat( @@ -1659,7 +1663,7 @@ export class SyncApi { } const mapper = this.client.getEventMapper({ decrypt }); type TaggedEvent = (IStrippedState | IRoomEvent | IStateEvent | IMinimalEvent) & { room_id?: string }; - return (obj.events as TaggedEvent[]).map(function (e) { + return (obj.events as TaggedEvent[]).filter(noUnsafeEventProps).map(function (e) { if (room) { e.room_id = room.roomId; } diff --git a/src/utils.ts b/src/utils.ts index 6a15e97444c..0c3aea773cd 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -22,7 +22,7 @@ import unhomoglyph from "unhomoglyph"; import promiseRetry from "p-retry"; import { Optional } from "matrix-events-sdk"; -import { MatrixEvent } from "./models/event"; +import { IEvent, MatrixEvent } from "./models/event"; import { M_TIMESTAMP } from "./@types/location"; import { ReceiptType } from "./@types/read_receipts"; @@ -703,3 +703,68 @@ export function mapsEqual(x: Map, y: Map, eq = (v1: V, v2: V): } return true; } + +function processMapToObjectValue(value: any): any { + if (value instanceof Map) { + // Value is a Map. Recursively map it to an object. + return recursiveMapToObject(value); + } else if (Array.isArray(value)) { + // Value is an Array. Recursively map the value (e.g. to cover Array of Arrays). + return value.map((v) => processMapToObjectValue(v)); + } else { + return value; + } +} + +/** + * Recursively converts Maps to plain objects. + * Also supports sub-lists of Maps. + */ +export function recursiveMapToObject(map: Map): any { + const targetMap = new Map(); + + for (const [key, value] of map) { + targetMap.set(key, processMapToObjectValue(value)); + } + + return Object.fromEntries(targetMap.entries()); +} + +export function unsafeProp(prop: K): boolean { + return prop === "__proto__" || prop === "prototype" || prop === "constructor"; +} + +export function safeSet(obj: Record, prop: K, value: any): void { + if (unsafeProp(prop)) { + throw new Error("Trying to modify prototype or constructor"); + } + + obj[prop] = value; +} + +export function noUnsafeEventProps(event: Partial): boolean { + return !( + unsafeProp(event.room_id) || + unsafeProp(event.sender) || + unsafeProp(event.user_id) || + unsafeProp(event.event_id) + ); +} + +export class MapWithDefault extends Map { + public constructor(private createDefault: () => V) { + super(); + } + + /** + * Returns the value if the key already exists. + * If not, it creates a new value under that key using the ctor callback and returns it. + */ + public getOrCreate(key: K): V { + if (!this.has(key)) { + this.set(key, this.createDefault()); + } + + return this.get(key)!; + } +} diff --git a/src/webrtc/call.ts b/src/webrtc/call.ts index 6d48c1c06dc..17318b0c205 100644 --- a/src/webrtc/call.ts +++ b/src/webrtc/call.ts @@ -600,7 +600,7 @@ export class MatrixCall extends TypedEventEmitter([[userId, new Map([[this.opponentDeviceId, content]])]]), + ); } } else { this.emit(CallEvent.SendVoipEvent, {