Skip to content

Commit

Permalink
Improvement to codegen of functional graphs in MDL (AcademySoftwareFo…
Browse files Browse the repository at this point in the history
…undation#1243)

This change list improves on codegen of functional graphs in MDL, adding code for default geomprops on the inputs of functions generated from graphs.
  • Loading branch information
niklasharrysson authored Feb 14, 2023
1 parent 444236f commit 0bfe3da
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 135 deletions.
28 changes: 14 additions & 14 deletions source/MaterialXGenMdl/MdlShaderGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,6 @@ MATERIALX_NAMESPACE_BEGIN
namespace
{

std::unordered_map<string, string> GEOMPROP_DEFINITIONS =
{
{ "Pobject", "base::transform_point(base::coordinate_internal, base::coordinate_object, state::position())" },
{ "Pworld", "base::transform_point(base::coordinate_internal, base::coordinate_world, state::position())" },
{ "Nobject", "base::transform_normal(base::coordinate_internal, base::coordinate_object, state::normal())" },
{ "Nworld", "base::transform_normal(base::coordinate_internal, base::coordinate_world, state::normal())" },
{ "Tobject", "base::transform_vector(base::coordinate_internal, base::coordinate_object, state::texture_tangent_u(0))" },
{ "Tworld", "base::transform_vector(base::coordinate_internal, base::coordinate_world, state::texture_tangent_u(0))" },
{ "Bobject", "base::transform_vector(base::coordinate_internal, base::coordinate_object, state::texture_tangent_v(0))" },
{ "Bworld", "base::transform_vector(base::coordinate_internal, base::coordinate_world, state::texture_tangent_v(0))" },
{ "UV0", "float2(state::texture_coordinate(0).x, state::texture_coordinate(0).y)" },
{ "Vworld", "state::direction()" }
};

const string MDL_VERSION = "1.6";

const vector<string> DEFAULT_IMPORTS =
Expand All @@ -67,6 +53,20 @@ const vector<string> DEFAULT_IMPORTS =

const string MdlShaderGenerator::TARGET = "genmdl";

const std::unordered_map<string, string> MdlShaderGenerator::GEOMPROP_DEFINITIONS =
{
{ "Pobject", "state::transform_point(state::coordinate_internal, state::coordinate_object, state::position())" },
{ "Pworld", "state::transform_point(state::coordinate_internal, state::coordinate_world, state::position())" },
{ "Nobject", "state::transform_normal(state::coordinate_internal, state::coordinate_object, state::normal())" },
{ "Nworld", "state::transform_normal(state::coordinate_internal, state::coordinate_world, state::normal())" },
{ "Tobject", "state::transform_vector(state::coordinate_internal, state::coordinate_object, state::texture_tangent_u(0))" },
{ "Tworld", "state::transform_vector(state::coordinate_internal, state::coordinate_world, state::texture_tangent_u(0))" },
{ "Bobject", "state::transform_vector(state::coordinate_internal, state::coordinate_object, state::texture_tangent_v(0))" },
{ "Bworld", "state::transform_vector(state::coordinate_internal, state::coordinate_world, state::texture_tangent_v(0))" },
{ "UV0", "float2(state::texture_coordinate(0).x, state::texture_coordinate(0).y)" },
{ "Vworld", "state::direction()" }
};

//
// MdlShaderGenerator methods
//
Expand Down
3 changes: 3 additions & 0 deletions source/MaterialXGenMdl/MdlShaderGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class MX_GENMDL_API MdlShaderGenerator : public ShaderGenerator
/// Unique identifier for this generator target
static const string TARGET;

/// Map of code snippets for geomprops in MDL.
static const std::unordered_map<string, string> GEOMPROP_DEFINITIONS;

protected:
// Create and initialize a new MDL shader for shader generation.
ShaderPtr createShader(const string& name, ElementPtr element, GenContext& context) const;
Expand Down
82 changes: 6 additions & 76 deletions source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,62 +24,20 @@ void ClosureCompoundNodeMdl::addClassification(ShaderNode& node) const
node.addClassification(_rootGraph->getClassification());
}

void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode&, GenContext& context, ShaderStage& stage) const
void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const
{
DEFINE_SHADER_STAGE(stage, Stage::PIXEL)
{
const ShaderGenerator& shadergen = context.getShaderGenerator();
const Syntax& syntax = shadergen.getSyntax();

const bool isMaterialExpr = (_rootGraph->hasClassification(ShaderNode::Classification::CLOSURE) ||
_rootGraph->hasClassification(ShaderNode::Classification::SHADER));

// Emit functions for all child nodes
shadergen.emitFunctionDefinitions(*_rootGraph, context, stage);

if (!_returnStruct.empty())
{
// Define the output struct.
shadergen.emitLine("struct " + _returnStruct, stage, false);
shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS);
for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets())
{
shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage);
}
shadergen.emitScopeEnd(stage, true);
shadergen.emitLineBreak(stage);

