p2pfl.learning.frameworks.pytorch.callbacks.fedprox_callback moduleΒΆ

FedProx Callback for PyTorch Lightning.

class p2pfl.learning.frameworks.pytorch.callbacks.fedprox_callback.FedProxCallback[source]ΒΆ

Bases: Callback, P2PFLCallback

PyTorch Lightning Callback to implement the FedProx algorithm.

This callback modifies the gradients before the optimizer step by adding the gradient of the proximal term: mu * (w - w_t).

static get_name()[source]ΒΆ

Return the name of the callback.

Return type:

str

on_before_optimizer_step(trainer, pl_module, optimizer)[source]ΒΆ

Add the proximal gradient term: mu * (w - w_global).

Parameters:
  • trainer (Trainer) – The trainer

  • pl_module (LightningModule) – The model

  • optimizer (Optimizer) – The optimizer

Return type:

None

on_train_start(trainer, pl_module)[source]ΒΆ

Store the initial model parameters and proximal coefficient.

Parameters:
  • trainer (Trainer) – The trainer

  • pl_module (LightningModule) – The model.

Return type:

None