Skip to content

Commit

Permalink
Allow specifying that certain talks must schedule after other talks.
Browse files Browse the repository at this point in the history
  • Loading branch information
lukegb committed Feb 4, 2024
1 parent ea8db68 commit d9a17be
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
36 changes: 36 additions & 0 deletions slotmachine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Talk:
preferred_venues: set[VenueID] = field(default_factory=set)
allowed_slots: set[Slot] = field(default_factory=set)
preferred_slots: set[Slot] = field(default_factory=set)
must_schedule_after: set[TalkID] = field(default_factory=set)


class SlotMachine(object):
Expand Down Expand Up @@ -87,6 +88,34 @@ def active(self, slot: Slot, talk_id: TalkID, venue: VenueID) -> pulp.LpVariable
self.var_cache[name] = variable
return variable

def start_slot(self, talk_id: TalkID) -> pulp.LpVariable:
"""A variable that is the number of the slot that talk ID is scheduled to begin."""
name = "S_start_%d" % (talk_id,)
if name in self.var_cache:
return self.var_cache[name]

variable = pulp.LpVariable(name, cat="Integer")
definition = pulp.lpSum(
self.start_var(slot, talk_id, venue) * slot
for slot in self.talks_by_id[talk_id].allowed_slots
for venue in self.talks_by_id[talk_id].venues
)
self.problem.addConstraint(variable == definition)
self.var_cache[name] = variable
return variable

def end_slot(self, talk_id: TalkID) -> pulp.LpVariable:
"""A variable that is the number of the slot that talk ID is scheduled to begin."""
name = "S_end_%d" % (talk_id,)
if name in self.var_cache:
return self.var_cache[name]

variable = pulp.LpVariable(name, cat="Integer")
definition = self.start_slot(talk_id) + self.talks_by_id[talk_id].duration
self.problem.addConstraint(variable == definition)
self.var_cache[name] = variable
return variable

def get_problem(
self, venues: set[VenueID], talks: list[Talk], old_talks: list[OldTalk]
) -> pulp.LpProblem:
Expand All @@ -107,6 +136,13 @@ def get_problem(
== 1
)

# Talks which must precede other talks do that
for talk in talks:
for schedule_before in talk.must_schedule_after:
self.problem.addConstraint(
self.end_slot(schedule_before) <= self.start_slot(talk.id)
)

# At most one talk may be active in a given venue and slot.
for v in venues:
for slot in self.slots_available:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_slotmachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ def talk(
venues: list[int],
speakers: list[str],
slots: Iterable[Slot] | Iterable[int],
must_schedule_after: list[int] = [],
) -> Talk:
return Talk(
id=TalkID(id),
duration=SlotCount(duration),
venues={VenueID(vid) for vid in venues},
speakers=speakers,
allowed_slots={Slot(s) for s in slots},
must_schedule_after={TalkID(t) for t in must_schedule_after},
)


Expand Down Expand Up @@ -253,3 +255,41 @@ def test_talk_clash(self):

# Talk 1 must now be in slot 3 or 4
self.assertIn(talks_slots[1], [3, 4])

def test_must_schedule_after(self):
avail_slots = SlotMachine.calculate_slots(
parser.parse("2016-08-06 13:00"),
parser.parse("2016-08-06 13:00"),
parser.parse("2016-08-06 15:00"),
)
_talk = partial(talk, slots=avail_slots[:], venues=[101])
talk_defs = [
_talk(
id=1, duration=3 + 1, speakers=["Speaker 1"], must_schedule_after=[2]
),
_talk(
id=2, duration=2 + 1, speakers=["Speaker 2"], must_schedule_after=[3]
),
_talk(
id=3, duration=2 + 1, speakers=["Speaker 3"], must_schedule_after=[4]
),
_talk(id=4, duration=2 + 1, speakers=["Speaker 4"]),
]
old_talks = [(0, 1, 101), (3, 2, 101), (6, 3, 101), (9, 4, 101)]
solved = self.schedule_and_basic_asserts(
talk_defs, avail_slots, old_talks=old_talks
)

slots, talks, venues = unzip(solved)
talks_slots = dict(zip(talks, slots))

# The talks are now in the reverse order of the one they were in in old_talks.
self.assertEqual(
{
4: 0,
3: 3,
2: 6,
1: 9,
},
talks_slots,
)

0 comments on commit d9a17be

Please sign in to comment.