diff --git a/README.markdown b/README.markdown index d7a24188f..67f3f92b6 100644 --- a/README.markdown +++ b/README.markdown @@ -446,6 +446,10 @@ Generated code will be placed in the Gradle build directory. - With `--ts_proto_opt=outputServices=false`, or `=none`, ts-proto will output NO service definitions. +- With `--ts_proto_opt=outputBeforeRequest=true`, ts-proto will add a function definition to the Rpc interface definition with the signature: `beforeRequest(request: )`. It will will also automatically set `outputTypeRegistry=true` and `outputServices=true`. Each of the Service's methods will call `beforeRequest` before performing it's request. + +- With `--ts_proto_opt=outputAfterResponse=true`, ts-proto will add a function definition to the Rpc interface definition with the signature: `afterResponse(response: )`. It will will also automatically set `outputTypeRegistry=true` and `outputServices=true`. Each of the Service's methods will call `afterResponse` before returning the response. + - With `--ts_proto_opt=useAbortSignal=true`, the generated services will accept an `AbortSignal` to cancel RPC calls. - With `--ts_proto_opt=useAsyncIterable=true`, the generated services will use `AsyncIterable` instead of `Observable`. diff --git a/integration/before-after-request/before-after-request-test.ts b/integration/before-after-request/before-after-request-test.ts new file mode 100644 index 000000000..9d2be2dfb --- /dev/null +++ b/integration/before-after-request/before-after-request-test.ts @@ -0,0 +1,47 @@ +import { FooServiceClientImpl, FooServiceCreateRequest, FooServiceCreateResponse } from "./simple"; +import { MessageType } from "./typeRegistry"; + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; + beforeRequest?(request: T): void; + afterResponse?(response: T): void; +} + +describe("before-after-request", () => { + const exampleData = { + kind: 1, + }; + let rpc = { + request: jest.fn(() => Promise.resolve(new Uint8Array())), + }; + let client = new FooServiceClientImpl(rpc); + const beforeRequest = jest.fn(); + const afterResponse = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + jest.spyOn(FooServiceCreateResponse, "decode").mockReturnValue(exampleData); + }); + + it("performs function before request if specified", async () => { + const req = FooServiceCreateRequest.create(exampleData); + client = new FooServiceClientImpl({ ...rpc, beforeRequest: beforeRequest }); + await client.Create(req); + expect(beforeRequest).toHaveBeenCalledWith(req); + }); + + it("performs function after request if specified", async () => { + const req = FooServiceCreateRequest.create(exampleData); + client = new FooServiceClientImpl({ ...rpc, afterResponse: afterResponse }); + await client.Create(req); + expect(afterResponse).toHaveBeenCalledWith(exampleData); + }); + + it("doesn't perform function before or after request if they are not specified", async () => { + const req = FooServiceCreateRequest.create(exampleData); + client = new FooServiceClientImpl({ ...rpc }); + await client.Create(req); + expect(beforeRequest).not.toHaveBeenCalled(); + expect(afterResponse).not.toHaveBeenCalled(); + }); +}); diff --git a/integration/before-after-request/parameters.txt b/integration/before-after-request/parameters.txt new file mode 100644 index 000000000..9335148d4 --- /dev/null +++ b/integration/before-after-request/parameters.txt @@ -0,0 +1 @@ +outputBeforeRequest=true,outputAfterResponse=true diff --git a/integration/before-after-request/simple.bin b/integration/before-after-request/simple.bin new file mode 100644 index 000000000..aa18fec38 Binary files /dev/null and b/integration/before-after-request/simple.bin differ diff --git a/integration/before-after-request/simple.proto b/integration/before-after-request/simple.proto new file mode 100644 index 000000000..4cc37d287 --- /dev/null +++ b/integration/before-after-request/simple.proto @@ -0,0 +1,37 @@ +syntax = "proto3"; +package simple; +import "simple2.proto"; + +enum SimpleEnum { + LOCAL_DEFAULT = 0; + LOCAL_FOO = 1; + LOCAL_BAR = 2; +} + +message Simple { + string name = 1; + simple2.Simple otherSimple = 2; +} + +message DifferentSimple { + string name = 1; + optional simple2.Simple otherOptionalSimple2 = 2; +} + +message SimpleEnums { + SimpleEnum local_enum = 1; + simple2.SimpleEnum import_enum = 2; +} + +message FooServiceCreateRequest { + simple2.FooService kind = 1; +} + +message FooServiceCreateResponse { + simple2.FooService kind = 1; +} + +service FooService { + rpc Create (FooServiceCreateRequest) returns (FooServiceCreateResponse); +} + diff --git a/integration/before-after-request/simple.ts b/integration/before-after-request/simple.ts new file mode 100644 index 000000000..c3d318639 --- /dev/null +++ b/integration/before-after-request/simple.ts @@ -0,0 +1,468 @@ +/* eslint-disable */ +import * as _m0 from "protobufjs/minimal"; +import { + FooService as FooService2, + fooServiceFromJSON, + fooServiceToJSON, + Simple as Simple3, + SimpleEnum as SimpleEnum1, + simpleEnumFromJSON as simpleEnumFromJSON4, + simpleEnumToJSON as simpleEnumToJSON5, +} from "./simple2"; + +export const protobufPackage = "simple"; + +export enum SimpleEnum { + LOCAL_DEFAULT = 0, + LOCAL_FOO = 1, + LOCAL_BAR = 2, + UNRECOGNIZED = -1, +} + +export function simpleEnumFromJSON(object: any): SimpleEnum { + switch (object) { + case 0: + case "LOCAL_DEFAULT": + return SimpleEnum.LOCAL_DEFAULT; + case 1: + case "LOCAL_FOO": + return SimpleEnum.LOCAL_FOO; + case 2: + case "LOCAL_BAR": + return SimpleEnum.LOCAL_BAR; + case -1: + case "UNRECOGNIZED": + default: + return SimpleEnum.UNRECOGNIZED; + } +} + +export function simpleEnumToJSON(object: SimpleEnum): string { + switch (object) { + case SimpleEnum.LOCAL_DEFAULT: + return "LOCAL_DEFAULT"; + case SimpleEnum.LOCAL_FOO: + return "LOCAL_FOO"; + case SimpleEnum.LOCAL_BAR: + return "LOCAL_BAR"; + case SimpleEnum.UNRECOGNIZED: + default: + return "UNRECOGNIZED"; + } +} + +export interface Simple { + name: string; + otherSimple: Simple3 | undefined; +} + +export interface DifferentSimple { + name: string; + otherOptionalSimple2?: Simple3 | undefined; +} + +export interface SimpleEnums { + localEnum: SimpleEnum; + importEnum: SimpleEnum1; +} + +export interface FooServiceCreateRequest { + kind: FooService2; +} + +export interface FooServiceCreateResponse { + kind: FooService2; +} + +function createBaseSimple(): Simple { + return { name: "", otherSimple: undefined }; +} + +export const Simple = { + encode(message: Simple, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.name !== "") { + writer.uint32(10).string(message.name); + } + if (message.otherSimple !== undefined) { + Simple3.encode(message.otherSimple, writer.uint32(18).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): Simple { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseSimple(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.name = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.otherSimple = Simple3.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): Simple { + return { + name: isSet(object.name) ? globalThis.String(object.name) : "", + otherSimple: isSet(object.otherSimple) ? Simple3.fromJSON(object.otherSimple) : undefined, + }; + }, + + toJSON(message: Simple): unknown { + const obj: any = {}; + if (message.name !== "") { + obj.name = message.name; + } + if (message.otherSimple !== undefined) { + obj.otherSimple = Simple3.toJSON(message.otherSimple); + } + return obj; + }, + + create, I>>(base?: I): Simple { + return Simple.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): Simple { + const message = createBaseSimple(); + message.name = object.name ?? ""; + message.otherSimple = (object.otherSimple !== undefined && object.otherSimple !== null) + ? Simple3.fromPartial(object.otherSimple) + : undefined; + return message; + }, +}; + +function createBaseDifferentSimple(): DifferentSimple { + return { name: "", otherOptionalSimple2: undefined }; +} + +export const DifferentSimple = { + encode(message: DifferentSimple, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.name !== "") { + writer.uint32(10).string(message.name); + } + if (message.otherOptionalSimple2 !== undefined) { + Simple3.encode(message.otherOptionalSimple2, writer.uint32(18).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): DifferentSimple { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseDifferentSimple(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.name = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.otherOptionalSimple2 = Simple3.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): DifferentSimple { + return { + name: isSet(object.name) ? globalThis.String(object.name) : "", + otherOptionalSimple2: isSet(object.otherOptionalSimple2) + ? Simple3.fromJSON(object.otherOptionalSimple2) + : undefined, + }; + }, + + toJSON(message: DifferentSimple): unknown { + const obj: any = {}; + if (message.name !== "") { + obj.name = message.name; + } + if (message.otherOptionalSimple2 !== undefined) { + obj.otherOptionalSimple2 = Simple3.toJSON(message.otherOptionalSimple2); + } + return obj; + }, + + create, I>>(base?: I): DifferentSimple { + return DifferentSimple.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): DifferentSimple { + const message = createBaseDifferentSimple(); + message.name = object.name ?? ""; + message.otherOptionalSimple2 = (object.otherOptionalSimple2 !== undefined && object.otherOptionalSimple2 !== null) + ? Simple3.fromPartial(object.otherOptionalSimple2) + : undefined; + return message; + }, +}; + +function createBaseSimpleEnums(): SimpleEnums { + return { localEnum: 0, importEnum: 0 }; +} + +export const SimpleEnums = { + encode(message: SimpleEnums, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.localEnum !== 0) { + writer.uint32(8).int32(message.localEnum); + } + if (message.importEnum !== 0) { + writer.uint32(16).int32(message.importEnum); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): SimpleEnums { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseSimpleEnums(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 8) { + break; + } + + message.localEnum = reader.int32() as any; + continue; + case 2: + if (tag !== 16) { + break; + } + + message.importEnum = reader.int32() as any; + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): SimpleEnums { + return { + localEnum: isSet(object.localEnum) ? simpleEnumFromJSON(object.localEnum) : 0, + importEnum: isSet(object.importEnum) ? simpleEnumFromJSON4(object.importEnum) : 0, + }; + }, + + toJSON(message: SimpleEnums): unknown { + const obj: any = {}; + if (message.localEnum !== 0) { + obj.localEnum = simpleEnumToJSON(message.localEnum); + } + if (message.importEnum !== 0) { + obj.importEnum = simpleEnumToJSON5(message.importEnum); + } + return obj; + }, + + create, I>>(base?: I): SimpleEnums { + return SimpleEnums.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): SimpleEnums { + const message = createBaseSimpleEnums(); + message.localEnum = object.localEnum ?? 0; + message.importEnum = object.importEnum ?? 0; + return message; + }, +}; + +function createBaseFooServiceCreateRequest(): FooServiceCreateRequest { + return { kind: 0 }; +} + +export const FooServiceCreateRequest = { + encode(message: FooServiceCreateRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.kind !== 0) { + writer.uint32(8).int32(message.kind); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): FooServiceCreateRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseFooServiceCreateRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 8) { + break; + } + + message.kind = reader.int32() as any; + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): FooServiceCreateRequest { + return { kind: isSet(object.kind) ? fooServiceFromJSON(object.kind) : 0 }; + }, + + toJSON(message: FooServiceCreateRequest): unknown { + const obj: any = {}; + if (message.kind !== 0) { + obj.kind = fooServiceToJSON(message.kind); + } + return obj; + }, + + create, I>>(base?: I): FooServiceCreateRequest { + return FooServiceCreateRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): FooServiceCreateRequest { + const message = createBaseFooServiceCreateRequest(); + message.kind = object.kind ?? 0; + return message; + }, +}; + +function createBaseFooServiceCreateResponse(): FooServiceCreateResponse { + return { kind: 0 }; +} + +export const FooServiceCreateResponse = { + encode(message: FooServiceCreateResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.kind !== 0) { + writer.uint32(8).int32(message.kind); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): FooServiceCreateResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseFooServiceCreateResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 8) { + break; + } + + message.kind = reader.int32() as any; + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): FooServiceCreateResponse { + return { kind: isSet(object.kind) ? fooServiceFromJSON(object.kind) : 0 }; + }, + + toJSON(message: FooServiceCreateResponse): unknown { + const obj: any = {}; + if (message.kind !== 0) { + obj.kind = fooServiceToJSON(message.kind); + } + return obj; + }, + + create, I>>(base?: I): FooServiceCreateResponse { + return FooServiceCreateResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): FooServiceCreateResponse { + const message = createBaseFooServiceCreateResponse(); + message.kind = object.kind ?? 0; + return message; + }, +}; + +export interface FooService { + Create(request: FooServiceCreateRequest): Promise; +} + +export const FooServiceServiceName = "simple.FooService"; +export class FooServiceClientImpl implements FooService { + private readonly rpc: Rpc; + private readonly service: string; + constructor(rpc: Rpc, opts?: { service?: string }) { + this.service = opts?.service || FooServiceServiceName; + this.rpc = rpc; + this.Create = this.Create.bind(this); + } + Create(request: FooServiceCreateRequest): Promise { + const data = FooServiceCreateRequest.encode(request).finish(); + if (this.rpc.beforeRequest) { + this.rpc.beforeRequest(request); + } + const promise = this.rpc.request(this.service, "Create", data); + return promise.then((data) => { + const response = FooServiceCreateResponse.decode(_m0.Reader.create(data)); + if (this.rpc.afterResponse) { + this.rpc.afterResponse(response); + } + return response; + }); + } +} + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; + beforeRequest?(request: T): void; + afterResponse?(response: T): void; +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/integration/before-after-request/simple2.bin b/integration/before-after-request/simple2.bin new file mode 100644 index 000000000..a172b14bd Binary files /dev/null and b/integration/before-after-request/simple2.bin differ diff --git a/integration/before-after-request/simple2.proto b/integration/before-after-request/simple2.proto new file mode 100644 index 000000000..4d087f923 --- /dev/null +++ b/integration/before-after-request/simple2.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; +package simple2; + +enum SimpleEnum { + IMPORT_DEFAULT = 0; + IMPORT_FOO = 10; + IMPORT_BAR = 11; +} + +enum FooService { + FOO_SERVICE_DEFAULT = 0; + FOO_SERVICE_FOO = 1; + FOO_SERVICE_BAR = 2; +} + +message Simple { + string name = 1; + int32 age = 2; +} + diff --git a/integration/before-after-request/simple2.ts b/integration/before-after-request/simple2.ts new file mode 100644 index 000000000..c842a7c0f --- /dev/null +++ b/integration/before-after-request/simple2.ts @@ -0,0 +1,177 @@ +/* eslint-disable */ +import * as _m0 from "protobufjs/minimal"; + +export const protobufPackage = "simple2"; + +export enum SimpleEnum { + IMPORT_DEFAULT = 0, + IMPORT_FOO = 10, + IMPORT_BAR = 11, + UNRECOGNIZED = -1, +} + +export function simpleEnumFromJSON(object: any): SimpleEnum { + switch (object) { + case 0: + case "IMPORT_DEFAULT": + return SimpleEnum.IMPORT_DEFAULT; + case 10: + case "IMPORT_FOO": + return SimpleEnum.IMPORT_FOO; + case 11: + case "IMPORT_BAR": + return SimpleEnum.IMPORT_BAR; + case -1: + case "UNRECOGNIZED": + default: + return SimpleEnum.UNRECOGNIZED; + } +} + +export function simpleEnumToJSON(object: SimpleEnum): string { + switch (object) { + case SimpleEnum.IMPORT_DEFAULT: + return "IMPORT_DEFAULT"; + case SimpleEnum.IMPORT_FOO: + return "IMPORT_FOO"; + case SimpleEnum.IMPORT_BAR: + return "IMPORT_BAR"; + case SimpleEnum.UNRECOGNIZED: + default: + return "UNRECOGNIZED"; + } +} + +export enum FooService { + FOO_SERVICE_DEFAULT = 0, + FOO_SERVICE_FOO = 1, + FOO_SERVICE_BAR = 2, + UNRECOGNIZED = -1, +} + +export function fooServiceFromJSON(object: any): FooService { + switch (object) { + case 0: + case "FOO_SERVICE_DEFAULT": + return FooService.FOO_SERVICE_DEFAULT; + case 1: + case "FOO_SERVICE_FOO": + return FooService.FOO_SERVICE_FOO; + case 2: + case "FOO_SERVICE_BAR": + return FooService.FOO_SERVICE_BAR; + case -1: + case "UNRECOGNIZED": + default: + return FooService.UNRECOGNIZED; + } +} + +export function fooServiceToJSON(object: FooService): string { + switch (object) { + case FooService.FOO_SERVICE_DEFAULT: + return "FOO_SERVICE_DEFAULT"; + case FooService.FOO_SERVICE_FOO: + return "FOO_SERVICE_FOO"; + case FooService.FOO_SERVICE_BAR: + return "FOO_SERVICE_BAR"; + case FooService.UNRECOGNIZED: + default: + return "UNRECOGNIZED"; + } +} + +export interface Simple { + name: string; + age: number; +} + +function createBaseSimple(): Simple { + return { name: "", age: 0 }; +} + +export const Simple = { + encode(message: Simple, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.name !== "") { + writer.uint32(10).string(message.name); + } + if (message.age !== 0) { + writer.uint32(16).int32(message.age); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): Simple { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseSimple(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.name = reader.string(); + continue; + case 2: + if (tag !== 16) { + break; + } + + message.age = reader.int32(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): Simple { + return { + name: isSet(object.name) ? globalThis.String(object.name) : "", + age: isSet(object.age) ? globalThis.Number(object.age) : 0, + }; + }, + + toJSON(message: Simple): unknown { + const obj: any = {}; + if (message.name !== "") { + obj.name = message.name; + } + if (message.age !== 0) { + obj.age = Math.round(message.age); + } + return obj; + }, + + create, I>>(base?: I): Simple { + return Simple.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): Simple { + const message = createBaseSimple(); + message.name = object.name ?? ""; + message.age = object.age ?? 0; + return message; + }, +}; + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/integration/before-after-request/typeRegistry.ts b/integration/before-after-request/typeRegistry.ts new file mode 100644 index 000000000..f824e4d15 --- /dev/null +++ b/integration/before-after-request/typeRegistry.ts @@ -0,0 +1,22 @@ +/* eslint-disable */ +import * as _m0 from "protobufjs/minimal"; + +export interface MessageType { + $type: Message["$type"]; + encode(message: Message, writer?: _m0.Writer): _m0.Writer; + decode(input: _m0.Reader | Uint8Array, length?: number): Message; + fromJSON(object: any): Message; + toJSON(message: Message): unknown; + fromPartial(object: DeepPartial): Message; +} + +export type UnknownMessage = { $type: string }; + +export const messageTypeRegistry = new Map(); + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in Exclude]?: DeepPartial } + : Partial; diff --git a/src/generate-services.ts b/src/generate-services.ts index 3d3402e28..ec7665314 100644 --- a/src/generate-services.ts +++ b/src/generate-services.ts @@ -120,7 +120,23 @@ function generateRegularRpcMethod(ctx: Context, methodDesc: MethodDescriptorProt const maybeAbortSignal = options.useAbortSignal ? "abortSignal || undefined," : ""; let encode = code`${rawInputType}.encode(request).finish()`; + let beforeRequest; + if (options.outputBeforeRequest) { + beforeRequest = code` + if (this.rpc.beforeRequest) { + this.rpc.beforeRequest(request); + }`; + } let decode = code`data => ${rawOutputType}.decode(${Reader}.create(data))`; + if (options.outputAfterResponse) { + decode = code`data => { + const response = ${rawOutputType}.decode(${Reader}.create(data)); + if (this.rpc.afterResponse) { + this.rpc.afterResponse(response); + } + return response; + }`; + } // if (options.useDate && rawOutputType.toString().includes("Timestamp")) { // decode = code`data => ${utils.fromTimestamp}(${rawOutputType}.decode(${Reader}.create(data)))`; @@ -160,7 +176,7 @@ function generateRegularRpcMethod(ctx: Context, methodDesc: MethodDescriptorProt ${methodDesc.formattedName}( ${joinCode(params, { on: "," })} ): ${responsePromiseOrObservable(ctx, methodDesc)} { - const data = ${encode}; + const data = ${encode}; ${beforeRequest ? beforeRequest : ""} const ${returnVariable} = this.rpc.${rpcMethod}( ${maybeCtx} this.service, @@ -351,6 +367,13 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod const maybeContextParam = options.context ? "ctx: Context," : ""; const maybeAbortSignalParam = options.useAbortSignal ? "abortSignal?: AbortSignal," : ""; const methods = [[code`request`, code`Uint8Array`, code`Promise`]]; + const additionalMethods = []; + if (options.outputBeforeRequest) { + additionalMethods.push(code`beforeRequest?(request: T): void;`); + } + if (options.outputAfterResponse) { + additionalMethods.push(code`afterResponse?(response: T): void;`); + } if (hasStreamingMethods) { const observable = observableType(ctx, true); methods.push([code`clientStreamingRequest`, code`${observable}`, code`Promise`]); @@ -373,6 +396,7 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod ${maybeAbortSignalParam} ): ${method[2]};`); }); + additionalMethods.forEach((method) => chunks.push(method)); chunks.push(code` }`); return joinCode(chunks, { on: "\n" }); } diff --git a/src/options.ts b/src/options.ts index 925d6ee35..88fab496e 100644 --- a/src/options.ts +++ b/src/options.ts @@ -87,6 +87,8 @@ export type Options = { outputExtensions: boolean; outputIndex: boolean; M: { [from: string]: string }; + outputBeforeRequest: boolean; + outputAfterResponse: boolean; }; export function defaultOptions(): Options { @@ -143,6 +145,8 @@ export function defaultOptions(): Options { outputExtensions: false, outputIndex: false, M: {}, + outputBeforeRequest: false, + outputAfterResponse: false, }; } @@ -240,6 +244,10 @@ export function optionsFromParameter(parameter: string | undefined): Options { options.exportCommonSymbols = false; } + if (options.outputBeforeRequest || options.outputAfterResponse) { + options.outputServices = [ServiceOption.DEFAULT]; + } + if (options.unrecognizedEnumValue) { // Make sure to cast number options to an actual number options.unrecognizedEnumValue = Number(options.unrecognizedEnumValue); diff --git a/tests/options-test.ts b/tests/options-test.ts index efd5f1647..91df889c3 100644 --- a/tests/options-test.ts +++ b/tests/options-test.ts @@ -2,6 +2,7 @@ import { DateOption, optionsFromParameter, ServiceOption } from "../src/options" describe("options", () => { it("can set outputJsonMethods with nestJs=true", () => { + console.log(optionsFromParameter("nestJs=true,outputJsonMethods=true")); expect(optionsFromParameter("nestJs=true,outputJsonMethods=true")).toMatchInlineSnapshot(` { "M": {}, @@ -25,6 +26,8 @@ describe("options", () => { "nestJs": true, "oneof": "properties", "onlyTypes": false, + "outputAfterResponse": false, + "outputBeforeRequest": false, "outputClientImpl": false, "outputEncodeMethods": false, "outputExtensions": false, @@ -163,4 +166,20 @@ describe("options", () => { useDate: DateOption.STRING, }); }); + + it("outputAfterResponse implies default service", () => { + const options = optionsFromParameter("outputAfterResponse=true"); + expect(options).toMatchObject({ + outputAfterResponse: true, + outputServices: [ServiceOption.DEFAULT], + }); + }); + + it("outputBeforeRequest implies default service", () => { + const options = optionsFromParameter("outputBeforeRequest=true"); + expect(options).toMatchObject({ + outputBeforeRequest: true, + outputServices: [ServiceOption.DEFAULT], + }); + }); });