From 0bfe3da4175548aa1e26c54763736886731b2f6e Mon Sep 17 00:00:00 2001 From: Niklas Harrysson Date: Wed, 15 Feb 2023 00:29:08 +0100 Subject: [PATCH] Improvement to codegen of functional graphs in MDL (#1243) This change list improves on codegen of functional graphs in MDL, adding code for default geomprops on the inputs of functions generated from graphs. --- source/MaterialXGenMdl/MdlShaderGenerator.cpp | 28 ++--- source/MaterialXGenMdl/MdlShaderGenerator.h | 3 + .../Nodes/ClosureCompoundNodeMdl.cpp | 82 +------------ .../MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp | 111 +++++++++++------- .../MaterialXGenMdl/Nodes/CompoundNodeMdl.h | 2 + source/MaterialXGenShader/ShaderGraph.cpp | 5 + 6 files changed, 96 insertions(+), 135 deletions(-) diff --git a/source/MaterialXGenMdl/MdlShaderGenerator.cpp b/source/MaterialXGenMdl/MdlShaderGenerator.cpp index 688530bde5..aa32682866 100644 --- a/source/MaterialXGenMdl/MdlShaderGenerator.cpp +++ b/source/MaterialXGenMdl/MdlShaderGenerator.cpp @@ -31,20 +31,6 @@ MATERIALX_NAMESPACE_BEGIN namespace { -std::unordered_map GEOMPROP_DEFINITIONS = -{ - { "Pobject", "base::transform_point(base::coordinate_internal, base::coordinate_object, state::position())" }, - { "Pworld", "base::transform_point(base::coordinate_internal, base::coordinate_world, state::position())" }, - { "Nobject", "base::transform_normal(base::coordinate_internal, base::coordinate_object, state::normal())" }, - { "Nworld", "base::transform_normal(base::coordinate_internal, base::coordinate_world, state::normal())" }, - { "Tobject", "base::transform_vector(base::coordinate_internal, base::coordinate_object, state::texture_tangent_u(0))" }, - { "Tworld", "base::transform_vector(base::coordinate_internal, base::coordinate_world, state::texture_tangent_u(0))" }, - { "Bobject", "base::transform_vector(base::coordinate_internal, base::coordinate_object, state::texture_tangent_v(0))" }, - { "Bworld", "base::transform_vector(base::coordinate_internal, base::coordinate_world, state::texture_tangent_v(0))" }, - { "UV0", "float2(state::texture_coordinate(0).x, state::texture_coordinate(0).y)" }, - { "Vworld", "state::direction()" } -}; - const string MDL_VERSION = "1.6"; const vector DEFAULT_IMPORTS = @@ -67,6 +53,20 @@ const vector DEFAULT_IMPORTS = const string MdlShaderGenerator::TARGET = "genmdl"; +const std::unordered_map MdlShaderGenerator::GEOMPROP_DEFINITIONS = +{ + { "Pobject", "state::transform_point(state::coordinate_internal, state::coordinate_object, state::position())" }, + { "Pworld", "state::transform_point(state::coordinate_internal, state::coordinate_world, state::position())" }, + { "Nobject", "state::transform_normal(state::coordinate_internal, state::coordinate_object, state::normal())" }, + { "Nworld", "state::transform_normal(state::coordinate_internal, state::coordinate_world, state::normal())" }, + { "Tobject", "state::transform_vector(state::coordinate_internal, state::coordinate_object, state::texture_tangent_u(0))" }, + { "Tworld", "state::transform_vector(state::coordinate_internal, state::coordinate_world, state::texture_tangent_u(0))" }, + { "Bobject", "state::transform_vector(state::coordinate_internal, state::coordinate_object, state::texture_tangent_v(0))" }, + { "Bworld", "state::transform_vector(state::coordinate_internal, state::coordinate_world, state::texture_tangent_v(0))" }, + { "UV0", "float2(state::texture_coordinate(0).x, state::texture_coordinate(0).y)" }, + { "Vworld", "state::direction()" } +}; + // // MdlShaderGenerator methods // diff --git a/source/MaterialXGenMdl/MdlShaderGenerator.h b/source/MaterialXGenMdl/MdlShaderGenerator.h index fd192787f4..f1431b8916 100644 --- a/source/MaterialXGenMdl/MdlShaderGenerator.h +++ b/source/MaterialXGenMdl/MdlShaderGenerator.h @@ -44,6 +44,9 @@ class MX_GENMDL_API MdlShaderGenerator : public ShaderGenerator /// Unique identifier for this generator target static const string TARGET; + /// Map of code snippets for geomprops in MDL. + static const std::unordered_map GEOMPROP_DEFINITIONS; + protected: // Create and initialize a new MDL shader for shader generation. ShaderPtr createShader(const string& name, ElementPtr element, GenContext& context) const; diff --git a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp index 6566d6627a..d5d2a3e533 100644 --- a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp @@ -24,12 +24,11 @@ void ClosureCompoundNodeMdl::addClassification(ShaderNode& node) const node.addClassification(_rootGraph->getClassification()); } -void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode&, GenContext& context, ShaderStage& stage) const +void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const { DEFINE_SHADER_STAGE(stage, Stage::PIXEL) { const ShaderGenerator& shadergen = context.getShaderGenerator(); - const Syntax& syntax = shadergen.getSyntax(); const bool isMaterialExpr = (_rootGraph->hasClassification(ShaderNode::Classification::CLOSURE) || _rootGraph->hasClassification(ShaderNode::Classification::SHADER)); @@ -37,49 +36,8 @@ void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode&, GenContex // Emit functions for all child nodes shadergen.emitFunctionDefinitions(*_rootGraph, context, stage); - if (!_returnStruct.empty()) - { - // Define the output struct. - shadergen.emitLine("struct " + _returnStruct, stage, false); - shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS); - for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets()) - { - shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage); - } - shadergen.emitScopeEnd(stage, true); - shadergen.emitLineBreak(stage); - - // Begin function signature. - shadergen.emitLine(_returnStruct + " " + _functionName, stage, false); - } - else - { - // Begin function signature. - const ShaderGraphOutputSocket* outputSocket = _rootGraph->getOutputSocket(); - const string& outputType = syntax.getTypeName(outputSocket->getType()); - shadergen.emitLine(outputType + " " + _functionName, stage, false); - } - - shadergen.emitScopeBegin(stage, Syntax::PARENTHESES); - - const string uniformPrefix = syntax.getUniformQualifier() + " "; - - // Emit all inputs - int count = int(_rootGraph->numInputSockets()); - for (ShaderGraphInputSocket* input : _rootGraph->getInputSockets()) - { - const string& qualifier = input->isUniform() || input->getType() == Type::FILENAME ? uniformPrefix : EMPTY_STRING; - const string& type = syntax.getTypeName(input->getType()); - const string value = (input->getValue() ? - syntax.getValue(input->getType(), *input->getValue()) : - syntax.getDefaultValue(input->getType())); - - const string& delim = --count > 0 ? Syntax::COMMA : EMPTY_STRING; - shadergen.emitLine(qualifier + type + " " + input->getVariable() + " = " + value + delim, stage, false); - } - - // End function signature. - shadergen.emitScopeEnd(stage); + // Emit function signature. + emitFunctionSignature(node, context, stage); // Special case for material expresions. if (isMaterialExpr) @@ -149,39 +107,11 @@ void ClosureCompoundNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext { const ShaderGenerator& shadergen = context.getShaderGenerator(); - // Emit calls for any closure dependencies upstream from this node. + // First emit calls for any closure dependencies upstream from this node. shadergen.emitDependentFunctionCalls(node, context, stage, ShaderNode::Classification::CLOSURE); - // Begin function call. - if (!_returnStruct.empty()) - { - // Emit the struct multioutput. - const string resultVariableName = node.getName() + "_result"; - shadergen.emitLineBegin(stage); - shadergen.emitString(_returnStruct + " " + resultVariableName + " = ", stage); - } - else - { - // Emit the single output. - shadergen.emitLineBegin(stage); - shadergen.emitOutput(node.getOutput(0), true, false, context, stage); - shadergen.emitString(" = ", stage); - } - - shadergen.emitString(_functionName + "(", stage); - - // Emit inputs. - string delim = ""; - for (ShaderInput* input : node.getInputs()) - { - shadergen.emitString(delim, stage); - shadergen.emitInput(input, context, stage); - delim = ", "; - } - - // End function call - shadergen.emitString(")", stage); - shadergen.emitLineEnd(stage); + // Then emit this nodes function call. + CompoundNodeMdl::emitFunctionCall(node, context, stage); } } diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp index af4267a481..df6d61d7f8 100644 --- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp @@ -4,6 +4,7 @@ // #include +#include #include #include @@ -29,12 +30,11 @@ void CompoundNodeMdl::initialize(const InterfaceElement& element, GenContext& co } } -void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode&, GenContext& context, ShaderStage& stage) const +void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const { DEFINE_SHADER_STAGE(stage, Stage::PIXEL) { const ShaderGenerator& shadergen = context.getShaderGenerator(); - const Syntax& syntax = shadergen.getSyntax(); const bool isMaterialExpr = (_rootGraph->hasClassification(ShaderNode::Classification::CLOSURE) || _rootGraph->hasClassification(ShaderNode::Classification::SHADER)); @@ -42,49 +42,8 @@ void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode&, GenContext& cont // Emit functions for all child nodes shadergen.emitFunctionDefinitions(*_rootGraph, context, stage); - if (!_returnStruct.empty()) - { - // Define the output struct. - shadergen.emitLine("struct " + _returnStruct, stage, false); - shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS); - for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets()) - { - shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage); - } - shadergen.emitScopeEnd(stage, true); - shadergen.emitLineBreak(stage); - - // Begin function signature. - shadergen.emitLine(_returnStruct + " " + _functionName, stage, false); - } - else - { - // Begin function signature. - const ShaderGraphOutputSocket* outputSocket = _rootGraph->getOutputSocket(); - const string& outputType = syntax.getTypeName(outputSocket->getType()); - shadergen.emitLine(outputType + " " + _functionName, stage, false); - } - - shadergen.emitScopeBegin(stage, Syntax::PARENTHESES); - - const string uniformPrefix = syntax.getUniformQualifier() + " "; - - // Emit all inputs - int count = int(_rootGraph->numInputSockets()); - for (ShaderGraphInputSocket* input : _rootGraph->getInputSockets()) - { - const string& qualifier = input->isUniform() || input->getType() == Type::FILENAME ? uniformPrefix : EMPTY_STRING; - const string& type = syntax.getTypeName(input->getType()); - const string value = (input->getValue() ? - syntax.getValue(input->getType(), *input->getValue()) : - syntax.getDefaultValue(input->getType())); - - const string& delim = --count > 0 ? Syntax::COMMA : EMPTY_STRING; - shadergen.emitLine(qualifier + type + " " + input->getVariable() + " = " + value + delim, stage, false); - } - - // End function signature. - shadergen.emitScopeEnd(stage); + // Emit function signature. + emitFunctionSignature(node, context, stage); // Special case for material expresions. if (isMaterialExpr) @@ -169,4 +128,66 @@ void CompoundNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& conte } } +void CompoundNodeMdl::emitFunctionSignature(const ShaderNode&, GenContext& context, ShaderStage& stage) const +{ + const ShaderGenerator& shadergen = context.getShaderGenerator(); + const Syntax& syntax = shadergen.getSyntax(); + + if (!_returnStruct.empty()) + { + // Define the output struct. + shadergen.emitLine("struct " + _returnStruct, stage, false); + shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS); + for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets()) + { + shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage); + } + shadergen.emitScopeEnd(stage, true); + shadergen.emitLineBreak(stage); + + // Begin function signature. + shadergen.emitLine(_returnStruct + " " + _functionName, stage, false); + } + else + { + // Begin function signature. + const ShaderGraphOutputSocket* outputSocket = _rootGraph->getOutputSocket(); + const string& outputType = syntax.getTypeName(outputSocket->getType()); + shadergen.emitLine(outputType + " " + _functionName, stage, false); + } + + shadergen.emitScopeBegin(stage, Syntax::PARENTHESES); + + const string uniformPrefix = syntax.getUniformQualifier() + " "; + + // Emit all inputs + int count = int(_rootGraph->numInputSockets()); + for (ShaderGraphInputSocket* input : _rootGraph->getInputSockets()) + { + const string& qualifier = input->isUniform() || input->getType() == Type::FILENAME ? uniformPrefix : EMPTY_STRING; + const string& type = syntax.getTypeName(input->getType()); + + string value = input->getValue() ? syntax.getValue(input->getType(), *input->getValue(), true) : EMPTY_STRING; + const string& geomprop = input->getGeomProp(); + if (!geomprop.empty()) + { + auto it = MdlShaderGenerator::GEOMPROP_DEFINITIONS.find(geomprop); + if (it != MdlShaderGenerator::GEOMPROP_DEFINITIONS.end()) + { + value = it->second; + } + } + if (value.empty()) + { + value = syntax.getDefaultValue(input->getType(), true); + } + + const string& delim = --count > 0 ? Syntax::COMMA : EMPTY_STRING; + shadergen.emitLine(qualifier + type + " " + input->getVariable() + " = " + value + delim, stage, false); + } + + // End function signature. + shadergen.emitScopeEnd(stage); +} + MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h index 1699f0fbfe..4ac9a5add3 100644 --- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h +++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h @@ -23,6 +23,8 @@ class MX_GENMDL_API CompoundNodeMdl : public CompoundNode void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; protected: + void emitFunctionSignature(const ShaderNode& node, GenContext& context, ShaderStage& stage) const; + string _returnStruct; }; diff --git a/source/MaterialXGenShader/ShaderGraph.cpp b/source/MaterialXGenShader/ShaderGraph.cpp index 035f70540e..5b7571af31 100644 --- a/source/MaterialXGenShader/ShaderGraph.cpp +++ b/source/MaterialXGenShader/ShaderGraph.cpp @@ -56,6 +56,11 @@ void ShaderGraph::addInputSockets(const InterfaceElement& elem, GenContext& cont { inputSocket->setUniform(); } + GeomPropDefPtr geomprop = input->getDefaultGeomProp(); + if (geomprop) + { + inputSocket->setGeomProp(geomprop->getName()); + } } }