p2pfl.learning.frameworks.tensorflow.callbacks.scaffold_callback moduleΒΆ
Callback for SCAFFOLD operations (Keras).
- class p2pfl.learning.frameworks.tensorflow.callbacks.scaffold_callback.SCAFFOLDCallback[source]ΒΆ
Bases:
Callback
,P2PFLCallback
Callback for SCAFFOLD operations to use with TensorFlow Keras.
At the beginning of the training, the callback initializes control variates and substitutes the optimizer with a custom one to apply control variate adjustments. After training, it updates the local control variate (c_i) and computes the deltas.
- on_train_batch_end(batch, logs=None)[source]ΒΆ
Increment the local step counter after each batch.
- Parameters:
batch (
Any
) β The batch.logs (
Optional
[Dict
[str
,Any
]]) β The logs.
- Return type:
None
- on_train_begin(logs=None)[source]ΒΆ
Initialize control variates and replace the optimizer with custom one.
- Return type:
None