Skip to content

Commit

Permalink
feat: support configuring what models to include for zod and trpc plu…
Browse files Browse the repository at this point in the history
…gins (#747)
  • Loading branch information
ymc9 authored Oct 11, 2023
1 parent 30b95eb commit a5d15a3
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 41 deletions.
40 changes: 14 additions & 26 deletions packages/plugins/trpc/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
PluginOptions,
RUNTIME_PACKAGE,
getPrismaClientImportSpec,
parseOptionAsStrings,
requireOption,
resolvePath,
saveProject,
Expand Down Expand Up @@ -32,11 +33,14 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
let outDir = requireOption<string>(options, 'output');
outDir = resolvePath(outDir, options);

// resolve "generateModels" option
const generateModels = parseOptionAsStrings(options, 'generateModels', name);

// resolve "generateModelActions" option
const generateModelActions = parseOptionAsStrings(options, 'generateModelActions');
const generateModelActions = parseOptionAsStrings(options, 'generateModelActions', name);

// resolve "generateClientHelpers" option
const generateClientHelpers = parseOptionAsStrings(options, 'generateClientHelpers');
const generateClientHelpers = parseOptionAsStrings(options, 'generateClientHelpers', name);
if (generateClientHelpers && !generateClientHelpers.every((v) => ['react', 'next'].includes(v))) {
throw new PluginError(name, `Option "generateClientHelpers" only support values "react" and "next"`);
}
Expand All @@ -50,10 +54,15 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.

const prismaClientDmmf = dmmf;

const modelOperations = prismaClientDmmf.mappings.modelOperations;
const models = prismaClientDmmf.datamodel.models;
let modelOperations = prismaClientDmmf.mappings.modelOperations;
if (generateModels) {
modelOperations = modelOperations.filter((mo) => generateModels.includes(mo.model));
}

// TODO: remove this legacy code that deals with "@Gen.hide" comment syntax inherited
// from original code
const hiddenModels: string[] = [];
resolveModelsComments(models, hiddenModels);
resolveModelsComments(prismaClientDmmf.datamodel.models, hiddenModels);

const zodSchemasImport = (options.zodSchemasImport as string) ?? '@zenstackhq/runtime/zod';
createAppRouter(
Expand Down Expand Up @@ -472,24 +481,3 @@ function createHelper(outDir: string) {
);
checkRead.formatText();
}

function parseOptionAsStrings(options: PluginOptions, optionaName: string) {
const value = options[optionaName];
if (value === undefined) {
return undefined;
} else if (typeof value === 'string') {
// comma separated string
return value
.split(',')
.filter((i) => !!i)
.map((i) => i.trim());
} else if (Array.isArray(value) && value.every((i) => typeof i === 'string')) {
// string array
return value as string[];
} else {
throw new PluginError(
name,
`Invalid "${optionaName}" option: must be a comma-separated string or an array of strings`
);
}
}
129 changes: 129 additions & 0 deletions packages/plugins/trpc/tests/trpc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,4 +285,133 @@ model post_item {
}
);
});

it('generate for selected models and actions', async () => {
const { projectDir } = await loadSchema(
`
datasource db {
provider = 'postgresql'
url = env('DATABASE_URL')
}
generator js {
provider = 'prisma-client-js'
}
plugin trpc {
provider = '${process.cwd()}/dist'
output = '$projectRoot/trpc'
generateModels = ['Post']
generateModelActions = ['findMany', 'update']
}
model User {
id String @id
email String @unique
posts Post[]
}
model Post {
id String @id
title String
author User? @relation(fields: [authorId], references: [id])
authorId String?
}
model Foo {
id String @id
value Int
}
`,
{
addPrelude: false,
pushDb: false,
extraDependencies: [`${origDir}/dist`, '@trpc/client', '@trpc/server'],
compile: true,
}
);

expect(fs.existsSync(path.join(projectDir, 'trpc/routers/User.router.ts'))).toBeFalsy();
expect(fs.existsSync(path.join(projectDir, 'trpc/routers/Foo.router.ts'))).toBeFalsy();
expect(fs.existsSync(path.join(projectDir, 'trpc/routers/Post.router.ts'))).toBeTruthy();

const postRouterContent = fs.readFileSync(path.join(projectDir, 'trpc/routers/Post.router.ts'), 'utf8');
expect(postRouterContent).toContain('findMany:');
expect(postRouterContent).toContain('update:');
expect(postRouterContent).not.toContain('findUnique:');
expect(postRouterContent).not.toContain('create:');

// trpc plugin passes "generateModels" option down to implicitly enabled zod plugin

expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/PostInput.schema.js'))
).toBeTruthy();
// zod for User is generated due to transitive dependency
expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/UserInput.schema.js'))
).toBeTruthy();
expect(fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/FooInput.schema.js'))).toBeFalsy();
});

