p2pfl.learning.frameworks.pytorch.callbacks.scaffold_callback moduleΒΆ

Callback for SCAFFOLD operations (PyTorch Lighting).

class p2pfl.learning.frameworks.pytorch.callbacks.scaffold_callback.SCAFFOLDCallback[source]ΒΆ

Bases: Callback, P2PFLCallback

Callback for scaffold operations to use with PyTorch Lightning.

At the beginning of the training, the callback needs to store the global model and the initial learning rate. Then, after optimization,

static get_name()[source]ΒΆ

Return the name of the callback.

Return type:

str

on_before_zero_grad(trainer, pl_module, optimizer)[source]ΒΆ

Modify model by applying control variate adjustment.

As the optimizer already computed y_i(g_i), we can compute the control variate adjustment as: y_i ← y_i + eta_l * c_i - eta_l * c

Parameters:
  • trainer (Trainer) – The trainer

  • pl_module (LightningModule) – The model.

  • optimizer (Optimizer) – The optimizer.

Return type:

None

on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]ΒΆ

Store the learning rate.

Parameters:
  • trainer (Trainer) – The trainer

  • pl_module (LightningModule) – The model.

  • batch (Any) – The batch.

  • batch_idx (int) – The batch index.

Return type:

None

on_train_end(trainer, pl_module)[source]ΒΆ

Restore the global model.

Parameters:
  • trainer (Trainer) – The trainer

  • pl_module (LightningModule) – The model.

Return type:

None

on_train_start(trainer, pl_module)[source]ΒΆ

Store the global model and the initial learning rate.

Parameters:
  • trainer (Trainer) – The trainer

  • pl_module (LightningModule) – The model.

Return type:

None