Skip to content

Commit

Permalink
Kronecker-constrained AMEn prototype under implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
szdan97 committed Oct 15, 2020
1 parent cd94df7 commit 62629cd
Show file tree
Hide file tree
Showing 23 changed files with 1,966 additions and 370 deletions.
4 changes: 0 additions & 4 deletions .idea/encodings.xml

This file was deleted.

36 changes: 0 additions & 36 deletions .idea/gradle.xml

This file was deleted.

11 changes: 0 additions & 11 deletions .idea/misc.xml

This file was deleted.

7 changes: 0 additions & 7 deletions .idea/vcs.xml

This file was deleted.

2 changes: 1 addition & 1 deletion delta
Submodule delta updated from e82c71 to dc1c42
53 changes: 34 additions & 19 deletions src/main/kotlin/MDDExtensions/MDDtoTT.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,43 @@ import solver.TensorTrain
fun MddHandle.toTensorTrain(): TensorTrain {
val cores = arrayListOf<CoreTensor>()

var nextNodes = Array(this.variableHandle.variable.get().domainSize) {this[it]}.toSet().toList()
val firstCore = CoreTensor(this.variableHandle.variable.get().domainSize, 1, nextNodes.size)
for (i in 0 until this.size()) {
firstCore[i][nextNodes.indexOf(this[i])] = 1.0
}
cores.add(firstCore)
val levels = arrayListOf<HashSet<MddHandle>>()
val domainSizes = arrayListOf<Int>()

while(!nextNodes[0].isTerminal) {
val currNodes = nextNodes
nextNodes = currNodes.flatMap { node -> Array(node.size()) { node[it] }.toList() }.toSet().toList()
if(nextNodes[0].isTerminal) nextNodes = nextNodes.filter { it.data == true }
val newCore = CoreTensor(currNodes[0].variableHandle.variable.get().domainSize, currNodes.size, nextNodes.size)
for ((idx, currNode) in currNodes.withIndex()) {
for (edge in 0 until currNode.size()) {
val targetIdx = nextNodes.indexOf(currNode[edge])
if(targetIdx != -1)
newCore[edge][idx, targetIdx] = 1.0
var currNodes = hashSetOf(this)
levels.add(currNodes)
while (!currNodes.any(MddHandle::isTerminal)) {
val nexts = hashSetOf<MddHandle>()
val domainSize = currNodes.first().variableHandle.variable.get().domainSize
domainSizes.add(domainSize)
for(n in currNodes) {
for(i in 0 until domainSize) {
nexts.add(n[i])
}
}
cores.add(newCore)
levels.add(nexts)
currNodes = nexts
}
levels.last().removeIf { it.isTerminalZero }

for(l in levels.size-2 downTo 0) {
levels[l].removeIf { n ->
(0 until domainSizes[l]).none {levels[l+1].contains(n[it])}
}
}

val tensorTrain = TensorTrain(cores)
return tensorTrain
val levelLists = levels.map { it.toList() }
for (l in 0 until levelLists.size-1) {
val curr = levelLists[l]
val next = levelLists[l+1]
val core = CoreTensor( domainSizes[l], curr.size, next.size)
for(i in 0 until domainSizes[l]) {
for((idx, n) in curr.withIndex()) {
val target = next.indexOf(n[i])
if(target != -1) core[i][idx, target] = 1.0
}
}
cores.add(core)
}
return TensorTrain(cores)
}
132 changes: 54 additions & 78 deletions src/main/kotlin/Main.kt
Original file line number Diff line number Diff line change
@@ -1,86 +1,51 @@

import MDDExtensions.GSCompaction
import MDDExtensions.toTensorTrain
import benchmark.generateKanban
import benchmark.generateLongKanban
import faulttree.BasicEvent
import faulttree.FaultTree
import faulttree.FaultTreeNode
import gspn.*
import gspn.rateexpressions.Constant
import hu.bme.mit.delta.mdd.LatticeDefinition
import gspn.PetriNet
import gspn.Place
import gspn.Transition
import gspn.arc
import hu.bme.mit.delta.mdd.MddBuilder
import hu.bme.mit.delta.mdd.MddVariable
import hu.bme.mit.delta.mdd.MddVariableDescriptor
import hu.bme.mit.delta.mdd.MddHandle
import org.ejml.simple.SimpleMatrix
import solver.ALSSolve
import solver.TTSquareMatrix
import solver.TTVector
import solver.solvers.AMEnALSSolve
import java.util.*
import kotlin.math.abs

fun main(args: Array<String>) {

val rand = Random()
val N = generateLongKanban(4, 2, {rand.nextDouble()*9.0+1.0})
N.run {
computeCapacities()
val variableOrder = GSPN.mddFactory.createMddVariableOrder(LatticeDefinition.forSets())
var last: MddVariable? = null
for (place in places) {
last =
if (last == null) variableOrder.createOnTop(MddVariableDescriptor.create(place.name, place.capacity + 1))
else variableOrder.createBelow(last, MddVariableDescriptor.create(place.name, place.capacity + 1))
val N = PetriNet {
val A = p("A", 1)
val B = p("B", 0)
timed("AtoB", 30.0) {
input(arc(A, 1))
out(arc(B, 1))
}
timed("BtoA", 10.0) {
input(arc(B, 1))
out(arc(A, 1))
}
val capacities = places.map(Place::capacity)
val reachableMdd = stateSpace.reachableStatesRoot().toDelta(variableOrder)
var p0mdd = stateSpace.calculateTangible().toDelta(variableOrder)
val p0mask: TTSquareMatrix = TTSquareMatrix.diag(TTVector(p0mdd.toTensorTrain()))
val R0 = transitions.filterIsInstance<ExponentialTransition>().map { it.toTT(variableOrder, places) }.reduce(TTSquareMatrix::plus)
val threshold = 1e-8
val p0maskRounded = p0mask.copy()
p0maskRounded.tt.roundAbsolute(threshold/R0.frobenius())
val R0Rounded = R0.copy()
R0Rounded.tt.roundAbsolute(threshold/p0mask.frobenius())
val res: TTSquareMatrix = p0mask*R0
val resCopy = res.copy()
var start = System.currentTimeMillis()
resCopy.tt.roundAbsolute(1e-12)
val fullSVDTime = System.currentTimeMillis()-start
val resOtherCopy = res.copy()
start = System.currentTimeMillis()
resOtherCopy.tt.roundAbsolute(1e-12, true)
val iterSVDTime = System.currentTimeMillis()-start
val i = 0 // NOP
}
return

val l1 = 1.0
val l2 = 0.1
val l3 = 2.0
val l4 = 0.2
val p = arrayListOf(
Place("p0", 2, 1),
Place("p1", 2, 1),
Place("p2", 1, 0)
)
val t = arrayListOf<Transition>(
ExponentialTransition("t1", arrayListOf(Arc.ConstantArc(p[0],1)), arrayListOf(Arc.ConstantArc(p[1], 1)), arrayListOf(Arc.ConstantArc(p[2], 1)), Constant(4.0)),
ExponentialTransition("t2", arrayListOf(Arc.ConstantArc(p[1],1)), arrayListOf(Arc.ConstantArc(p[0], 1)), arrayListOf(), Constant(1.0)),
ExponentialTransition("t3", arrayListOf(), arrayListOf(Arc.ConstantArc(p[2], 1)), arrayListOf(Arc.ConstantArc(p[2], 1)), Constant(2.0)),
ExponentialTransition("t4", arrayListOf(Arc.ConstantArc(p[2],1)), arrayListOf(), arrayListOf(), Constant(3.0))
)
val g = generateKanban(1, {1.0})
val R1 = g.getRateMatrix()
val R2 = g.getRateMatrix(useCompaction = true)
// g.computeCapacities()
// val mdd = g.stateSpace.reachableStatesRoot()
// val order = JavaMddFactory.getDefault().createMddVariableOrder(LatticeDefinition.forSets())
// for (place in g.places.reversed()) {
// order.createOnTop(MddVariableDescriptor.create(place.name, place.capacity+1))
// }
// val deltamdd = mdd.toDelta(order, 0)
return
val sparse = true
// val sparse = false
val ss = if(!sparse) N.getSteadyStateDistribution(true, 0.0) { A ->
AMEnALSSolve(
A = A,
y = TTVector.zeros(A.modes),
residualThreshold = 1e-7,
maxSweeps = 50,
enrichmentRank = 2,
normalize = true,
verbose = true,
useApproxResidualForStopping = false
)
} else N.getSteadyStateDistributionSparse(true, false)
ss.printElements()
}

private fun compactionTest() {
Expand Down Expand Up @@ -109,11 +74,11 @@ private fun compactionTest() {

fun faultTreeGrowthTest() {
val rand = Random(123)
var topNode: FaultTreeNode = BasicEvent("ev0", rand.nextDouble()*10.0)
var topNode: FaultTreeNode = BasicEvent("ev0", rand.nextDouble() * 10.0)
for (i in 1..40) {
println("Number of leaves: ${i+1}")
topNode = topNode and BasicEvent("ev$i", rand.nextDouble()*10.0)
if(i < 28) continue
println("Number of leaves: ${i + 1}")
topNode = topNode and BasicEvent("ev$i", rand.nextDouble() * 10.0)
if (i < 28) continue
val ft = FaultTree(topNode)
val A = ft.getModifiedGenerator()
A.tt.roundRelative(1e-30)
Expand All @@ -122,7 +87,7 @@ fun faultTreeGrowthTest() {
val ones = TTVector.ones(b.modes)
var x0 = TTVector.ones(b.modes)
for (j in 0 until r) {
x0 = x0+x0.hadamard(ones)
x0 = x0 + x0.hadamard(ones)
}
x0.divAssign(r.toDouble())
val relativeThreshold = 0.0001
Expand Down Expand Up @@ -150,16 +115,15 @@ fun generateSPN(nPlaces: Int, nTransitions: Int, capacities: Int, minRate: Doubl
fun report(A: TTSquareMatrix, b: TTVector, x: TTVector, threshold: Double) =
report(A::times, b, x, threshold)

fun report(linearMap: (TTVector)->TTVector, b: TTVector, x: TTVector, threshold: Double) {
fun report(linearMap: (TTVector) -> TTVector, b: TTVector, x: TTVector, threshold: Double) {
println("results:")
val resNorm = (b-linearMap(x)).norm()
println("residual norm: $resNorm ${if(resNorm<threshold) "<" else ">"} $threshold (threshold)")
println("relative residual norm: ${resNorm/b.norm()}")
val resNorm = (b - linearMap(x)).norm()
println("residual norm: $resNorm ${if (resNorm < threshold) "<" else ">"} $threshold (threshold)")
println("relative residual norm: ${resNorm / b.norm()}")
print("solution vector: ")
if(x.numElements < 100) {
if (x.numElements < 100) {
x.printElements(); println()
}
else
} else
println("First element: ${x[0]}")
println("TT ranks of the result: ${x.tt.cores.map { it.rows }}")
print("Non-nullness in absorbing states: ${x.tt.hadamard((TTVector.ones(x.modes) - b).tt).frobenius()}")
Expand All @@ -178,4 +142,16 @@ fun SimpleMatrix.roundZeros(threshold: Double = 1E-14) {
if (abs(this[i, j]) < threshold) this[i, j] = 0.0
}
}
}

fun <R> MddHandle.mapTuples(f: (List<Int>) -> R): List<R> {
fun mapTuplesHelper(prefix: List<Int>, node: MddHandle): List<R> {
if(node.isTerminalZero) return listOf()
if(node.isTerminal) return listOf(f(prefix))
val res = arrayListOf<R>()
for(i in 0 until this.variableHandle.variable.get().domainSize)
res.addAll(mapTuplesHelper(prefix + i, node[i]))
return res
}
return mapTuplesHelper(listOf(), this)
}
2 changes: 1 addition & 1 deletion src/main/kotlin/benchmark/examples.kt
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ fun generateLongKanban(numLargeBlocks: Int, N: Int, getNextRate: () -> Double) =

val tsynch23_4 = immediate("tsyncs23_4") {
input(arc(pout[0]), arc(pout[1]), arc(p[2]))
out(arc(p[1]), arc(p[2]), arc(pm[2]))
out(arc(p[0]), arc(p[1]), arc(pm[2]))
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/main/kotlin/cli/Calc.kt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class Calc : CliktCommand(help =
b,
residualThreshold = threshold,
maxSweeps = momentArgs.sweeps ?: 0,
enrichmentRank = momentArgs.enrichmentRank ?: 1
enrichmentRank = momentArgs.enrichmentRank ?: 1,
useApproxResidualForStopping = false
)
}
else -> throw RuntimeException("Unknown solver")
Expand Down
Loading

0 comments on commit 62629cd

Please sign in to comment.