it('generate for selected models with zod plugin declared', async () => {
const { projectDir } = await loadSchema(
`
datasource db {
provider = 'postgresql'
url = env('DATABASE_URL')
}
generator js {
provider = 'prisma-client-js'
}
plugin zod {
provider = '@core/zod'
}
plugin trpc {
provider = '${process.cwd()}/dist'
output = '$projectRoot/trpc'
generateModels = ['Post']
generateModelActions = ['findMany', 'update']
}
model User {
id String @id
email String @unique
posts Post[]
}
model Post {
id String @id
title String
author User? @relation(fields: [authorId], references: [id])
authorId String?
}
model Foo {
id String @id
value Int
}
`,
{
addPrelude: false,
pushDb: false,
extraDependencies: [`${origDir}/dist`, '@trpc/client', '@trpc/server'],
compile: true,
}
);

// trpc plugin's "generateModels" shouldn't interfere in this case

expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/PostInput.schema.js'))
).toBeTruthy();
expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/UserInput.schema.js'))
).toBeTruthy();
expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/FooInput.schema.js'))
).toBeTruthy();
});
});
42 changes: 35 additions & 7 deletions packages/schema/src/cli/plugin-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ export class PluginRunner {
}

// "@core/access-policy" has implicit requirements
let zodImplicitlyAdded = false;
if ([...plugins, ...corePlugins].find((p) => p.provider === '@core/access-policy')) {
// make sure "@core/model-meta" is enabled
if (!corePlugins.find((p) => p.provider === '@core/model-meta')) {
Expand All @@ -193,25 +194,52 @@ export class PluginRunner {
// '@core/zod' plugin is auto-enabled by "@core/access-policy"
// if there're validation rules
if (!corePlugins.find((p) => p.provider === '@core/zod') && this.hasValidation(options.schema)) {
zodImplicitlyAdded = true;
corePlugins.push({ provider: '@core/zod', options: { modelOnly: true } });
}
}

// core plugins introduced by dependencies
plugins
.flatMap((p) => p.dependencies)
.forEach((dep) => {
plugins.forEach((plugin) => {
// TODO: generalize this
const isTrpcPlugin =
plugin.provider === '@zenstackhq/trpc' ||
// for testing
(process.env.ZENSTACK_TEST && plugin.provider.includes('trpc'));

for (const dep of plugin.dependencies) {
if (dep.startsWith('@core/')) {
const existing = corePlugins.find((p) => p.provider === dep);
if (existing) {
// reset options to default
existing.options = undefined;
// TODO: generalize this
if (existing.provider === '@core/zod') {
// Zod plugin can be automatically enabled in `modelOnly` mode, however
// other plugin (tRPC) for now requires it to run in full mode
existing.options = {};

if (
isTrpcPlugin &&
zodImplicitlyAdded // don't do it for user defined zod plugin
) {
// pass trpc plugin's `generateModels` option down to zod plugin
existing.options.generateModels = plugin.options.generateModels;
}
}
} else {
// add core dependency
corePlugins.push({ provider: dep });
const toAdd = { provider: dep, options: {} as Record<string, unknown> };

// TODO: generalize this
if (dep === '@core/zod' && isTrpcPlugin) {
// pass trpc plugin's `generateModels` option down to zod plugin
toAdd.options.generateModels = plugin.options.generateModels;
}

corePlugins.push(toAdd);
}
}
});
}
});

return corePlugins;
}
Expand Down
71 changes: 64 additions & 7 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
isEnumFieldReference,
isForeignKeyField,
isFromStdlib,
parseOptionAsStrings,
resolvePath,
saveProject,
} from '@zenstackhq/sdk';
Expand All @@ -21,6 +22,7 @@ import { streamAllContents } from 'langium';
import path from 'path';
import { Project } from 'ts-morph';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '.';
import { getDefaultOutputFolder } from '../plugin-utils';
import Transformer from './transformer';
import removeDir from './utils/removeDir';
Expand All @@ -44,12 +46,26 @@ export async function generate(
output = resolvePath(output, options);
await handleGeneratorOutputValue(output);

// calculate the models to be excluded
const excludeModels = getExcludedModels(model, options);

const prismaClientDmmf = dmmf;

const modelOperations = prismaClientDmmf.mappings.modelOperations;
const inputObjectTypes = prismaClientDmmf.schema.inputObjectTypes.prisma;
const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma;
const models: DMMF.Model[] = prismaClientDmmf.datamodel.models;
const modelOperations = prismaClientDmmf.mappings.modelOperations.filter(
(o) => !excludeModels.find((e) => e === o.model)
);

// TODO: better way of filtering than string startsWith?
const inputObjectTypes = prismaClientDmmf.schema.inputObjectTypes.prisma.filter(
(type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLocaleLowerCase()))
);
const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma.filter(
(type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLowerCase()))
);

