From 1396f234e952dc28ddc687b9e2c232fff3cb6af8 Mon Sep 17 00:00:00 2001 From: Richard Hopton Date: Sun, 2 Apr 2023 18:34:08 -0700 Subject: [PATCH] Add encryption --- lib/connection.js | 69 +++++++-------- lib/utils/frameHelper.js | 41 +++++++++ lib/utils/messages.js | 30 ------- lib/utils/noiseFrameHelper.js | 136 ++++++++++++++++++++++++++++++ lib/utils/plaintextFrameHelper.js | 75 ++++++++++++++++ package.json | 1 + 6 files changed, 285 insertions(+), 67 deletions(-) create mode 100644 lib/utils/frameHelper.js create mode 100644 lib/utils/noiseFrameHelper.js create mode 100644 lib/utils/plaintextFrameHelper.js diff --git a/lib/connection.js b/lib/connection.js index 22db886..b93c420 100644 --- a/lib/connection.js +++ b/lib/connection.js @@ -1,9 +1,10 @@ const EventEmitter = require("events"); -const Net = require('net'); const { Entities } = require('./entities'); const { mapMessageByType } = require("./utils/mapMessageByType"); -const { serialize, deserialize, pb } = require('./utils/messages'); +const { pb } = require('./utils/messages'); const Package = require('../package.json'); +const PlaintextFrameHelper = require('./utils/PlaintextFrameHelper'); +const NoiseFrameHelper = require('./utils/NoiseFrameHelper'); class EsphomeNativeApiConnection extends EventEmitter { constructor({ @@ -11,6 +12,8 @@ class EsphomeNativeApiConnection extends EventEmitter { host, clientInfo = Package.name + ' ' + Package.version, password = '', + encryptionKey = '', + expectedServerName = '', reconnect = true, reconnectInterval = 30 * 1000, pingInterval = 15 * 1000, @@ -19,30 +22,18 @@ class EsphomeNativeApiConnection extends EventEmitter { super(); if (!host) throw new Error(`Host is required`); - this.socket = new Net.Socket(); - this.buffer = Buffer.from([]); + this.frameHelper = !encryptionKey ? new PlaintextFrameHelper(host, port) : new NoiseFrameHelper(host, port, encryptionKey, expectedServerName); - // socket data - this.socket.on('data', data => { - this.emit('data', data); - this.buffer = Buffer.concat([this.buffer, data]); - let message; - try { - while (message = deserialize(this.buffer)) { - this.buffer = this.buffer.slice(message.length); - const type = message.constructor.type; - const mapped = mapMessageByType(type, message.toObject()); - this.emit(`message.${type}`, mapped); - this.emit('message', type, mapped); - } - } catch (e) { - this.emit('error', e); - this.emit('unhandledData', data); - } - }) + this.frameHelper.on('message', (message) => { + const type = message.constructor.type; + const mapped = mapMessageByType(type, message.toObject()); - // socket close - this.socket.on('close', () => { + this.emit(`message.${type}`, mapped); + this.emit('message', type, mapped); + }); + + // frame helper close + this.frameHelper.on('close', () => { this.connected = false; this.authorized = false; clearInterval(this.pingTimer); @@ -54,8 +45,8 @@ class EsphomeNativeApiConnection extends EventEmitter { } }) - // socket connect - this.socket.on('connect', async () => { + // frame helper connect + this.frameHelper.on('connect', async () => { clearTimeout(this.reconnectTimer); this.connected = true; try { @@ -65,7 +56,7 @@ class EsphomeNativeApiConnection extends EventEmitter { this.authorized = true; } catch(e) { this.emit('error', e); - this.socket.end(); + this.frameHelper.end(); } this.pingTimer = setInterval(async () => { try { @@ -73,23 +64,26 @@ class EsphomeNativeApiConnection extends EventEmitter { this.pingCount = 0; } catch(e) { if (++this.pingCount >= this.pingAttempts) { - this.socket.end(); + this.frameHelper.end(); } } }, this.pingInterval); }) - // socket error - this.socket.on('error', e => { + // frame helper error + this.frameHelper.on('error', (e) => { this.emit('error', e); - this.socket.end(); + }) + + this.frameHelper.on('data', (data)=> { + this.emit('data', data); }) // DisconnectRequest this.on('message.DisconnectRequest', () => { try { this.sendMessage(new pb.DisconnectResponse()); - this.socket.destroy(); + this.frameHelper.destroy(); } catch(e) { this.emit('error', new Error(`Failed respond to DisconnectRequest. Reason: ${e.message}`)); } @@ -97,7 +91,7 @@ class EsphomeNativeApiConnection extends EventEmitter { // DisconnectResponse this.on('message.DisconnectResponse', () => { - this.socket.destroy(); + this.frameHelper.destroy(); }) // PingRequest @@ -127,6 +121,7 @@ class EsphomeNativeApiConnection extends EventEmitter { this.host = host; this.clientInfo = clientInfo; this.password = password; + this.encryptionKey = encryptionKey; this.reconnect = reconnect; this.reconnectTimer = null; this.reconnectInterval = reconnectInterval; @@ -153,20 +148,20 @@ class EsphomeNativeApiConnection extends EventEmitter { } connect() { if (this.connected) throw new Error(`Already connected. Can't connect.`); - this.socket.connect(this.port, this.host); + this.frameHelper.connect(); } disconnect() { clearInterval(this.pingTimer); clearTimeout(this.reconnectTimer); this.reconnect = false; this.sendMessage(new pb.DisconnectRequest()); - this.socket.removeAllListeners(); + this.frameHelper.removeAllListeners(); this.removeAllListeners(); - this.socket.destroy(); + this.frameHelper.destroy(); } sendMessage(message) { if (!this.connected) throw new Error(`Socket is not connected`); - this.socket.write(serialize(message)); + this.frameHelper.sendMessage(message); } sendCommandMessage(message) { if (!this.connected) throw new Error(`Not connected`); diff --git a/lib/utils/frameHelper.js b/lib/utils/frameHelper.js new file mode 100644 index 0000000..7854fd2 --- /dev/null +++ b/lib/utils/frameHelper.js @@ -0,0 +1,41 @@ +const EventEmitter = require("events"); +const Net = require("net"); +const { pb, id_to_type } = require("./messages"); + +class FrameHelper extends EventEmitter { + constructor(host, port) { + super(); + this.host = host; + this.port = port; + this.buffer = Buffer.from([]); + this.socket = new Net.Socket(); + this.socket.on("close", () => this.emit("close")); + this.socket.on("error", (e) => { + this.emit("error", e); + this.socket.end(); + }); + } + + connect() { + this.socket.connect(this.port, this.host); + } + + end() { + this.socket.end(); + } + + destroy() { + this.socket.destroy(); + } + + removeAllListeners() { + this.socket.removeAllListeners(); + super.removeAllListeners(); + } + + buildMessage(messageId, bytes) { + return pb[id_to_type[messageId]].deserializeBinary(bytes); + } +} + +module.exports = FrameHelper; diff --git a/lib/utils/messages.js b/lib/utils/messages.js index 57b8572..f513b70 100644 --- a/lib/utils/messages.js +++ b/lib/utils/messages.js @@ -1,5 +1,4 @@ const pb = require('../protoc/api_pb'); -const { varuint_to_bytes, recv_varuint } = require('./index'); const id_to_type = { @@ -102,34 +101,5 @@ for(const [ id, type ] of Object.entries(id_to_type)){ module.exports = { id_to_type, type_to_id, - serialize(message) { - const encoded = message.serializeBinary(); - return Buffer.from([ - 0, - ...varuint_to_bytes(encoded.length), - ...varuint_to_bytes(message.constructor.id), - ...encoded - ]); - }, - deserialize(buffer) { - if (buffer.length < 3) return null; - let offset = 0; - const next = () => { - if (offset >= buffer.length) return null; - return buffer[offset++]; - } - const t =next(); - if (t !== 0) throw new Error('Bad format. Expected 0 at the begin'); - - const message_length = recv_varuint(next); - if (message_length === null) return null; - const message_id = recv_varuint(next); - if (message_id === null) return null; - if (message_length + offset > buffer.length) return null; - // else if(message_length + offset < buffer.length) throw new Error(`Bad format. Expected buffer length = ${message_length + offset}. Received ${buffer.length}`); - const message = pb[id_to_type[message_id]].deserializeBinary(buffer.slice(offset, message_length + offset)); - message.length = message_length + offset; - return message; - }, pb } \ No newline at end of file diff --git a/lib/utils/noiseFrameHelper.js b/lib/utils/noiseFrameHelper.js new file mode 100644 index 0000000..cb28020 --- /dev/null +++ b/lib/utils/noiseFrameHelper.js @@ -0,0 +1,136 @@ +const FrameHelper = require("./frameHelper"); +const createNoise = require("@richardhopton/noise-c.wasm"); + +//const getAsciiBytes = (value) => new Uint8Array(Object.keys(value).map((i) => value.charCodeAt(i))); +const HANDSHAKE_HELLO = 1; +const HANDSHAKE_HANDSHAKE = 2; +const HANDSHAKE_READY = 3; +const HANDSHAKE_CLOSED = 4; + +class NoiseFrameHelper extends FrameHelper { + constructor(host, port, encryptionKey, expectedServerName) { + super(host, port); + this.encryptionKey = encryptionKey; + this.expectedServerName = expectedServerName; + + this.handshakeState = HANDSHAKE_HELLO; + this.socket.on("data", (data) => this.onData(data)); + this.socket.on("connect", async () => await this.onConnect()); + this.socket.on("close", () => this.handshakeState = HANDSHAKE_CLOSED); + } + + async onConnect() { + const psk = Buffer.from(this.encryptionKey, "base64"); + const noise = await new Promise((res) => createNoise(res)); + this.client = noise.HandshakeState( + "Noise_NNpsk0_25519_ChaChaPoly_SHA256", + noise.constants.NOISE_ROLE_INITIATOR + ); + this.client.Initialize( + new Uint8Array(Buffer.from("NoiseAPIInit\x00\x00")), + null, + null, + new Uint8Array(psk) + ); + this.write([]); + } + + extractFrameBytes() { + if (this.buffer.length < 3) return null; + const indicator = this.buffer[0]; + if (indicator != 1) + throw new Error("Bad format. Expected 1 at the begin"); + + const frameEnd = 3 + ((this.buffer[1] << 8) | this.buffer[2]); + if (this.buffer.length < frameEnd) return null; + const frame = this.buffer.subarray(3, frameEnd); + this.buffer = this.buffer.subarray(frameEnd); + return frame; + } + + onData(data) { + this.emit("data", data); + this.buffer = Buffer.concat([this.buffer, data]); + let frame; + while ((frame = this.extractFrameBytes())) { + switch (this.handshakeState) { + case HANDSHAKE_HELLO: + return this.handleHello(frame); + case HANDSHAKE_HANDSHAKE: + return this.handleHandshake(frame); + default: + const message = this.deserialize(this.decryptor.DecryptWithAd([], frame)); + return this.emit("message", message); + } + } + } + + handleHello(serverHello) { + const chosenProto = serverHello[0]; + if (chosenProto != 1) + throw new Error( + `Unknown protocol selected by server ${chosenProto}` + ); + if (!!this.expectedServerName) { + const serverNameEnd = serverHello.indexOf("\0", 1); + if (serverNameEnd > 1) { + const serverName = serverHello + .subarray(1, serverNameEnd) + .toString(); + if (this.expectedServerName != serverName) + throw new Error(`Server name mismatch, expected ${this.expectedServerName}, got ${serverName}`); + } + } + + this.handshakeState = HANDSHAKE_HANDSHAKE; + this.write([0, ...this.client.WriteMessage()]); + } + + handleHandshake(serverHandshake) { + const header = serverHandshake[0]; + const message = serverHandshake.subarray(1); + if (header != 0) { + throw new Error(`Handshake failure: ${message.toString()}`); + } + this.client.ReadMessage(message, true); + [this.encryptor, this.decryptor] = this.client.Split(); + this.handshakeState = HANDSHAKE_READY; + this.emit("connect"); + } + + serialize(message) { + const encoded = message.serializeBinary(); + const messageId = message.constructor.id; + const messageLength = encoded.length; + const buffer = Buffer.from([ + (messageId >> 8) & 255, + messageId & 255, + (messageLength >> 8) & 255, + messageLength & 255, + ...encoded, + ]); + return buffer; + } + + deserialize(buffer) { + if (buffer.length < 4) return null; + const messageId = (buffer[0] << 8) | buffer[1]; + const messageLength = (buffer[2] << 8) | buffer[3]; + const message = this.buildMessage(messageId, buffer.subarray(4, messageLength + 4)); + message.length = messageLength + 4; + return message; + } + + write(frame) { + const frameLength = frame.length; + const header = [1, (frameLength >> 8) & 255, frameLength & 255]; + const payload = Buffer.from([...header, ...frame]); + this.socket.write(payload); + } + + sendMessage(message) { + this.write(this.encryptor.EncryptWithAd([], this.serialize(message))); + } +} + +module.exports = NoiseFrameHelper; diff --git a/lib/utils/plaintextFrameHelper.js b/lib/utils/plaintextFrameHelper.js new file mode 100644 index 0000000..b258440 --- /dev/null +++ b/lib/utils/plaintextFrameHelper.js @@ -0,0 +1,75 @@ +const FrameHelper = require('./frameHelper'); +const { varuint_to_bytes, recv_varuint } = require("./index"); + +class PlaintextFrameHelper extends FrameHelper { + constructor(host, port) { + super(host, port); + this.buffer = Buffer.from([]); + this.socket.on("data", (data) => this.onData(data)); + this.socket.on("connect", () => this.onConnect()); + } + + serialize(message) { + const encoded = message.serializeBinary() + return Buffer.from([ + 0, + ...varuint_to_bytes(encoded.length), + ...varuint_to_bytes(message.constructor.id), + ...encoded, + ]); + } + + deserialize(buffer) { + if (buffer.length < 3) return null; + + let offset = 0; + const next = () => { + if (offset >= buffer.length) + return null; + return buffer[offset++]; + }; + const t = next(); + if (t !== 0) { + if(t === 1) throw new Error('Bad format: Encryption expected'); + throw new Error(`Bad format. Expected 0 at the begin`); + } + + const messageLength = recv_varuint(next); + if (messageLength === null) + return null; + const messageId = recv_varuint(next); + if (messageId === null) + return null; + if (messageLength + offset > buffer.length) + return null; + + const message = this.buildMessage(messageId, buffer.subarray(offset, messageLength + offset)); + message.length = messageLength + offset; + return message; + } + + onData(data) { + this.emit("data", data); + this.buffer = Buffer.concat([this.buffer, data]); + let message; + try { + while ((message = this.deserialize(this.buffer))) { + this.buffer = this.buffer.slice(message.length); + this.emit("message", message); + } + } catch (e) { + this.emit("error", e); + this.emit("unhandledData", data); + } + } + + onConnect() { + this.emit("connect"); + } + + sendMessage(message) { + this.socket.write(this.serialize(message)); + } +} + +module.exports = PlaintextFrameHelper; diff --git a/package.json b/package.json index 64a0b42..3fcdb40 100644 --- a/package.json +++ b/package.json @@ -22,6 +22,7 @@ }, "homepage": "https://github.com/twocolors/esphome-native-api#readme", "dependencies": { + "@richardhopton/noise-c.wasm": "https://github.com/richardhopton/noise-c.wasm#npm", "google-protobuf": "^3.21.2", "multicast-dns": "^7.2.5" }