p2pfl.learning.frameworks.pytorch.callbacks.scaffold_callback moduleΒΆ
Callback for SCAFFOLD operations (PyTorch Lighting).
- class p2pfl.learning.frameworks.pytorch.callbacks.scaffold_callback.SCAFFOLDCallback[source]ΒΆ
Bases:
Callback
,P2PFLCallback
Callback for scaffold operations to use with PyTorch Lightning.
At the beginning of the training, the callback needs to store the global model and the initial learning rate. Then, after optimization,
- on_before_zero_grad(trainer, pl_module, optimizer)[source]ΒΆ
Modify model by applying control variate adjustment.
As the optimizer already computed y_i(g_i), we can compute the control variate adjustment as: y_i β y_i + eta_l * c_i - eta_l * c
- Parameters:
trainer (
Trainer
) β The trainerpl_module (
LightningModule
) β The model.optimizer (
Optimizer
) β The optimizer.
- Return type:
None
- on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]ΒΆ
Store the learning rate.
- Parameters:
trainer (
Trainer
) β The trainerpl_module (
LightningModule
) β The model.batch (
Any
) β The batch.batch_idx (
int
) β The batch index.
- Return type:
None