const models: DMMF.Model[] = prismaClientDmmf.datamodel.models.filter(
(m) => !excludeModels.find((e) => e === m.name)
);

// whether Prisma's Unchecked* series of input types should be generated
const generateUnchecked = options.noUncheckedInput !== true;
Expand All @@ -73,7 +89,7 @@ export async function generate(
dataSource?.fields.find((f) => f.name === 'provider')?.value
) as ConnectorType;

await generateModelSchemas(project, model, output);
await generateModelSchemas(project, model, output, excludeModels);

if (options.modelOnly !== true) {
// detailed object schemas referenced from input schemas
Expand Down Expand Up @@ -120,6 +136,45 @@ export async function generate(
}
}

function getExcludedModels(model: Model, options: PluginOptions) {
// resolve "generateModels" option
const generateModels = parseOptionAsStrings(options, 'generateModels', name);
if (generateModels) {
if (options.modelOnly === true) {
// no model reference needs to be considered, directly exclude any model not included
return model.declarations
.filter((d) => isDataModel(d) && !generateModels.includes(d.name))
.map((m) => m.name);
} else {
// calculate a transitive closure of models to be included
const todo = getDataModels(model).filter((dm) => generateModels.includes(dm.name));
const included = new Set<DataModel>();
while (todo.length > 0) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const dm = todo.pop()!;
included.add(dm);

// add referenced models to the todo list
dm.fields
.map((f) => f.type.reference?.ref)
.filter((type): type is DataModel => isDataModel(type))
.forEach((type) => {
if (!included.has(type)) {
todo.push(type);
}
});
}

// finally find the models to be excluded
return getDataModels(model)
.filter((dm) => !included.has(dm))
.map((m) => m.name);
}
} else {
return [];
}
}

async function handleGeneratorOutputValue(output: string) {
// create the output directory and delete contents that might exist from a previous run
await fs.mkdir(output, { recursive: true });
Expand Down Expand Up @@ -184,10 +239,12 @@ async function generateObjectSchemas(
);
}

async function generateModelSchemas(project: Project, zmodel: Model, output: string) {
async function generateModelSchemas(project: Project, zmodel: Model, output: string, excludedModels: string[]) {
const schemaNames: string[] = [];
for (const dm of getDataModels(zmodel)) {
schemaNames.push(await generateModelSchema(dm, project, output));
if (!excludedModels.includes(dm.name)) {
schemaNames.push(await generateModelSchema(dm, project, output));
}
}

project.createSourceFile(
Expand Down
Loading

0 comments on commit a5d15a3

Please sign in to comment.