diff --git a/src/header.ts b/src/header.ts new file mode 100644 index 0000000..6e1986c --- /dev/null +++ b/src/header.ts @@ -0,0 +1,210 @@ +import { Buffer } from 'node:buffer' + +// TODO(lwvemike): remove before release +function todo(message: string) { + throw new Error(`TODO: ${message}`) +} + +interface Versions { + majorVersion: MajorVersion + minorVersion: MinorVersion +} +export interface BaseHeaderRecord { + majorVersion: MajorVersion + minorVersion: MinorVersion + flags: Flag + length: number + seqNo: number + sessionId: number + type: HeaderType +} + +export interface UnknownHeader { + majorVersion: number + minorVersion: number + flags: number + length: number + seqNo: number + sessionId: number + type: number +} + +function createVersionByte({ majorVersion, minorVersion }: Versions) { + return ((majorVersion & 0xF) << 4) | (minorVersion & 0xF) +} + +export const HEADER_TYPES = { + // TODO(lwvemike): maybe is not valid + TAC_DEFAULT: 0x00, + TAC_PLUS_AUTHEN: 0x01, // Authentication + TAC_PLUS_AUTHOR: 0x02, // Authorization + TAC_PLUS_ACCT: 0x03, // Accounting +} as const + +export const ALLOWED_HEADER_TYPES = Object.values(HEADER_TYPES) + +export type HeaderType = typeof ALLOWED_HEADER_TYPES[number] + +function isHeaderType(maybeType: number): maybeType is HeaderType { + return (ALLOWED_HEADER_TYPES as number[]).includes(maybeType) +} + +export const FLAGS = { + TAC_PLUS_UNENCRYPTED_FLAG: 0x01, + TAC_PLUS_SINGLE_CONNECT_FLAG: 0x04, +} as const + +export const ALLOWED_FLAGS = Object.values(FLAGS) + +function isFlag(maybeFlag: number): maybeFlag is Flag { + return (ALLOWED_FLAGS as number[]).includes(maybeFlag) +} + +export type Flag = typeof FLAGS[keyof typeof FLAGS] + +export const MAJOR_VERSIONS = { + TAC_PLUS_MAJOR_VER_DEFAULT: 0x0, + TAC_PLUS_MAJOR_VER: 0xC, +} as const + +const ALLOWED_MAJOR_VERSIONS = Object.values(MAJOR_VERSIONS) + +type MajorVersion = typeof ALLOWED_MAJOR_VERSIONS[number] + +function isMajorVersion(maybeMajorVersion: number): maybeMajorVersion is MajorVersion { + return (ALLOWED_MAJOR_VERSIONS as number[]).includes(maybeMajorVersion) +} + +export const MINOR_VERSIONS = { + TAC_PLUS_MINOR_VER_DEFAULT: 0x0, + TAC_PLUS_MINOR_VER_ONE: 0x1, +} as const + +const ALLOWED_MINOR_VERSIONS = Object.values(MINOR_VERSIONS) + +type MinorVersion = typeof ALLOWED_MINOR_VERSIONS[number] + +function isMinorVersion(maybeMinorVersion: number): maybeMinorVersion is MinorVersion { + return (ALLOWED_MINOR_VERSIONS as number[]).includes(maybeMinorVersion) +} + +function validateHeader({ majorVersion, minorVersion, flags, type, length, seqNo, sessionId }: UnknownHeader) { + if (!isMajorVersion(majorVersion)) { + throw new Error('Invalid major version') + } + + if (!isMinorVersion(minorVersion)) { + throw new Error('Invalid minor version') + } + + if (!isHeaderType(type)) { + throw new Error('Invalid header type') + } + + if (!isFlag(flags)) { + throw new Error('Invalid flag') + } + + return { + majorVersion, + minorVersion, + flags, + type, + length, + seqNo, + sessionId, + } +} + +type HeaderRecord = + & BaseHeaderRecord + & Record<'isEncrypted' | 'isSingleConnection', boolean> + +export class Header { + /** + * @throws Error + * @param raw + */ + static decode(raw: Buffer): HeaderRecord { + if (raw.length !== Header.SIZE) { + throw new Error(`Header size must be ${Header.SIZE}, but received ${raw.length}`) + } + + let offset = 0 + + const versionByte = raw.subarray(offset, 1).readUInt8(0) + offset += 1 + + const majorVersion = ((versionByte >> 4) & 0xF) + const minorVersion = (versionByte & 0xF) + + const type = raw.subarray(offset, 2).readUInt8(0) + offset += 1 + + const seqNo = raw.subarray(offset, 3).readUInt8(0) + if (seqNo === 255) { + todo('SeqNo is 255, you should handle restart the session') + } + offset += 1 + + const flags = raw.subarray(offset, 4).readUint8(0) + offset += 1 + + const sessionId = raw.subarray(offset, 8).readUInt32BE(0) + offset += 4 + + const length = raw.subarray(offset, 12).readUInt32BE(0) + + const header = validateHeader({ + majorVersion, + minorVersion, + flags, + type, + length, + seqNo, + sessionId, + }) + + return { + ...header, + isEncrypted: !((header.flags & FLAGS.TAC_PLUS_UNENCRYPTED_FLAG) === FLAGS.TAC_PLUS_UNENCRYPTED_FLAG), + isSingleConnection: ((header.flags & FLAGS.TAC_PLUS_SINGLE_CONNECT_FLAG) === FLAGS.TAC_PLUS_UNENCRYPTED_FLAG), + } + } + + static create(unknownHeader: UnknownHeader = Header.DEFAULT_HEADER): Buffer { + const buffer = Buffer.alloc(Header.SIZE) + + const { + majorVersion, + minorVersion, + type, + flags, + seqNo, + sessionId, + length, + } = validateHeader(unknownHeader) + + const versionByte = createVersionByte({ majorVersion, minorVersion }) + + buffer.writeUInt8(versionByte, 0) + buffer.writeUInt8(type, 1) + buffer.writeUInt8(seqNo, 2) + buffer.writeUInt8(flags, 3) + buffer.writeUInt32BE(sessionId, 4) + buffer.writeUInt32BE(length, 8) + + return buffer + } + + static readonly SIZE = 12 + static readonly DEFAULT_HEADER: BaseHeaderRecord = { + majorVersion: MAJOR_VERSIONS.TAC_PLUS_MAJOR_VER_DEFAULT, + minorVersion: MINOR_VERSIONS.TAC_PLUS_MINOR_VER_DEFAULT, + type: HEADER_TYPES.TAC_DEFAULT, + seqNo: 0x1, + flags: FLAGS.TAC_PLUS_UNENCRYPTED_FLAG, + sessionId: 0x0, + length: 0x0, + } +} diff --git a/test/index.test.ts b/test/index.test.ts index 15e9a89..ea99ab9 100644 --- a/test/index.test.ts +++ b/test/index.test.ts @@ -1,42 +1,71 @@ import { Buffer } from 'node:buffer' import { describe, expect, it } from 'vitest' -import { Header, PacketType } from '../src' +import { FLAGS, HEADER_TYPES, Header } from '../src/header' describe('@noction/tacacs-plus', () => { describe('header', () => { - it('should decode a valid header', () => { - const buffer = Buffer.from([ - 0x12, - 0x01, - 0x42, - 0x10, - 0x00, - 0x00, - 0x00, - 0x01, - 0x00, - 0x00, - 0x00, - 0x0C, - ]) - - const header = Header.decodeHeader(buffer) - - expect(header.majorVersion).toBe(1) - expect(header.minorVersion).toBe(2) - expect(header.type).toBe(PacketType.TAC_PLUS_AUTHEN) - expect(header.seq_no).toBe(66) - expect(header.flags).toBe(16) - expect(header.session_id).toBe(1) - expect(header.length).toBe(12) + describe('decodeHeader', () => { + it('should decode a valid header', () => { + const buffer = Buffer.from([ + 0x12, + 0x01, + 0x42, + FLAGS.TAC_PLUS_UNENCRYPTED_FLAG, + 0x00, + 0x00, + 0x00, + 0x01, + 0x00, + 0x00, + 0x00, + 0x0C, + ]) + + const header = Header.decode(buffer) + + expect(header.majorVersion).toBe(1) + expect(header.minorVersion).toBe(2) + expect(header.type).toBe(HEADER_TYPES.TAC_PLUS_AUTHEN) + expect(header.seqNo).toBe(66) + expect(header.flags).toBe(FLAGS.TAC_PLUS_UNENCRYPTED_FLAG) + expect(header.sessionId).toBe(1) + expect(header.length).toBe(12) + + expect(header.isEncrypted).toBe(false) + expect(header.isSingleConnection).toBe(false) + }) + + it('should throw an error for an invalid header size', () => { + const invalidBuffer = Buffer.from([0x00, 0x01]) + + expect(() => Header.decode(invalidBuffer)).toThrowError( + 'Header size must be 12, but received 2', + ) + }) }) - it('should throw an error for an invalid header size', () => { - const invalidBuffer = Buffer.from([0x00, 0x01]) + describe('create', () => { + it('should create a valid header with default values', () => { + const defaultHeader = Header.create() + + expect(defaultHeader).toBeInstanceOf(Buffer) + expect(defaultHeader.length).toBe(Header.SIZE) + }) + + it('should create a custom header with specified values', () => { + const customHeader = Header.create({ + majorVersion: 1, + minorVersion: 2, + type: HEADER_TYPES.TAC_PLUS_ACCT, + seqNo: 42, + flags: FLAGS.TAC_PLUS_SINGLE_CONNECT_FLAG, + sessionId: 123456, + length: 100, + }) - expect(() => Header.decodeHeader(invalidBuffer)).toThrowError( - 'Header size must be 12, but received 2', - ) + expect(customHeader).toBeInstanceOf(Buffer) + expect(customHeader.length).toBe(Header.SIZE) + }) }) }) })