From 14a2e4cdcb64af8452feb163455446143b42ccf2 Mon Sep 17 00:00:00 2001 From: koray kavukcuoglu Date: Wed, 3 Apr 2013 22:07:50 +0100 Subject: [PATCH] add 'id' field to Node correct reverse/clone/topsort functions --- Node.lua | 1 + init.lua | 65 ++++++++++++++++++++++++++++++++++++++++--------------- utils.lua | 4 ++-- 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/Node.lua b/Node.lua index 84c52a5..459f9a8 100644 --- a/Node.lua +++ b/Node.lua @@ -13,6 +13,7 @@ local Node = torch.class('graph.Node') function Node:__init(d,p) self.data = d + self.id = 0 self.children = {} self.visited = false self.marked = false diff --git a/init.lua b/init.lua index 9bc15c6..737bca9 100644 --- a/init.lua +++ b/init.lua @@ -43,6 +43,8 @@ function Graph:add(edge) end -- add the edge to the node for parsing in nodes edge.from:add(edge.to) + edge.from.id = self.nodes[edge.from] + edge.to.id = self.nodes[edge.to] else for i,e in ipairs(edge) do self:add(e) @@ -55,15 +57,18 @@ end -- Note that primitive data types like numbers can not be shared function Graph:clone() local clone = graph.Graph() + local nodes = {} + for i,n in ipairs(self.nodes) do + table.insert(nodes,n.new(n.data)) + end for i,e in ipairs(self.edges) do - local from = graph.Node(e.from.data) - local to = graph.Node(e.to.data) - clone:add(graph.Edge(from,to)) + local from = nodes[self.nodes[e.from]] + local to = nodes[self.nodes[e.to]] + clone:add(e.new(from,to)) end return clone end - -- It returns a new graph where the edges are reversed. -- The nodes share the data. Note that primitive data types can -- not be shared. @@ -75,25 +80,23 @@ function Graph:reverse() mapnodes[e.to] = mapnodes[e.to] or e.to.new(e.to.data) local from = mapnodes[e.from] local to = mapnodes[e.to] - rg:add(graph.Edge(to,from)) + rg:add(e.new(to,from)) end - return rg + return rg,mapnodes end --[[ Topological Sort + ** This is not finished. OK for graphs with single root. ]]-- function Graph:topsort() - -- first clone the graph - local g = self:clone() - local nodes = g.nodes - local edges = g.edges - for i,node in ipairs(nodes) do - node.children = {} - end -- reverse the graph - local rg = self:reverse() + local rg,map = self:reverse() + local rmap = {} + for k,v in pairs(map) do + rmap[v] = k + end -- work on the sorted graph local sortednodes = {} @@ -105,7 +108,7 @@ function Graph:topsort() -- run for i,root in ipairs(rootnodes) do - root:dfs(function(node) table.insert(sortednodes,node) end) + root:dfs(function(node) table.insert(sortednodes,rmap[node]) end) end if #sortednodes ~= #self.nodes then @@ -137,16 +140,42 @@ function Graph:roots() return roots end -function Graph:todot() +-- find root nodes +function Graph:leaves() + local edges = self.edges + local leafnodes = {} + for i,edge in ipairs(edges) do + --table.insert(rootnodes,edge.from) + if not leafnodes[edge.to] then + leafnodes[edge.to] = #leafnodes+1 + end + end + for i,edge in ipairs(edges) do + if leafnodes[edge.from] then + leafnodes[edge.from] = nil + end + end + local leaves = {} + for leaf,i in pairs(leafnodes) do + table.insert(leaves, leaf) + end + table.sort(leaves,function(a,b) return self.nodes[a] < self.nodes[b] end ) + return leaves +end + +function Graph:todot(title) local nodes = self.nodes local edges = self.edges str = {} table.insert(str,'digraph G {\n') + if title then + table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') + end table.insert(str,'node [shape = oval]; ') local nodelabels = {} for i,node in ipairs(nodes) do - local l = '"' .. (node:label() or 'n' .. i) .. '"' - nodelabels[node] = 'n' .. i + local l = '"' .. ( 'Node' .. node.id .. '\\n' .. node:label() ) .. '"' + nodelabels[node] = 'n' .. node.id table.insert(str, '\n' .. nodelabels[node] .. '[label=' .. l .. '];') end table.insert(str,'\n') diff --git a/utils.lua b/utils.lua index f4ba52b..22600c8 100644 --- a/utils.lua +++ b/utils.lua @@ -1,8 +1,8 @@ require 'qtsvg' -function graph.dot(g,fname) - local gv = g:todot() +function graph.dot(g,title,fname) + local gv = g:todot(title) local fngv = (fname or os.tmpname()) .. '.dot' local fgv = io.open(fngv,'w') fgv:write(gv)