p2pfl.learning.aggregators.aggregator module

Abstract aggregator.

class p2pfl.learning.aggregators.aggregator.Aggregator(disable_partial_aggregation=False)[source]

Bases: NodeComponent

Abstract base class for all aggregators.

Important

We do not recomend to inherit directly from this class. Instead, inherit from: - WeightAggregator: For neural network aggregation (FedAvg, etc.) - TreeAggregator: For tree ensemble aggregation (FedXgbBagging, etc.)

Parameters:

disable_partial_aggregation (bool) – Whether to disable partial aggregation.

SUPPORTS_PARTIAL_AGGREGATION

Whether partial aggregation is supported.

SUPPORTS_PARTIAL_AGGREGATION: bool = False
add_model(model)[source]

Add a model. The first model to be added starts the run method (timeout).

Parameters:

model (P2PFLModel) – Model to add.

Return type:

list[str]

Returns:

List of contributors.

addr: str
aggregate(models)[source]

Validate and aggregate the models.

Automatically calls validate_models() before delegating to _aggregate().

Parameters:

models (list[P2PFLModel]) – List of models to aggregate.

Return type:

P2PFLModel

Returns:

The aggregated model.

clear()[source]

Clear the aggregation (remove trainset and release locks).

Return type:

None

get_aggregated_models()[source]

Get the list of aggregated models.

Return type:

list[str]

Returns:

Name of nodes that colaborated to get the model.

get_missing_models()[source]

Obtain missing models for the aggregation.

Return type:

set

Returns:

A set of missing models.

get_model(except_nodes)[source]

Get corresponding aggregation depending if aggregator supports partial aggregations.

Parameters:

except_nodes – List of nodes to exclude from the aggregation.

Return type:

P2PFLModel

get_required_callbacks()[source]

Get the required callbacks for the aggregation.

Return type:

list[str]

Returns:

List of required callbacks.

partial_aggregation: bool
set_nodes_to_aggregate(nodes_to_aggregate)[source]

List with the name of nodes to aggregate. Be careful, by setting new nodes, the actual aggregation will be lost.

Parameters:

nodes_to_aggregate (list[str]) – List of nodes to aggregate. Empty for no aggregation.

Raises:

RuntimeError – If the aggregation is running.

Return type:

None

validate_models(models)[source]

Validate that all models are compatible with this aggregator.

Parameters:

models (list[P2PFLModel]) – List of models to validate.

Raises:

IncompatibleModelError – If any model is incompatible with this aggregator.

Return type:

None

wait_and_get_aggregation(timeout=300)[source]

Wait for aggregation to finish.

Parameters:

timeout (int) – Timeout in seconds.

Return type:

P2PFLModel

Returns:

Aggregated model.

Raises:

Exception – If waiting for an aggregated model and several models were received.

exception p2pfl.learning.aggregators.aggregator.IncompatibleModelError[source]

Bases: Exception

Exception raised when a model type is incompatible with the aggregator.

exception p2pfl.learning.aggregators.aggregator.NoModelsToAggregateError[source]

Bases: Exception

Exception raised when there are no models to aggregate.

class p2pfl.learning.aggregators.aggregator.TreeAggregator(disable_partial_aggregation=False)[source]

Bases: Aggregator

Base class for aggregators that work with tree ensemble models.

Inherit from this class for aggregators that:
  • Combine trees via bagging, boosting, or cycling

  • Work with XGBoost models

  • Expect serialized tree structures

The validation is automatic via the template pattern: aggregate() calls validate_models() before delegating to _aggregate().

Example

>>> class MyTreeAggregator(TreeAggregator):
...     def _aggregate(self, models):
...         # Validation already done by aggregate()
...         # ... your tree combination logic
...         pass
Parameters:

disable_partial_aggregation (bool)

class p2pfl.learning.aggregators.aggregator.WeightAggregator(disable_partial_aggregation=False)[source]

Bases: Aggregator

Base class for aggregators that work with neural network models.

Inherit from this class for aggregators that:
  • Average or combine weight tensors

  • Work with PyTorch, TensorFlow, Flax models

  • Expect list[np.ndarray] of float32/float64 parameter arrays

The validation is automatic via the template pattern: aggregate() calls validate_models() before delegating to _aggregate().

Example

>>> class MyAggregator(WeightAggregator):
...     def _aggregate(self, models):
...         # Validation already done by aggregate()
...         # ... your averaging logic
...         pass
Parameters:

disable_partial_aggregation (bool)