p2pfl.learning.frameworks.flax.flax_learner moduleΒΆ

Flax Learner for P2PFL.

class p2pfl.learning.frameworks.flax.flax_learner.FlaxLearner(model=None, data=None, aggregator=None)[source]ΒΆ

Bases: Learner

Learner for Flax models in P2PFL.

Parameters:
  • model (P2PFLModel | None) – The FlaxModel instance.

  • data (P2PFLDataset | None) – The P2PFLDataset instance.

  • self_addr – The address of this node.

  • aggregator (Aggregator | None)

evaluate()[source]ΒΆ

Evaluate the Flax model.

Return type:

dict[str, float]

fit()[source]ΒΆ

Fit the model.

Return type:

P2PFLModel

property flax_model: FlaxModelΒΆ

Retrieve the Flax model.

get_framework()[source]ΒΆ

Retrieve the learner name.

Return type:

str

Returns:

The name of the learner class.

interrupt_fit()[source]ΒΆ

Interrupt the fit process.

Return type:

None

train_step(state, x, y)[source]ΒΆ

Perform a single training step.

Parameters:
  • state (TrainState)

  • x (ndarray)

  • y (ndarray)

Return type:

tuple[TrainState, float, float]