Skip to content

Commit

Permalink
Improve #3215 implementation (#3226)
Browse files Browse the repository at this point in the history
* Improve key upload request

* Add fallback keys check

* Review fixes

* Add comments about sliding sync usage of `processKeyCounts`

* Review fixes

* Better wording
  • Loading branch information
florianduros authored Apr 5, 2023
1 parent 6ebbc15 commit 2daa429
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 84 deletions.
44 changes: 24 additions & 20 deletions spec/integ/crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1937,51 +1937,51 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,
jest.useRealTimers();
});

function listenToUpload(): Promise<number> {
function awaitKeyUploadRequest(): Promise<{ keysCount: number; fallbackKeysCount: number }> {
return new Promise((resolve) => {
const listener = (url: string, options: RequestInit) => {
const content = JSON.parse(options.body as string);
const keysCount = Object.keys(content?.one_time_keys || {}).length;
if (keysCount) resolve(keysCount);
const fallbackKeysCount = Object.keys(content?.fallback_keys || {}).length;
if (keysCount) resolve({ keysCount, fallbackKeysCount });
return {
one_time_key_counts: {
// The matrix client does `/upload` requests until 50 keys are uploaded
// We return here 60 to avoid the `/upload` request loop
signed_curve25519: keysCount ? 60 : keysCount,
},
};
};

// catch both r0 and v3 variants
fetchMock.post(
new URL("/_matrix/client/r0/keys/upload", aliceClient.getHomeserverUrl()).toString(),
listener,
{
for (const path of ["/_matrix/client/r0/keys/upload", "/_matrix/client/v3/keys/upload"]) {
fetchMock.post(new URL(path, aliceClient.getHomeserverUrl()).toString(), listener, {
// These routes are already defined in the E2EKeyReceiver
// We want to overwrite the behaviour of the E2EKeyReceiver
overwriteRoutes: true,
},
);
fetchMock.post(
new URL("/_matrix/client/v3/keys/upload", aliceClient.getHomeserverUrl()).toString(),
listener,
{
overwriteRoutes: true,
},
);
});
}
});
}

it("should make key upload request after sync", async () => {
let uploadPromise = listenToUpload();
let uploadPromise = awaitKeyUploadRequest();
expectAliceKeyQuery({ device_keys: { "@alice:localhost": {} }, failures: {} });
await startClientAndAwaitFirstSync();

syncResponder.sendOrQueueSyncResponse(getSyncResponse([]));

await syncPromise(aliceClient);
expect(await uploadPromise).toBeGreaterThan(0);

uploadPromise = listenToUpload();
// Verify that `/upload` is called on Alice's homesever
const { keysCount, fallbackKeysCount } = await uploadPromise;
expect(keysCount).toBeGreaterThan(0);
expect(fallbackKeysCount).toBe(0);

uploadPromise = awaitKeyUploadRequest();
syncResponder.sendOrQueueSyncResponse({
next_batch: 2,
device_one_time_keys_count: { signed_curve25519: 0 },
device_unused_fallback_key_types: [],
});

// Advance local date to 2 minutes
Expand All @@ -1990,7 +1990,11 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("crypto (%s)", (backend: string,

await syncPromise(aliceClient);

expect(await uploadPromise).toBeGreaterThan(0);
// After we set device_one_time_keys_count to 0
// a `/upload` is expected
const res = await uploadPromise;
expect(res.keysCount).toBeGreaterThan(0);
expect(res.fallbackKeysCount).toBeGreaterThan(0);
});
});
});
10 changes: 0 additions & 10 deletions spec/unit/rust-crypto/rust-crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,6 @@ describe("RustCrypto", () => {
const res = await rustCrypto.preprocessToDeviceMessages(inputs);
expect(res).toEqual(inputs);
});

it("should pass through one time key counts", async () => {
const oneTimeKeyCounts = new Map<string, number>([["signed_curve25519", 50]]);
await expect(rustCrypto.preprocessOneTimeKeyCounts(oneTimeKeyCounts)).resolves.not.toBeDefined();
});

it("should pass through unused fallback keys", async () => {
const unusedFallbackKeys = new Set(["signed_curve25519"]);
await expect(rustCrypto.preprocessUnusedFallbackKeys(unusedFallbackKeys)).resolves.not.toBeDefined();
});
});