// Begin function signature.
shadergen.emitLine(_returnStruct + " " + _functionName, stage, false);
}
else
{
// Begin function signature.
const ShaderGraphOutputSocket* outputSocket = _rootGraph->getOutputSocket();
const string& outputType = syntax.getTypeName(outputSocket->getType());
shadergen.emitLine(outputType + " " + _functionName, stage, false);
}

shadergen.emitScopeBegin(stage, Syntax::PARENTHESES);

const string uniformPrefix = syntax.getUniformQualifier() + " ";

// Emit all inputs
int count = int(_rootGraph->numInputSockets());
for (ShaderGraphInputSocket* input : _rootGraph->getInputSockets())
{
const string& qualifier = input->isUniform() || input->getType() == Type::FILENAME ? uniformPrefix : EMPTY_STRING;
const string& type = syntax.getTypeName(input->getType());
const string value = (input->getValue() ?
syntax.getValue(input->getType(), *input->getValue()) :
syntax.getDefaultValue(input->getType()));

const string& delim = --count > 0 ? Syntax::COMMA : EMPTY_STRING;
shadergen.emitLine(qualifier + type + " " + input->getVariable() + " = " + value + delim, stage, false);
}

// End function signature.
shadergen.emitScopeEnd(stage);
// Emit function signature.
emitFunctionSignature(node, context, stage);

// Special case for material expresions.
if (isMaterialExpr)
Expand Down Expand Up @@ -149,39 +107,11 @@ void ClosureCompoundNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext
{
const ShaderGenerator& shadergen = context.getShaderGenerator();

// Emit calls for any closure dependencies upstream from this node.
// First emit calls for any closure dependencies upstream from this node.
shadergen.emitDependentFunctionCalls(node, context, stage, ShaderNode::Classification::CLOSURE);

// Begin function call.
if (!_returnStruct.empty())
{
// Emit the struct multioutput.
const string resultVariableName = node.getName() + "_result";
shadergen.emitLineBegin(stage);
shadergen.emitString(_returnStruct + " " + resultVariableName + " = ", stage);
}
else
{
// Emit the single output.
shadergen.emitLineBegin(stage);
shadergen.emitOutput(node.getOutput(0), true, false, context, stage);
shadergen.emitString(" = ", stage);
}

shadergen.emitString(_functionName + "(", stage);

// Emit inputs.
string delim = "";
for (ShaderInput* input : node.getInputs())
{
shadergen.emitString(delim, stage);
shadergen.emitInput(input, context, stage);
delim = ", ";
}

// End function call
shadergen.emitString(")", stage);
shadergen.emitLineEnd(stage);
// Then emit this nodes function call.
CompoundNodeMdl::emitFunctionCall(node, context, stage);
}
}

Expand Down
111 changes: 66 additions & 45 deletions source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//

#include <MaterialXGenMdl/Nodes/CompoundNodeMdl.h>
#include <MaterialXGenMdl/MdlShaderGenerator.h>

#include <MaterialXGenShader/HwShaderGenerator.h>
#include <MaterialXGenShader/ShaderGenerator.h>
Expand All @@ -29,62 +30,20 @@ void CompoundNodeMdl::initialize(const InterfaceElement& element, GenContext& co
}
}

void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode&, GenContext& context, ShaderStage& stage) const
void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const
{
DEFINE_SHADER_STAGE(stage, Stage::PIXEL)
{
const ShaderGenerator& shadergen = context.getShaderGenerator();
const Syntax& syntax = shadergen.getSyntax();

const bool isMaterialExpr = (_rootGraph->hasClassification(ShaderNode::Classification::CLOSURE) ||
_rootGraph->hasClassification(ShaderNode::Classification::SHADER));

// Emit functions for all child nodes
shadergen.emitFunctionDefinitions(*_rootGraph, context, stage);

if (!_returnStruct.empty())
{
// Define the output struct.
shadergen.emitLine("struct " + _returnStruct, stage, false);
shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS);
for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets())
{
shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage);
}
shadergen.emitScopeEnd(stage, true);
shadergen.emitLineBreak(stage);

