p2pfl.learning.aggregators.aggregator module¶
Abstract aggregator.
- class p2pfl.learning.aggregators.aggregator.Aggregator(disable_partial_aggregation=False)[source]¶
Bases:
NodeComponentAbstract base class for all aggregators.
Important
We do not recomend to inherit directly from this class. Instead, inherit from: -
WeightAggregator: For neural network aggregation (FedAvg, etc.) -TreeAggregator: For tree ensemble aggregation (FedXgbBagging, etc.)- Parameters:
disable_partial_aggregation (
bool) – Whether to disable partial aggregation.
- SUPPORTS_PARTIAL_AGGREGATION¶
Whether partial aggregation is supported.
-
SUPPORTS_PARTIAL_AGGREGATION:
bool= False¶
- add_model(model)[source]¶
Add a model. The first model to be added starts the run method (timeout).
- Parameters:
model (
P2PFLModel) – Model to add.- Return type:
list[str]- Returns:
List of contributors.
-
addr:
str¶
- aggregate(models)[source]¶
Validate and aggregate the models.
Automatically calls
validate_models()before delegating to_aggregate().- Parameters:
models (
list[P2PFLModel]) – List of models to aggregate.- Return type:
- Returns:
The aggregated model.
- get_aggregated_models()[source]¶
Get the list of aggregated models.
- Return type:
list[str]- Returns:
Name of nodes that colaborated to get the model.
- get_missing_models()[source]¶
Obtain missing models for the aggregation.
- Return type:
set- Returns:
A set of missing models.
- get_model(except_nodes)[source]¶
Get corresponding aggregation depending if aggregator supports partial aggregations.
- Parameters:
except_nodes – List of nodes to exclude from the aggregation.
- Return type:
- get_required_callbacks()[source]¶
Get the required callbacks for the aggregation.
- Return type:
list[str]- Returns:
List of required callbacks.
-
partial_aggregation:
bool¶
- set_nodes_to_aggregate(nodes_to_aggregate)[source]¶
List with the name of nodes to aggregate. Be careful, by setting new nodes, the actual aggregation will be lost.
- Parameters:
nodes_to_aggregate (
list[str]) – List of nodes to aggregate. Empty for no aggregation.- Raises:
RuntimeError – If the aggregation is running.
- Return type:
None
- validate_models(models)[source]¶
Validate that all models are compatible with this aggregator.
- Parameters:
models (
list[P2PFLModel]) – List of models to validate.- Raises:
IncompatibleModelError – If any model is incompatible with this aggregator.
- Return type:
None
- exception p2pfl.learning.aggregators.aggregator.IncompatibleModelError[source]¶
Bases:
ExceptionException raised when a model type is incompatible with the aggregator.
- exception p2pfl.learning.aggregators.aggregator.NoModelsToAggregateError[source]¶
Bases:
ExceptionException raised when there are no models to aggregate.
- class p2pfl.learning.aggregators.aggregator.TreeAggregator(disable_partial_aggregation=False)[source]¶
Bases:
AggregatorBase class for aggregators that work with tree ensemble models.
- Inherit from this class for aggregators that:
Combine trees via bagging, boosting, or cycling
Work with XGBoost models
Expect serialized tree structures
The validation is automatic via the template pattern:
aggregate()callsvalidate_models()before delegating to_aggregate().Example
>>> class MyTreeAggregator(TreeAggregator): ... def _aggregate(self, models): ... # Validation already done by aggregate() ... # ... your tree combination logic ... pass
- Parameters:
disable_partial_aggregation (
bool)
- class p2pfl.learning.aggregators.aggregator.WeightAggregator(disable_partial_aggregation=False)[source]¶
Bases:
AggregatorBase class for aggregators that work with neural network models.
- Inherit from this class for aggregators that:
Average or combine weight tensors
Work with PyTorch, TensorFlow, Flax models
Expect
list[np.ndarray]of float32/float64 parameter arrays
The validation is automatic via the template pattern:
aggregate()callsvalidate_models()before delegating to_aggregate().Example
>>> class MyAggregator(WeightAggregator): ... def _aggregate(self, models): ... # Validation already done by aggregate() ... # ... your averaging logic ... pass
- Parameters:
disable_partial_aggregation (
bool)