## This file is part of the federated_learning_p2p (p2pfl) distribution# (see https://github.com/pguijas/p2pfl).# Copyright (c) 2022 Pedro Guijas Bravo.## This program is free software: you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation, version 3.## This program is distributed in the hope that it will be useful, but# WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU# General Public License for more details.## You should have received a copy of the GNU General Public License# along with this program. If not, see <http://www.gnu.org/licenses/>.#"""NodeLearning Interface - Template Pattern."""fromabcimportABC,abstractmethodfromtypingimportDict,Optional,Unionimportnumpyasnpfromp2pfl.learning.aggregators.aggregatorimportAggregatorfromp2pfl.learning.dataset.p2pfl_datasetimportP2PFLDatasetfromp2pfl.learning.frameworks.callbackimportP2PFLCallbackfromp2pfl.learning.frameworks.callback_factoryimportCallbackFactoryfromp2pfl.learning.frameworks.p2pfl_modelimportP2PFLModelfromp2pfl.utils.node_componentimportNodeComponent,allow_no_addr_check
[docs]classLearner(ABC,NodeComponent):""" Template to implement learning processes, including metric monitoring during training. Args: model: The model of the learner. data: The data of the learner. self_addr: The address of the learner. """def__init__(self,model:Optional[P2PFLModel]=None,data:Optional[P2PFLDataset]=None,aggregator:Optional[Aggregator]=None)->None:"""Initialize the learner."""# (addr) SuperNodeComponent.__init__(self)# Indicate aggregator (init callbacks)self.callbacks:list[P2PFLCallback]=[]ifaggregator:self.indicate_aggregator(aggregator)self.epochs:int=1# Default epochs# Model and data init (dummy if not)self.__model:Optional[P2PFLModel]=Noneifmodel:self.set_model(model)self.__data:Optional[P2PFLDataset]=Noneifdata:self.set_data(data)@allow_no_addr_checkdefset_model(self,model:Union[P2PFLModel,list[np.ndarray],bytes])->None:""" Set the model of the learner. Args: model: The model of the learner. """ifisinstance(model,P2PFLModel):self.__model=modelelifisinstance(model,(list,bytes)):self.get_model().set_parameters(model)# Update callbacks with model infoself.update_callbacks_with_model_info()@allow_no_addr_checkdefget_model(self)->P2PFLModel:""" Get the model of the learner. Returns: The model of the learner. """ifself.__modelisNone:raiseValueError("Model not initialized, please ensure to set the model before accessing it. Use .set_model() method.")returnself.__model@allow_no_addr_checkdefset_data(self,data:P2PFLDataset)->None:""" Set the data of the learner. It is used to fit the model. Args: data: The data of the learner. """self.__data=data@allow_no_addr_checkdefget_data(self)->P2PFLDataset:""" Get the data of the learner. Returns: The data of the learner. """ifself.__dataisNone:raiseValueError("Data not initialized, please ensure to set the data before accessing it. Use .set_data() method.")returnself.__data@allow_no_addr_checkdefindicate_aggregator(self,aggregator:Aggregator)->None:""" Indicate to the learner the aggregators that are being used in order to instantiate the callbacks. Args: aggregator: The aggregator used in the learning process. """ifaggregator:self.callbacks=self.callbacks+CallbackFactory.create_callbacks(framework=self.get_framework(),aggregator=aggregator)@allow_no_addr_checkdefset_epochs(self,epochs:int)->None:""" Set the number of epochs of the model. Args: epochs: The number of epochs of the model. """self.epochs=epochs@allow_no_addr_checkdefupdate_callbacks_with_model_info(self)->None:"""Update the callbacks with the model additional information."""new_info=self.get_model().get_info()forcallbackinself.callbacks:try:callback_name=callback.get_name()callback.set_info(new_info[callback_name])exceptKeyError:pass@allow_no_addr_checkdefadd_callback_info_to_model(self)->None:"""Add the additional information from the callbacks to the model."""forcinself.callbacks:self.get_model().add_info(c.get_name(),c.get_info())@abstractmethoddeffit(self)->P2PFLModel:"""Fit the model."""pass@abstractmethoddefinterrupt_fit(self)->None:"""Interrupt the fit process."""pass@abstractmethoddefevaluate(self)->Dict[str,float]:""" Evaluate the model with actual parameters. Returns: The evaluation results. """pass@abstractmethoddefget_framework(self)->str:""" Retrieve the learner name. Returns: The name of the learner class. """pass