// Begin function signature.
shadergen.emitLine(_returnStruct + " " + _functionName, stage, false);
}
else
{
// Begin function signature.
const ShaderGraphOutputSocket* outputSocket = _rootGraph->getOutputSocket();
const string& outputType = syntax.getTypeName(outputSocket->getType());
shadergen.emitLine(outputType + " " + _functionName, stage, false);
}

shadergen.emitScopeBegin(stage, Syntax::PARENTHESES);

const string uniformPrefix = syntax.getUniformQualifier() + " ";

// Emit all inputs
int count = int(_rootGraph->numInputSockets());
for (ShaderGraphInputSocket* input : _rootGraph->getInputSockets())
{
const string& qualifier = input->isUniform() || input->getType() == Type::FILENAME ? uniformPrefix : EMPTY_STRING;
const string& type = syntax.getTypeName(input->getType());
const string value = (input->getValue() ?
syntax.getValue(input->getType(), *input->getValue()) :
syntax.getDefaultValue(input->getType()));

const string& delim = --count > 0 ? Syntax::COMMA : EMPTY_STRING;
shadergen.emitLine(qualifier + type + " " + input->getVariable() + " = " + value + delim, stage, false);
}

// End function signature.
shadergen.emitScopeEnd(stage);
// Emit function signature.
emitFunctionSignature(node, context, stage);

// Special case for material expresions.
if (isMaterialExpr)
Expand Down Expand Up @@ -169,4 +128,66 @@ void CompoundNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& conte
}
}

void CompoundNodeMdl::emitFunctionSignature(const ShaderNode&, GenContext& context, ShaderStage& stage) const
{
const ShaderGenerator& shadergen = context.getShaderGenerator();
const Syntax& syntax = shadergen.getSyntax();

if (!_returnStruct.empty())
{
// Define the output struct.
shadergen.emitLine("struct " + _returnStruct, stage, false);
shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS);
for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets())
{
shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage);
}
shadergen.emitScopeEnd(stage, true);
shadergen.emitLineBreak(stage);

// Begin function signature.
shadergen.emitLine(_returnStruct + " " + _functionName, stage, false);
}
else
{
// Begin function signature.
const ShaderGraphOutputSocket* outputSocket = _rootGraph->getOutputSocket();
const string& outputType = syntax.getTypeName(outputSocket->getType());
shadergen.emitLine(outputType + " " + _functionName, stage, false);
}

shadergen.emitScopeBegin(stage, Syntax::PARENTHESES);

const string uniformPrefix = syntax.getUniformQualifier() + " ";

// Emit all inputs
int count = int(_rootGraph->numInputSockets());
for (ShaderGraphInputSocket* input : _rootGraph->getInputSockets())
{
const string& qualifier = input->isUniform() || input->getType() == Type::FILENAME ? uniformPrefix : EMPTY_STRING;
const string& type = syntax.getTypeName(input->getType());

string value = input->getValue() ? syntax.getValue(input->getType(), *input->getValue(), true) : EMPTY_STRING;
const string& geomprop = input->getGeomProp();
if (!geomprop.empty())
{
auto it = MdlShaderGenerator::GEOMPROP_DEFINITIONS.find(geomprop);
if (it != MdlShaderGenerator::GEOMPROP_DEFINITIONS.end())
{
value = it->second;
}
}
if (value.empty())
{
value = syntax.getDefaultValue(input->getType(), true);
}

const string& delim = --count > 0 ? Syntax::COMMA : EMPTY_STRING;
shadergen.emitLine(qualifier + type + " " + input->getVariable() + " = " + value + delim, stage, false);
}

// End function signature.
shadergen.emitScopeEnd(stage);
}

MATERIALX_NAMESPACE_END
2 changes: 2 additions & 0 deletions source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class MX_GENMDL_API CompoundNodeMdl : public CompoundNode
void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;

protected:
void emitFunctionSignature(const ShaderNode& node, GenContext& context, ShaderStage& stage) const;

string _returnStruct;
};

Expand Down
5 changes: 5 additions & 0 deletions source/MaterialXGenShader/ShaderGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ void ShaderGraph::addInputSockets(const InterfaceElement& elem, GenContext& cont
{
inputSocket->setUniform();
}
GeomPropDefPtr geomprop = input->getDefaultGeomProp();
if (geomprop)
{
inputSocket->setGeomProp(geomprop->getName());
}
}
}

Expand Down

0 comments on commit 0bfe3da

Please sign in to comment.