Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-491 | Make aggregator plugin-configurable #469

Merged
merged 5 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions fedn/cli/run_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ def reducer_cmd(ctx, host, port, secret_key, local_package, name, init):
@click.option('-c', '--max_clients', required=False, default=30, help='The maximal number of client connections allowed.')
@click.option('-in', '--init', required=False, default=None,
help='Path to configuration file to (re)init combiner.')
@click.option('-a', '--aggregator', required=False, default='fedavg', help='Filename of the aggregator module to use.')
@click.pass_context
def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init):
def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init, aggregator):
"""

:param ctx:
Expand All @@ -261,7 +262,8 @@ def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn,
:param init:
"""
config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'host': host,
'port': port, 'fqdn': fqdn, 'name': name, 'secure': secure, 'verify': verify, 'max_clients': max_clients, 'init': init}
'port': port, 'fqdn': fqdn, 'name': name, 'secure': secure, 'verify': verify, 'max_clients': max_clients,
'init': init, 'aggregator': aggregator}

if config['init']:
apply_config(config)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import importlib
import json
import queue
from abc import ABC, abstractmethod

import fedn.common.net.grpc.fedn_pb2 as fedn

AGGREGATOR_PLUGIN_PATH = "fedn.network.combiner.aggregators.{}"

class Aggregator(ABC):

class AggregatorBase(ABC):
""" Abstract class defining an aggregator. """

@abstractmethod
def __init__(self, id, storage, server, modelservice, control):
def __init__(self, storage, server, modelservice, control):
""" Initialize the aggregator.

:param id: A reference to id of :class: `fedn.network.combiner.Combiner`
Expand All @@ -25,7 +28,6 @@ def __init__(self, id, storage, server, modelservice, control):
"""
self.name = self.__class__.__name__
self.storage = storage
self.id = id
self.server = server
self.modelservice = modelservice
self.control = control
Expand Down Expand Up @@ -105,3 +107,24 @@ def next_model_update(self, helper):
data['round_id'] = config['round_id']

return model_next, data, model_id


def get_aggregator(aggregator_module_name, storage, server, modelservice, control):
""" Return an instance of the helper class.

:param helper_module_name: The name of the helper plugin module.
:type helper_module_name: str
:param storage: Model repository for :class: `fedn.network.combiner.Combiner`
:type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository`
:param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner`
:type server: class: `fedn.network.combiner.Combiner`
:param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService`
:type modelservice: class: `fedn.network.combiner.modelservice.ModelService`
:param control: A handle to the :class: `fedn.network.combiner.round.RoundController`
:type control: class: `fedn.network.combiner.round.RoundController`
:return: An aggregator instance.
:rtype: class: `fedn.combiner.aggregators.AggregatorBase`
"""
aggregator_plugin = AGGREGATOR_PLUGIN_PATH.format(aggregator_module_name)
aggregator = importlib.import_module(aggregator_plugin)
return aggregator.Aggregator(storage, server, modelservice, control)
10 changes: 5 additions & 5 deletions fedn/fedn/network/combiner/aggregators/fedavg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import fedn.common.net.grpc.fedn_pb2 as fedn
from fedn.network.combiner.aggregators.aggregator import Aggregator
from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase


class FedAvg(Aggregator):
class Aggregator(AggregatorBase):
""" Local SGD / Federated Averaging (FedAvg) aggregator. Computes a weighted mean
of parameter updates.

Expand All @@ -19,12 +19,12 @@ class FedAvg(Aggregator):

"""

def __init__(self, id, storage, server, modelservice, control):
def __init__(self, storage, server, modelservice, control):
"""Constructor method"""

super().__init__(id, storage, server, modelservice, control)
super().__init__(storage, server, modelservice, control)

self.name = "FedAvg"
self.name = "fedavg"

def combine_models(self, helper=None, time_window=180, max_nr_models=100, delete_models=True):
"""Aggregate model updates in the queue by computing an incremental
Expand Down
16 changes: 6 additions & 10 deletions fedn/fedn/network/combiner/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import uuid

from fedn.network.combiner.aggregators.fedavg import FedAvg
from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator
from fedn.utils.helpers import get_helper


Expand All @@ -18,8 +18,8 @@ class RoundController:
The round controller recieves round configurations from the global controller
and coordinates model updates and aggregation, and model validations.

:param id: A reference to id of :class: `fedn.network.combiner.Combiner`
:type id: str
:param aggregator_name: The name of the aggregator plugin module.
:type aggregator_name: str
:param storage: Model repository for :class: `fedn.network.combiner.Combiner`
:type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository`
:param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner`
Expand All @@ -28,17 +28,13 @@ class RoundController:
:type modelservice: class: `fedn.network.combiner.modelservice.ModelService`
"""

