Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relu emu slow #35

Merged
merged 7 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@
"mode": "debug",
"program": "${workspaceFolder}/samples/kmeans",
"args": [
"-timing",
// "-timing",
"-points=1024",
"-features=32",
"-clusters=5",
"-max-iter=5",
"-report-all",
"-max-iter=4",
// "-report-all",
],
},
{
Expand Down
23 changes: 21 additions & 2 deletions emu/computeunit.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type ComputeUnit struct {
GlobalMemStorage *mem.Storage

ToDispatcher sim.Port

finishedMapWGReqs []string
}

// ControlPort returns the port that can receive controlling messages from the
Expand Down Expand Up @@ -366,15 +368,32 @@ func (cu *ComputeUnit) resolveBarrier(wg *kernels.WorkGroup) {

func (cu *ComputeUnit) handleWGCompleteEvent(evt *WGCompleteEvent) error {
delete(cu.wfs, evt.Req.WorkGroup)
found := false
for _, r := range cu.finishedMapWGReqs {
if r == evt.Req.ID {
found = true
break
}
}
if !found {
cu.finishedMapWGReqs = append(cu.finishedMapWGReqs, evt.Req.ID)
}

if len(cu.wfs) != 0 {
return nil
}

req := protocol.WGCompletionMsgBuilder{}.
WithRspTo(evt.Req.ID).
WithSrc(cu.ToDispatcher).
WithDst(evt.Req.Src).
WithSendTime(evt.Time()).
WithRspTo(cu.finishedMapWGReqs).
Build()

err := cu.ToDispatcher.Send(req)
if err != nil {
if err == nil {
cu.finishedMapWGReqs = nil
} else {
newEvent := NewWGCompleteEvent(cu.Freq.NextTick(evt.Time()),
cu, evt.Req)
cu.Engine.Schedule(newEvent)
Expand Down
6 changes: 3 additions & 3 deletions protocol/cuprotocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func (b MapWGReqBuilder) Build() *MapWGReq {
// execution
type WGCompletionMsg struct {
sim.MsgMeta
RspTo string
RspTo []string
}

// Meta returns the meta data associated with the MapWGReq.
Expand All @@ -287,7 +287,7 @@ func (r *WGCompletionMsg) Meta() *sim.MsgMeta {
type WGCompletionMsgBuilder struct {
sendTime sim.VTimeInSec
src, dst sim.Port
rspTo string
rspTo []string
}

// WithSendTime sets the send time.
Expand Down Expand Up @@ -316,7 +316,7 @@ func (b WGCompletionMsgBuilder) WithDst(

// WithRspTo sets rspTo
func (b WGCompletionMsgBuilder) WithRspTo(
rspTo string,
rspTo []string,
) WGCompletionMsgBuilder {
b.rspTo = rspTo
return b
Expand Down
44 changes: 28 additions & 16 deletions timing/cp/internal/dispatching/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dispatching

import (
"fmt"
"log"

"github.com/sarchlab/akita/v3/monitoring"
"github.com/sarchlab/akita/v3/sim"
Expand Down Expand Up @@ -120,28 +121,39 @@ func (d *DispatcherImpl) processMessagesFromCU(now sim.VTimeInSec) bool {

switch msg := msg.(type) {
case *protocol.WGCompletionMsg:
location, ok := d.inflightWGs[msg.RspTo]
if !ok {
return false
count := 0
for _, rspToID := range msg.RspTo {
_, ok := d.inflightWGs[rspToID]
if ok {
count += 1
}
}

d.alg.FreeResources(location)
delete(d.inflightWGs, msg.RspTo)
d.numCompletedWGs++
if d.numCompletedWGs == d.alg.NumWG() {
d.cycleLeft = d.constantKernelOverhead
if count == 0 {
return false
} else if count < len(msg.RspTo) {
log.Panic("In emulation all finished WGs from more than one dispatcher")
}

d.dispatchingPort.Retrieve(now)

originalReq := d.originalReqs[msg.RspTo]
delete(d.originalReqs, msg.RspTo)
tracing.TraceReqFinalize(originalReq, d)

if d.progressBar != nil {
d.progressBar.MoveInProgressToFinished(1)
for _, rspToID := range msg.RspTo {
location := d.inflightWGs[rspToID]
d.alg.FreeResources(location)
delete(d.inflightWGs, rspToID)
d.numCompletedWGs++
if d.numCompletedWGs == d.alg.NumWG() {
d.cycleLeft = d.constantKernelOverhead
}

originalReq := d.originalReqs[rspToID]
delete(d.originalReqs, rspToID)
tracing.TraceReqFinalize(originalReq, d)

if d.progressBar != nil {
d.progressBar.MoveInProgressToFinished(1)
}
}

d.dispatchingPort.Retrieve(now)
return true
}

Expand Down
6 changes: 3 additions & 3 deletions timing/cp/internal/dispatching/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ var _ = Describe("Dispatcher", func() {
dispatcher.inflightWGs[mapWGReq.ID] = location
dispatcher.originalReqs[mapWGReq.ID] = mapWGReq

wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: mapWGReq.ID}
wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: []string{mapWGReq.ID}}

dispatcher.numDispatchedWGs = 64
dispatcher.numCompletedWGs = 48
Expand Down Expand Up @@ -197,7 +197,7 @@ var _ = Describe("Dispatcher", func() {
dispatcher.inflightWGs[mapWGReq.ID] = location
dispatcher.originalReqs[mapWGReq.ID] = mapWGReq

wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: mapWGReq.ID}
wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: []string{mapWGReq.ID}}

dispatcher.numDispatchedWGs = 64
dispatcher.numCompletedWGs = 63
Expand Down Expand Up @@ -227,7 +227,7 @@ var _ = Describe("Dispatcher", func() {
mapWGReq := protocol.MapWGReqBuilder{}.Build()
// dispatcher.inflightWGs[mapWGReq.ID] = location

wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: mapWGReq.ID}
wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: []string{mapWGReq.ID}}

dispatcher.numDispatchedWGs = 64
dispatcher.numCompletedWGs = 48
Expand Down
2 changes: 1 addition & 1 deletion timing/cu/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func (s *SchedulerImpl) sendWGCompletionMessage(
WithSendTime(now).
WithSrc(s.cu.ToACE).
WithDst(dispatcher).
WithRspTo(mapReq.ID).
WithRspTo([]string{mapReq.ID}).
Build()

err := s.cu.ToACE.Send(msg)
Expand Down
Loading