Skip to content

Commit

Permalink
Adds GPU support for Cloudburst (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
vsreekanti authored May 6, 2020
1 parent 676f318 commit 1bd83d9
Show file tree
Hide file tree
Showing 19 changed files with 594 additions and 229 deletions.
9 changes: 8 additions & 1 deletion cloudburst/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def register(self, function, name):
else:
raise RuntimeError(f'Unexpected error while registering function: {resp}.')

def register_dag(self, name, functions, connections, colocated=[]):
def register_dag(self, name, functions, connections, gpu_functions=[],
batching_functions=[], colocated=[]):
'''
Registers a new DAG with the system. This operation will fail if any of
the functions provided cannot be identified in the system.
Expand Down Expand Up @@ -186,6 +187,12 @@ def register_dag(self, name, functions, connections, colocated=[]):
fname = function
invalids = []

if function in gpu_functions:
ref.gpu = True

if function in batching_functions:
ref.batching = True

ref.name = fname
for invalid in invalids:
ref.invalid_results.append(serializer.dump(invalid))
Expand Down
187 changes: 120 additions & 67 deletions cloudburst/server/executor/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,18 @@ def _exec_func_normal(kvs, func, args, user_lib, cache):
processed += (arg,)
args = processed

refs = list(filter(lambda a: isinstance(a, CloudburstReference), args))
if all([type(arg) == list for arg in args]): # A batching request.
refs = []

# For a batching request, we pull out the references in each sublist of
# arguments.
for arg in args:
arg_refs = list(
filter(lambda a: isinstance(a, CloudburstReference), arg))
refs.extend(arg_refs)
else:
# For non-batching requests, we just filter all of the arguments.
refs = list(filter(lambda a: isinstance(a, CloudburstReference), args))

if refs:
refs = _resolve_ref_normal(refs, kvs, cache)
Expand All @@ -123,9 +134,20 @@ def _run_function(func, refs, args, user_lib):
# If any of the arguments are references, we insert the resolved reference
# instead of the raw value.
for arg in args:
if isinstance(arg, CloudburstReference):
func_args += (refs[arg.key],)
# The standard non-batching approach to resolving references. We simply
# take the KV-pairs and swap in the actual values for the references.
if type(arg) != list:
if isinstance(arg, CloudburstReference):
func_args += (refs[arg.key],)
else:
func_args += (arg,)
else:
# The batching approach: We look at each value to check if it's a
# ref then append the whole list to the argument set.
for idx, val in enumerate(arg):
if isinstance(val, CloudburstReference):
arg[idx] = refs[val.key]

func_args += (arg,)

return func(*func_args)
Expand Down Expand Up @@ -219,29 +241,32 @@ def _resolve_ref_causal(refs, kvs, schedule, key_version_locations,
return kv_pairs


def exec_dag_function(pusher_cache, kvs, triggers, function, schedule,
user_library, dag_runtimes, cache, schedulers):
if schedule.consistency == NORMAL:
finished, success = _exec_dag_function_normal(pusher_cache, kvs,
triggers, function,
schedule, user_library,
cache, schedulers)
def exec_dag_function(pusher_cache, kvs, trigger_sets, function, schedules,
user_library, dag_runtimes, cache, schedulers, batching):
if schedules[0].consistency == NORMAL:
finished, successes = _exec_dag_function_normal(pusher_cache, kvs,
trigger_sets, function,
schedules,
user_library, cache,
schedulers, batching)
else:
finished, success = _exec_dag_function_causal(pusher_cache, kvs,
triggers, function,
schedule, user_library)
finished, successes = _exec_dag_function_causal(pusher_cache, kvs,
trigger_sets, function,
schedules, user_library)

# If finished is true, that means that this executor finished the DAG
# request. We will report the end-to-end latency for this DAG if so.
if finished:
dname = schedule.dag.name
if dname not in dag_runtimes:
dag_runtimes[dname] = []
for schedule, success in zip(schedules, successes):
if success:
dname = schedule.dag.name
if dname not in dag_runtimes:
dag_runtimes[dname] = []

runtime = time.time() - schedule.start_time
dag_runtimes[schedule.dag.name].append(runtime)
runtime = time.time() - schedule.start_time
dag_runtimes[schedule.dag.name].append(runtime)

return success
return successes


def _construct_trigger(sid, fname, result):
Expand All @@ -257,67 +282,95 @@ def _construct_trigger(sid, fname, result):
return trigger


def _exec_dag_function_normal(pusher_cache, kvs, triggers, function, schedule,
user_lib, cache, schedulers):
fname = schedule.target_function
fargs = list(schedule.arguments[fname].values)
def _exec_dag_function_normal(pusher_cache, kvs, trigger_sets, function,
schedules, user_lib, cache, schedulers,
batching):
fname = schedules[0].target_function

for trigger in triggers:
fargs += list(trigger.arguments.values)
# We construct farg_sets to have a request by request set of arguments.
# That is, each element in farg_sets will have all the arguments for one
# invocation.
farg_sets = []
for schedule, trigger_set in zip(schedules, trigger_sets):
fargs = list(schedule.arguments[fname].values)

fargs = [serializer.load(arg) for arg in fargs]
result = _exec_func_normal(kvs, function, fargs, user_lib, cache)
for trigger in trigger_set:
fargs += list(trigger.arguments.values)

this_ref = None
for ref in schedule.dag.functions:
if ref.name == fname:
this_ref = ref # There must be a match.
fargs = [serializer.load(arg) for arg in fargs]
farg_sets.append(fargs)

success = True
if this_ref.type == MULTIEXEC:
if serializer.dump(result) in this_ref.invalid_results:
return False, False
if batching:
fargs = [[]] * len(farg_sets[0])
for farg_set in farg_sets:
for idx, val in enumerate(farg_set):
fargs[idx].append(val)
else: # There will only be one thing in farg_sets
fargs = farg_sets[0]

result_list = _exec_func_normal(kvs, function, fargs, user_lib, cache)
if not isinstance(result_list, list):
result_list = [result_list]

successes = []
is_sink = True
new_trigger = _construct_trigger(schedule.id, fname, result)
for conn in schedule.dag.connections:
if conn.source == fname:
is_sink = False
new_trigger.target_function = conn.sink

dest_ip = schedule.locations[conn.sink]
sckt = pusher_cache.get(sutils.get_dag_trigger_address(dest_ip))
sckt.send(new_trigger.SerializeToString())
for schedule, result in zip(schedules, result_list):
this_ref = None
for ref in schedule.dag.functions:
if ref.name == fname:
this_ref = ref # There must be a match.

if is_sink:
if schedule.continuation.name:
cont = schedule.continuation
cont.id = schedule.id
cont.result = serializer.dump(result)

logging.info('Sending continuation to scheduler for DAG %s.' %
(schedule.id))
sckt = pusher_cache.get(utils.get_continuation_address(schedulers))
sckt.send(cont.SerializeToString())
elif schedule.response_address:
sckt = pusher_cache.get(schedule.response_address)
logging.info('DAG %s (ID %s) result returned to requester.' %
(schedule.dag.name, trigger.id))
sckt.send(serializer.dump(result))
if this_ref.type == MULTIEXEC:
if serializer.dump(result) in this_ref.invalid_results:
successes.append(False)
continue

else:
lattice = serializer.dump_lattice(result)
output_key = schedule.output_key if schedule.output_key \
else schedule.id
logging.info('DAG %s (ID %s) result in KVS at %s.' %
(schedule.dag.name, trigger.id, output_key))
kvs.put(output_key, lattice)
successes.append(True)
new_trigger = _construct_trigger(schedule.id, fname, result)
for conn in schedule.dag.connections:
if conn.source == fname:
is_sink = False
new_trigger.target_function = conn.sink

dest_ip = schedule.locations[conn.sink]
sckt = pusher_cache.get(sutils.get_dag_trigger_address(dest_ip))
sckt.send(new_trigger.SerializeToString())

if is_sink:
if schedule.continuation.name:
cont = schedule.continuation
cont.id = schedule.id
cont.result = serializer.dump(result)

logging.info('Sending continuation to scheduler for DAG %s.' %
(schedule.id))
sckt = pusher_cache.get(utils.get_continuation_address(schedulers))
sckt.send(cont.SerializeToString())
elif schedule.response_address:
sckt = pusher_cache.get(schedule.response_address)
logging.info('DAG %s (ID %s) result returned to requester.' %
(schedule.dag.name, trigger.id))
sckt.send(serializer.dump(result))

else:
lattice = serializer.dump_lattice(result)
output_key = schedule.output_key if schedule.output_key \
else schedule.id
logging.info('DAG %s (ID %s) result in KVS at %s.' %
(schedule.dag.name, trigger.id, output_key))
kvs.put(output_key, lattice)

return is_sink, success
return is_sink, successes


# Causal mode does not currently support batching, so there should only ever be
# one trigger set and oone schedule.
def _exec_dag_function_causal(pusher_cache, kvs, triggers, function, schedule,
user_lib):
schedule = schedule[0]
triggers = triggers[0]

fname = schedule.target_function
fargs = list(schedule.arguments[fname].values)

Expand Down Expand Up @@ -424,7 +477,7 @@ def _exec_dag_function_causal(pusher_cache, kvs, triggers, function, schedule,
sckt = pusher_cache.get(gc_address)
sckt.send_string(schedule.client_id)

return is_sink, success
return is_sink, [success]


def _compute_children_read_set(schedule):
Expand Down
15 changes: 12 additions & 3 deletions cloudburst/server/executor/pin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def pin(pin_socket, pusher_cache, kvs, status, function_cache, runtimes,
exec_counts, user_library, local):
exec_counts, user_library, local, batching):
serialized = pin_socket.recv()
pin_msg = PinFunction()
pin_msg.ParseFromString(serialized)
Expand All @@ -35,7 +35,7 @@ def pin(pin_socket, pusher_cache, kvs, status, function_cache, runtimes,
or not status.running)):
sutils.error.SerializeToString()
sckt.send(sutils.error.SerializeToString())
return
return batching

func = utils.retrieve_function(pin_msg.name, kvs, user_library)

Expand All @@ -54,9 +54,18 @@ def pin(pin_socket, pusher_cache, kvs, status, function_cache, runtimes,
runtimes[name] = []
exec_counts[name] = 0
logging.info('Adding function %s to my local pinned functions.' % (name))


if pin_msg.batching and len(status.functions) > 1:
raise RuntimeError('There is more than one pinned function (we are'
+ ' operating in local mode), and the function'
+ ' attempting to be pinned has batching enabled. This'
+ ' is not allowed -- you can only use batching in'
+ ' cluster mode or in local mode with one function.')

sckt.send(sutils.ok_resp)

return pin_msg.batching


def unpin(unpin_socket, status, function_cache, runtimes, exec_counts):
name = unpin_socket.recv_string()
Expand Down
Loading

0 comments on commit 1bd83d9

Please sign in to comment.