Skip to content

Commit

Permalink
feat(framework) Add run_config to templates
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jul 18, 2024
1 parent a20e4c9 commit 29bc77e
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 15 deletions.
10 changes: 7 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ class FlowerClient(NumPyClient):

def fit(self, parameters, config):
self.set_parameters(parameters)
train(self.net, self.trainloader, epochs=1)
train(
self.net,
self.trainloader,
epochs=int(self.context.run_config["local-epochs"]),
)
return self.get_parameters(config={}), len(self.trainloader), {}

def evaluate(self, parameters, config):
Expand All @@ -45,8 +49,8 @@ def client_fn(context: Context):
CHECKPOINT, num_labels=2
).to(DEVICE)

partition_id = int(context.node_config['partition-id'])
num_partitions = int(context.node_config['num-partitions])
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
trainloader, valloader = load_data(partition_id, num_partitions)

# Return Client instance
Expand Down
18 changes: 10 additions & 8 deletions src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@ from $import_name.task import (
# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self, data):
num_layers = 2
hidden_dim = 32
num_layers = int(self.context.run_config["num-layers"])
hidden_dim = int(self.context.run_config["hidden-dim"])
num_classes = 10
batch_size = 256
num_epochs = 1
learning_rate = 1e-1
batch_size = int(self.context.run_config["batch-size"])
learning_rate = float(self.context.run_config["lr"])
num_epochs = int(self.context.run_config["local-epochs"])

self.train_images, self.train_labels, self.test_images, self.test_labels = data
self.model = MLP(num_layers, self.train_images.shape[-1], hidden_dim, num_classes)
self.optimizer = optim.SGD(learning_rate=learning_rate)
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
self.model = MLP(
num_layers, self.train_images.shape[-1], hidden_dim, num_classes
)
self.optimizer = optim.SGD(learning_rate=learning_rate)
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
self.num_epochs = num_epochs
self.batch_size = batch_size

Expand Down
8 changes: 7 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class FlowerClient(NumPyClient):

def fit(self, parameters, config):
set_weights(self.net, parameters)
results = train(self.net, self.trainloader, self.valloader, 1, DEVICE)
results = train(
self.net,
self.trainloader,
self.valloader,
int(self.context.run_config["local-epochs"]),
DEVICE,
)
return get_weights(self.net), len(self.trainloader.dataset), results

def evaluate(self, parameters, config):
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ class FlowerClient(NumPyClient):

return loss, len(self.X_test), {"accuracy": accuracy}

fds = FederatedDataset(dataset="mnist", partitioners={"train": 2})

def client_fn(context: Context):
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
dataset = fds.load_partition(partition_id, "train").with_format("numpy")

X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
Expand Down
11 changes: 9 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ class FlowerClient(NumPyClient):

def fit(self, parameters, config):
self.model.set_weights(parameters)
self.model.fit(self.x_train, self.y_train, epochs=1, batch_size=32, verbose=0)
self.model.fit(
self.x_train,
self.y_train,
epochs=int(self.context.run_config["local-epochs"]),
batch_size=int(self.context.run_config["batch-size"]),
verbose=bool(self.context.run_config.get("verbose")),
)
return self.model.get_weights(), len(self.x_train), {}

def evaluate(self, parameters, config):
Expand All @@ -34,7 +40,8 @@ def client_fn(context: Context):
net = load_model()

partition_id = int(context.node_config["partition-id"])
x_train, y_train, x_test, y_test = load_data(partition_id, 2)
num_partitions = int(context.node_config["num-partitions"])
x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)

# Return Client instance
return FlowerClient(net, x_train, y_train, x_test, y_test).to_client()
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
[tool.flwr.federations]
default = "localhost"
Expand Down
5 changes: 5 additions & 0 deletions src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
num-layers = "2"
hidden-dim = "32"
batch-size = "256"
lr = "0.1"
[tool.flwr.federations]
default = "localhost"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
[tool.flwr.federations]
default = "localhost"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ clientapp = "$import_name.client_app:app"

[tool.flwr.app.config]
num-server-rounds = "3"
local-epochs = "1"
batch-size = "32"
verbose = "" # Empty string means False

[tool.flwr.federations]
default = "localhost"
Expand Down

0 comments on commit 29bc77e

Please sign in to comment.