describe("outgoing requests", () => {
Expand Down
24 changes: 2 additions & 22 deletions src/common-crypto/CryptoBackend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,32 +106,12 @@ export interface SyncCryptoCallbacks {
preprocessToDeviceMessages(events: IToDeviceEvent[]): Promise<IToDeviceEvent[]>;

/**
* Called by the /sync loop whenever there are incoming to-device messages.
*
* The implementation may preprocess the received messages (eg, decrypt them) and return an
* updated list of messages for dispatch to the rest of the system.
*
* Note that, unlike {@link ClientEvent.ToDeviceEvent} events, this is called on the raw to-device
* messages, rather than the results of any decryption attempts.
* Called by the /sync loop when one time key counts and unused fallback key details are received.
*
* @param oneTimeKeysCounts - the received one time key counts
* @returns A list of preprocessed to-device messages.
*/
preprocessOneTimeKeyCounts(oneTimeKeysCounts: Map<string, number>): Promise<void>;

/**
* Called by the /sync loop whenever there are incoming to-device messages.
*
* The implementation may preprocess the received messages (eg, decrypt them) and return an
* updated list of messages for dispatch to the rest of the system.
*
* Note that, unlike {@link ClientEvent.ToDeviceEvent} events, this is called on the raw to-device
* messages, rather than the results of any decryption attempts.
*
* @param unusedFallbackKeys - the received unused fallback keys
* @returns A list of preprocessed to-device messages.
*/
preprocessUnusedFallbackKeys(unusedFallbackKeys: Set<string>): Promise<void>;
processKeyCounts(oneTimeKeysCounts?: Record<string, number>, unusedFallbackKeys?: string[]): Promise<void>;

/**
* Called by the /sync loop whenever an m.room.encryption event is received.
Expand Down
26 changes: 19 additions & 7 deletions src/crypto/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,10 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
* onSyncCompleted). The count is e.g. coming from a /sync response.
*
* @param currentCount - The current count of one_time_keys to be stored
*
* TODO This method is called in `processKeyCounts` method and in the `sliding-sync-sdk`.
* TODO The `sliding-sync-sdk` should call `processKeyCounts` directly instead of `updateOneTimeKeyCount`
* TODO Move the content of `updateOneTimeKeyCount` in `processKeyCounts` after the `sliding-sync-sdk` change
*/
public updateOneTimeKeyCount(currentCount: number): void {
if (isFinite(currentCount)) {
Expand All @@ -1831,6 +1835,9 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
}
}

// TODO This method is called in `processKeyCounts`, `uploadOneTimeKeys` methods and in the `sliding-sync-sdk`.
// TODO The `sliding-sync-sdk` should call `processKeyCounts` directly instead of `setNeedsNewFallback`
// TODO Move the content of `setNeedsNewFallback` in `processKeyCounts` after the `sliding-sync-sdk` change
public setNeedsNewFallback(needsNewFallback: boolean): void {
this.needsNewFallback = needsNewFallback;
}
Expand Down Expand Up @@ -3215,14 +3222,19 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
});
}

public preprocessOneTimeKeyCounts(oneTimeKeysCounts: Map<string, number>): Promise<void> {
const currentCount = oneTimeKeysCounts.get("signed_curve25519") || 0;
this.updateOneTimeKeyCount(currentCount);
return Promise.resolve();
}
public processKeyCounts(oneTimeKeysCounts?: Record<string, number>, unusedFallbackKeys?: string[]): Promise<void> {
if (oneTimeKeysCounts !== undefined) {
this.updateOneTimeKeyCount(oneTimeKeysCounts["signed_curve25519"] || 0);
}

if (unusedFallbackKeys !== undefined) {
// If `unusedFallbackKeys` is defined, that means `device_unused_fallback_key_types`
// is present in the sync response, which indicates that the server supports fallback keys.
//
// If there's no unused signed_curve25519 fallback key, we need a new one.
this.setNeedsNewFallback(!unusedFallbackKeys.includes("signed_curve25519"));
}

public preprocessUnusedFallbackKeys(unusedFallbackKeys: Set<string>): Promise<void> {
this.setNeedsNewFallback(!unusedFallbackKeys.has("signed_curve25519"));
return Promise.resolve();
}

Expand Down
26 changes: 14 additions & 12 deletions src/rust-crypto/rust-crypto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,22 +209,24 @@ export class RustCrypto implements CryptoBackend {
return this.receiveSyncChanges({ events });
}

/** called by the sync loop to preprocess one time key counts
/** called by the sync loop to process one time key counts and unused fallback keys
*
* @param oneTimeKeysCounts - the received one time key counts
* @returns A list of preprocessed to-device messages.
*/
public async preprocessOneTimeKeyCounts(oneTimeKeysCounts: Map<string, number>): Promise<void> {
await this.receiveSyncChanges({ oneTimeKeysCounts });
}

/** called by the sync loop to preprocess unused fallback keys
*
* @param unusedFallbackKeys - the received unused fallback keys
* @returns A list of preprocessed to-device messages.
*/
public async preprocessUnusedFallbackKeys(unusedFallbackKeys: Set<string>): Promise<void> {
await this.receiveSyncChanges({ unusedFallbackKeys });
public async processKeyCounts(
oneTimeKeysCounts?: Record<string, number>,
unusedFallbackKeys?: string[],
): Promise<void> {
const mapOneTimeKeysCount = oneTimeKeysCounts && new Map<string, number>(Object.entries(oneTimeKeysCounts));
const setUnusedFallbackKeys = unusedFallbackKeys && new Set<string>(unusedFallbackKeys);

if (mapOneTimeKeysCount !== undefined || setUnusedFallbackKeys !== undefined) {
await this.receiveSyncChanges({
oneTimeKeysCounts: mapOneTimeKeysCount,
unusedFallbackKeys: setUnusedFallbackKeys,
});
}
}

/** called by the sync loop on m.room.encrypted events
Expand Down
18 changes: 5 additions & 13 deletions src/sync.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1524,19 +1524,11 @@ export class SyncApi {
}
}

// Handle one_time_keys_count
if (data.device_one_time_keys_count) {
const map = new Map<string, number>(Object.entries(data.device_one_time_keys_count));
this.syncOpts.cryptoCallbacks?.preprocessOneTimeKeyCounts(map);
}
if (data.device_unused_fallback_key_types || data["org.matrix.msc2732.device_unused_fallback_key_types"]) {
// The presence of device_unused_fallback_key_types indicates that the
// server supports fallback keys. If there's no unused
// signed_curve25519 fallback key we need a new one.
const unusedFallbackKeys =
data.device_unused_fallback_key_types || data["org.matrix.msc2732.device_unused_fallback_key_types"];
this.syncOpts.cryptoCallbacks?.preprocessUnusedFallbackKeys(new Set<string>(unusedFallbackKeys || null));
}
// Handle one_time_keys_count and unused fallback keys
this.syncOpts.cryptoCallbacks?.processKeyCounts(
data.device_one_time_keys_count,
data.device_unused_fallback_key_types ?? data["org.matrix.msc2732.device_unused_fallback_key_types"],
);
}

/**
Expand Down

0 comments on commit 2daa429

Please sign in to comment.