Skip to content

Commit

Permalink
Add encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
richardhopton committed Apr 3, 2023
1 parent d3d0723 commit 1396f23
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 67 deletions.
69 changes: 32 additions & 37 deletions lib/connection.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
const EventEmitter = require("events");
const Net = require('net');
const { Entities } = require('./entities');
const { mapMessageByType } = require("./utils/mapMessageByType");
const { serialize, deserialize, pb } = require('./utils/messages');
const { pb } = require('./utils/messages');
const Package = require('../package.json');
const PlaintextFrameHelper = require('./utils/PlaintextFrameHelper');
const NoiseFrameHelper = require('./utils/NoiseFrameHelper');

class EsphomeNativeApiConnection extends EventEmitter {
constructor({
port = 6053,
host,
clientInfo = Package.name + ' ' + Package.version,
password = '',
encryptionKey = '',
expectedServerName = '',
reconnect = true,
reconnectInterval = 30 * 1000,
pingInterval = 15 * 1000,
Expand All @@ -19,30 +22,18 @@ class EsphomeNativeApiConnection extends EventEmitter {
super();
if (!host) throw new Error(`Host is required`);

this.socket = new Net.Socket();
this.buffer = Buffer.from([]);
this.frameHelper = !encryptionKey ? new PlaintextFrameHelper(host, port) : new NoiseFrameHelper(host, port, encryptionKey, expectedServerName);

// socket data
this.socket.on('data', data => {
this.emit('data', data);
this.buffer = Buffer.concat([this.buffer, data]);
let message;
try {
while (message = deserialize(this.buffer)) {
this.buffer = this.buffer.slice(message.length);
const type = message.constructor.type;
const mapped = mapMessageByType(type, message.toObject());
this.emit(`message.${type}`, mapped);
this.emit('message', type, mapped);
}
} catch (e) {
this.emit('error', e);
this.emit('unhandledData', data);
}
})
this.frameHelper.on('message', (message) => {
const type = message.constructor.type;
const mapped = mapMessageByType(type, message.toObject());

// socket close
this.socket.on('close', () => {
this.emit(`message.${type}`, mapped);
this.emit('message', type, mapped);
});

// frame helper close
this.frameHelper.on('close', () => {
this.connected = false;
this.authorized = false;
clearInterval(this.pingTimer);
Expand All @@ -54,8 +45,8 @@ class EsphomeNativeApiConnection extends EventEmitter {
}
})

// socket connect
this.socket.on('connect', async () => {
// frame helper connect
this.frameHelper.on('connect', async () => {
clearTimeout(this.reconnectTimer);
this.connected = true;
try {
Expand All @@ -65,39 +56,42 @@ class EsphomeNativeApiConnection extends EventEmitter {
this.authorized = true;
} catch(e) {
this.emit('error', e);
this.socket.end();
this.frameHelper.end();
}
this.pingTimer = setInterval(async () => {
try {
await this.pingService();
this.pingCount = 0;
} catch(e) {
if (++this.pingCount >= this.pingAttempts) {
this.socket.end();
this.frameHelper.end();
}
}
}, this.pingInterval);
})

// socket error
this.socket.on('error', e => {
// frame helper error
this.frameHelper.on('error', (e) => {
this.emit('error', e);
this.socket.end();
})

this.frameHelper.on('data', (data)=> {
this.emit('data', data);
})

// DisconnectRequest
this.on('message.DisconnectRequest', () => {
try {
this.sendMessage(new pb.DisconnectResponse());
this.socket.destroy();
this.frameHelper.destroy();
} catch(e) {
this.emit('error', new Error(`Failed respond to DisconnectRequest. Reason: ${e.message}`));
}
})

// DisconnectResponse
this.on('message.DisconnectResponse', () => {
this.socket.destroy();
this.frameHelper.destroy();
})

// PingRequest
Expand Down Expand Up @@ -127,6 +121,7 @@ class EsphomeNativeApiConnection extends EventEmitter {
this.host = host;
this.clientInfo = clientInfo;
this.password = password;
this.encryptionKey = encryptionKey;
this.reconnect = reconnect;
this.reconnectTimer = null;
this.reconnectInterval = reconnectInterval;
Expand All @@ -153,20 +148,20 @@ class EsphomeNativeApiConnection extends EventEmitter {
}
connect() {
if (this.connected) throw new Error(`Already connected. Can't connect.`);
this.socket.connect(this.port, this.host);
this.frameHelper.connect();
}
disconnect() {
clearInterval(this.pingTimer);
clearTimeout(this.reconnectTimer);
this.reconnect = false;
this.sendMessage(new pb.DisconnectRequest());
this.socket.removeAllListeners();
this.frameHelper.removeAllListeners();
this.removeAllListeners();
this.socket.destroy();
this.frameHelper.destroy();
}
sendMessage(message) {
if (!this.connected) throw new Error(`Socket is not connected`);
this.socket.write(serialize(message));
this.frameHelper.sendMessage(message);
}
sendCommandMessage(message) {
if (!this.connected) throw new Error(`Not connected`);
Expand Down
41 changes: 41 additions & 0 deletions lib/utils/frameHelper.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
const EventEmitter = require("events");
const Net = require("net");
const { pb, id_to_type } = require("./messages");

class FrameHelper extends EventEmitter {
constructor(host, port) {
super();
this.host = host;
this.port = port;
this.buffer = Buffer.from([]);
this.socket = new Net.Socket();
this.socket.on("close", () => this.emit("close"));
this.socket.on("error", (e) => {
this.emit("error", e);
this.socket.end();
});
}

connect() {
this.socket.connect(this.port, this.host);
}

end() {
this.socket.end();
}

destroy() {
this.socket.destroy();
}

removeAllListeners() {
this.socket.removeAllListeners();
super.removeAllListeners();
}

buildMessage(messageId, bytes) {
return pb[id_to_type[messageId]].deserializeBinary(bytes);
}
}

module.exports = FrameHelper;
30 changes: 0 additions & 30 deletions lib/utils/messages.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
const pb = require('../protoc/api_pb');
const { varuint_to_bytes, recv_varuint } = require('./index');

const id_to_type =
{
Expand Down Expand Up @@ -102,34 +101,5 @@ for(const [ id, type ] of Object.entries(id_to_type)){
module.exports = {
id_to_type,
type_to_id,
serialize(message) {
const encoded = message.serializeBinary();
return Buffer.from([
0,
...varuint_to_bytes(encoded.length),
...varuint_to_bytes(message.constructor.id),
...encoded
]);
},
deserialize(buffer) {
if (buffer.length < 3) return null;
let offset = 0;
const next = () => {
if (offset >= buffer.length) return null;
return buffer[offset++];
}
const t =next();
if (t !== 0) throw new Error('Bad format. Expected 0 at the begin');

const message_length = recv_varuint(next);
if (message_length === null) return null;
const message_id = recv_varuint(next);
if (message_id === null) return null;
if (message_length + offset > buffer.length) return null;
// else if(message_length + offset < buffer.length) throw new Error(`Bad format. Expected buffer length = ${message_length + offset}. Received ${buffer.length}`);
const message = pb[id_to_type[message_id]].deserializeBinary(buffer.slice(offset, message_length + offset));
message.length = message_length + offset;
return message;
},
pb
}
136 changes: 136 additions & 0 deletions lib/utils/noiseFrameHelper.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
const FrameHelper = require("./frameHelper");
const createNoise = require("@richardhopton/noise-c.wasm");

//const getAsciiBytes = (value) => new Uint8Array(Object.keys(value).map((i) => value.charCodeAt(i)));
const HANDSHAKE_HELLO = 1;
const HANDSHAKE_HANDSHAKE = 2;
const HANDSHAKE_READY = 3;
const HANDSHAKE_CLOSED = 4;

class NoiseFrameHelper extends FrameHelper {
constructor(host, port, encryptionKey, expectedServerName) {
super(host, port);
this.encryptionKey = encryptionKey;
this.expectedServerName = expectedServerName;

this.handshakeState = HANDSHAKE_HELLO;
this.socket.on("data", (data) => this.onData(data));
this.socket.on("connect", async () => await this.onConnect());
this.socket.on("close", () => this.handshakeState = HANDSHAKE_CLOSED);
}

async onConnect() {
const psk = Buffer.from(this.encryptionKey, "base64");
const noise = await new Promise((res) => createNoise(res));
this.client = noise.HandshakeState(
"Noise_NNpsk0_25519_ChaChaPoly_SHA256",
noise.constants.NOISE_ROLE_INITIATOR
);
this.client.Initialize(
new Uint8Array(Buffer.from("NoiseAPIInit\x00\x00")),
null,
null,
new Uint8Array(psk)
);
this.write([]);
}

extractFrameBytes() {
if (this.buffer.length < 3) return null;
const indicator = this.buffer[0];
if (indicator != 1)
throw new Error("Bad format. Expected 1 at the begin");

const frameEnd = 3 + ((this.buffer[1] << 8) | this.buffer[2]);
if (this.buffer.length < frameEnd) return null;
const frame = this.buffer.subarray(3, frameEnd);
this.buffer = this.buffer.subarray(frameEnd);
return frame;
}

onData(data) {
this.emit("data", data);
this.buffer = Buffer.concat([this.buffer, data]);
let frame;
while ((frame = this.extractFrameBytes())) {
switch (this.handshakeState) {
case HANDSHAKE_HELLO:
return this.handleHello(frame);
case HANDSHAKE_HANDSHAKE:
return this.handleHandshake(frame);
default:
const message = this.deserialize(this.decryptor.DecryptWithAd([], frame));
return this.emit("message", message);
}
}
}

handleHello(serverHello) {
const chosenProto = serverHello[0];
if (chosenProto != 1)
throw new Error(
`Unknown protocol selected by server ${chosenProto}`
);
if (!!this.expectedServerName) {
const serverNameEnd = serverHello.indexOf("\0", 1);
if (serverNameEnd > 1) {
const serverName = serverHello
.subarray(1, serverNameEnd)
.toString();
if (this.expectedServerName != serverName)
throw new Error(`Server name mismatch, expected ${this.expectedServerName}, got ${serverName}`);
}
}

this.handshakeState = HANDSHAKE_HANDSHAKE;
this.write([0, ...this.client.WriteMessage()]);
}

handleHandshake(serverHandshake) {
const header = serverHandshake[0];
const message = serverHandshake.subarray(1);
if (header != 0) {
throw new Error(`Handshake failure: ${message.toString()}`);
}
this.client.ReadMessage(message, true);
[this.encryptor, this.decryptor] = this.client.Split();
this.handshakeState = HANDSHAKE_READY;
this.emit("connect");
}

serialize(message) {
const encoded = message.serializeBinary();
const messageId = message.constructor.id;
const messageLength = encoded.length;
const buffer = Buffer.from([
(messageId >> 8) & 255,
messageId & 255,
(messageLength >> 8) & 255,
messageLength & 255,
...encoded,
]);
return buffer;
}

deserialize(buffer) {
if (buffer.length < 4) return null;
const messageId = (buffer[0] << 8) | buffer[1];
const messageLength = (buffer[2] << 8) | buffer[3];
const message = this.buildMessage(messageId, buffer.subarray(4, messageLength + 4));
message.length = messageLength + 4;
return message;
}

write(frame) {
const frameLength = frame.length;
const header = [1, (frameLength >> 8) & 255, frameLength & 255];
const payload = Buffer.from([...header, ...frame]);
this.socket.write(payload);
}

sendMessage(message) {
this.write(this.encryptor.EncryptWithAd([], this.serialize(message)));
}
}

module.exports = NoiseFrameHelper;
Loading

0 comments on commit 1396f23

Please sign in to comment.