Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of graph traversal #2023

10 changes: 8 additions & 2 deletions source/MaterialXCore/Traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ GraphIterator& GraphIterator::operator++()
// Traverse to the first upstream edge of this element.
_stack.emplace_back(_upstreamElem, 0);
Edge nextEdge = _upstreamElem->getUpstreamEdge(0);
if (nextEdge && nextEdge.getUpstreamElement())
if (nextEdge && nextEdge.getUpstreamElement() && !skipOrMarkAsVisited(nextEdge))
{
extendPathUpstream(nextEdge.getUpstreamElement(), nextEdge.getConnectingElement());
return *this;
Expand All @@ -140,7 +140,7 @@ GraphIterator& GraphIterator::operator++()
if (parentFrame.second + 1 < parentFrame.first->getUpstreamEdgeCount())
{
Edge nextEdge = parentFrame.first->getUpstreamEdge(++parentFrame.second);
if (nextEdge && nextEdge.getUpstreamElement())
if (nextEdge && nextEdge.getUpstreamElement() && !skipOrMarkAsVisited(nextEdge))
{
extendPathUpstream(nextEdge.getUpstreamElement(), nextEdge.getConnectingElement());
return *this;
Expand Down Expand Up @@ -177,6 +177,12 @@ void GraphIterator::returnPathDownstream(ElementPtr upstreamElem)
_connectingElem = ElementPtr();
}

bool GraphIterator::skipOrMarkAsVisited(const Edge& edge)
{
auto [it, inserted] = _visitedEdges.emplace(edge);
return !inserted;
}

//
// InheritanceIterator methods
//
Expand Down
2 changes: 2 additions & 0 deletions source/MaterialXCore/Traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ class MX_CORE_API GraphIterator
private:
void extendPathUpstream(ElementPtr upstreamElem, ElementPtr connectingElem);
void returnPathDownstream(ElementPtr upstreamElem);
bool skipOrMarkAsVisited(const Edge&);

private:
ElementPtr _upstreamElem;
ElementPtr _connectingElem;
ElementSet _pathElems;
vector<StackFrame> _stack;
std::set<Edge> _visitedEdges;
bool _prune;
size_t _holdCount;
};
Expand Down
12 changes: 9 additions & 3 deletions source/MaterialXGenShader/ShaderGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ void ShaderGraph::optimize()
ShaderOutput* upstreamPort = outputSocket->getConnection();
if (upstreamPort && upstreamPort->getNode() != this)
{
for (ShaderGraphEdge edge : ShaderGraph::traverseUpstream(upstreamPort))
for (ShaderGraphEdge edge : traverseUpstream(upstreamPort))
{
ShaderNode* node = edge.upstream->getNode();
if (usedNodesSet.count(node) == 0)
Expand Down Expand Up @@ -1206,7 +1206,7 @@ ShaderGraphEdgeIterator& ShaderGraphEdgeIterator::operator++()
ShaderInput* input = _upstream->getNode()->getInput(0);
ShaderOutput* output = input->getConnection();

if (output && !output->getNode()->isAGraph())
if (output && !output->getNode()->isAGraph() && !skipOrMarkAsVisited({ output, input }))
{
extendPathUpstream(output, input);
return *this;
Expand Down Expand Up @@ -1234,7 +1234,7 @@ ShaderGraphEdgeIterator& ShaderGraphEdgeIterator::operator++()
ShaderInput* input = parentFrame.first->getNode()->getInput(++parentFrame.second);
ShaderOutput* output = input->getConnection();

if (output && !output->getNode()->isAGraph())
if (output && !output->getNode()->isAGraph() && !skipOrMarkAsVisited({ output, input }))
{
extendPathUpstream(output, input);
return *this;
Expand Down Expand Up @@ -1275,4 +1275,10 @@ void ShaderGraphEdgeIterator::returnPathDownstream(ShaderOutput* upstream)
_downstream = nullptr;
}

bool ShaderGraphEdgeIterator::skipOrMarkAsVisited(ShaderGraphEdge edge)
{
auto [it, inserted] = _visitedEdges.emplace(edge);
return !inserted;
}

MATERIALX_NAMESPACE_END
18 changes: 18 additions & 0 deletions source/MaterialXGenShader/ShaderGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,22 @@ class MX_GENSHADER_API ShaderGraphEdge
downstream(down)
{
}

bool operator==(const ShaderGraphEdge& rhs) const
{
return upstream == rhs.upstream && downstream == rhs.downstream;
}

bool operator!=(const ShaderGraphEdge& rhs) const
{
return !(*this == rhs);
}

bool operator<(const ShaderGraphEdge& rhs) const
{
return std::tie(upstream, downstream) < std::tie(rhs.upstream, rhs.downstream);
}

ShaderOutput* upstream;
ShaderInput* downstream;
};
Expand Down Expand Up @@ -254,12 +270,14 @@ class MX_GENSHADER_API ShaderGraphEdgeIterator
private:
void extendPathUpstream(ShaderOutput* upstream, ShaderInput* downstream);
void returnPathDownstream(ShaderOutput* upstream);
bool skipOrMarkAsVisited(ShaderGraphEdge);

ShaderOutput* _upstream;
ShaderInput* _downstream;
using StackFrame = std::pair<ShaderOutput*, size_t>;
std::vector<StackFrame> _stack;
std::set<ShaderOutput*> _path;
std::set<ShaderGraphEdge> _visitedEdges;
};

MATERIALX_NAMESPACE_END
Expand Down