def __init__(self, id, storage, server, modelservice):
def __init__(self, aggregator_name, storage, server, modelservice):

self.id = id
self.round_configs = queue.Queue()
self.storage = storage
self.server = server
self.modelservice = modelservice

# TODO, make runtime configurable
self.aggregator = FedAvg(
self.id, self.storage, self.server, self.modelservice, self)
self.aggregator = get_aggregator(aggregator_name, self.storage, self.server, self.modelservice, self)

def push_round_config(self, round_config):
"""Add a round_config (job description) to the inbox.
Expand Down Expand Up @@ -366,7 +362,7 @@ def run(self, polling_interval=1.0):
round_meta['time_exec_training'] = time.time() - \
tic
round_meta['status'] = "Success"
round_meta['name'] = self.id
round_meta['name'] = self.server.id
self.server.tracer.set_round_combiner_data(round_meta)
if round_config['delete_models_storage'] == 'True':
self.modelservice.models.delete(round_config['model_id'])
Expand Down
54 changes: 27 additions & 27 deletions fedn/fedn/network/combiner/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def role_to_proto_role(role):
class Combiner(rpc.CombinerServicer, rpc.ReducerServicer, rpc.ConnectorServicer, rpc.ControlServicer):
""" Combiner gRPC server. """

def __init__(self, connect_config):
def __init__(self, config):
""" Initialize a Combiner.

:param connect_config: configuration for the combiner
:type connect_config: dict
:param config: configuration for the combiner
:type config: dict
"""

# Client queues
Expand All @@ -64,24 +64,24 @@ def __init__(self, connect_config):
self.modelservice = ModelService()

# Validate combiner name
match = re.search(VALID_NAME_REGEX, connect_config['name'])
match = re.search(VALID_NAME_REGEX, config['name'])
if not match:
raise ValueError('Unallowed character in combiner name. Allowed characters: a-z, A-Z, 0-9, _, -.')

self.id = connect_config['name']
self.id = config['name']
self.role = Role.COMBINER
self.max_clients = connect_config['max_clients']
self.max_clients = config['max_clients']

# Connector to announce combiner to discover service (reducer)
announce_client = ConnectorCombiner(host=connect_config['discover_host'],
port=connect_config['discover_port'],
myhost=connect_config['host'],
fqdn=connect_config['fqdn'],
myport=connect_config['port'],
token=connect_config['token'],
name=connect_config['name'],
secure=connect_config['secure'],
verify=connect_config['verify'])
announce_client = ConnectorCombiner(host=config['discover_host'],
port=config['discover_port'],
myhost=config['host'],
fqdn=config['fqdn'],
myport=config['port'],
token=config['token'],
name=config['name'],
secure=config['secure'],
verify=config['verify'])

response = None
while True:
Expand All @@ -92,41 +92,41 @@ def __init__(self, connect_config):
time.sleep(5)
continue
if status == Status.Assigned:
config = response
announce_config = response
print(
"COMBINER {0}: Announced successfully".format(self.id), flush=True)
break
if status == Status.UnAuthorized:
print(response, flush=True)
sys.exit("Exiting: Unauthorized")

cert = config['certificate']
key = config['key']
cert = announce_config['certificate']
key = announce_config['key']

if config['certificate']:
cert = base64.b64decode(config['certificate']) # .decode('utf-8')
key = base64.b64decode(config['key']) # .decode('utf-8')
if announce_config['certificate']:
cert = base64.b64decode(announce_config['certificate']) # .decode('utf-8')
key = base64.b64decode(announce_config['key']) # .decode('utf-8')

# Set up gRPC server configuration
grpc_config = {'port': connect_config['port'],
'secure': connect_config['secure'],
grpc_config = {'port': config['port'],
'secure': config['secure'],
'certificate': cert,
'key': key}

# Set up model repository
self.repository = S3ModelRepository(
config['storage']['storage_config'])
announce_config['storage']['storage_config'])

# Create gRPC server
self.server = Server(self, self.modelservice, grpc_config)

# Set up tracer for statestore
self.tracer = MongoTracer(
config['statestore']['mongo_config'], config['statestore']['network_id'])
announce_config['statestore']['mongo_config'], announce_config['statestore']['network_id'])

# Set up round controller
self.control = RoundController(
self.id, self.repository, self, self.modelservice)
self.control = RoundController(config['aggregator'], self.repository, self, self.modelservice)

# Start thread for round controller
threading.Thread(target=self.control.run, daemon=True).start()

Expand Down
2 changes: 0 additions & 2 deletions fedn/fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def __init__(self, message):
self.message = message
super().__init__(self.message)

# Exception class for when model is None


class NoModelException(Exception):
""" Exception class for when model is None """
Expand Down