## This file is part of the federated_learning_p2p (p2pfl) distribution# (see https://github.com/pguijas/p2pfl).# Copyright (c) 2022 Pedro Guijas Bravo.## This program is free software: you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation, version 3.## This program is distributed in the hope that it will be useful, but# WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU# General Public License for more details.## You should have received a copy of the GNU General Public License# along with this program. If not, see <http://www.gnu.org/licenses/>.#"""NodeLearning Interface - Template Pattern."""fromabcimportABC,abstractmethodfromtypingimportDict,List,Optional,Unionimportnumpyasnpfromp2pfl.learning.aggregators.aggregatorimportAggregatorfromp2pfl.learning.dataset.p2pfl_datasetimportP2PFLDatasetfromp2pfl.learning.frameworks.callbackimportP2PFLCallbackfromp2pfl.learning.frameworks.callback_factoryimportCallbackFactoryfromp2pfl.learning.frameworks.p2pfl_modelimportP2PFLModel
[docs]classLearner(ABC):""" Template to implement learning processes, including metric monitoring during training. Args: model: The model of the learner. data: The data of the learner. self_addr: The address of the learner. """def__init__(self,model:P2PFLModel,data:P2PFLDataset,self_addr:str="unknown-node",aggregator:Optional[Aggregator]=None)->None:"""Initialize the learner."""self.model:P2PFLModel=modelself.data:P2PFLDataset=dataself._self_addr=self_addrself.callbacks:List[P2PFLCallback]=[]ifaggregator:self.callbacks=CallbackFactory.create_callbacks(framework=self.get_framework(),aggregator=aggregator)self.epochs:int=1# Default epochs
[docs]defset_addr(self,addr:str)->None:""" Set the address of the learner. Args: addr: The address of the learner. """self._self_addr=addr
[docs]defset_model(self,model:Union[P2PFLModel,List[np.ndarray],bytes])->None:""" Set the model of the learner. Args: model: The model of the learner. """ifisinstance(model,P2PFLModel):self.model=modelelifisinstance(model,(list,bytes)):self.model.set_parameters(model)# Update callbacks with model infoself.update_callbacks_with_model_info()
[docs]defget_model(self)->P2PFLModel:""" Get the model of the learner. Returns: The model of the learner. """returnself.model
[docs]defset_data(self,data:P2PFLDataset)->None:""" Set the data of the learner. It is used to fit the model. Args: data: The data of the learner. """self.data=data
[docs]defget_data(self)->P2PFLDataset:""" Get the data of the learner. Returns: The data of the learner. """returnself.data
[docs]defset_epochs(self,epochs:int)->None:""" Set the number of epochs of the model. Args: epochs: The number of epochs of the model. """self.epochs=epochs
[docs]defupdate_callbacks_with_model_info(self)->None:"""Update the callbacks with the model additional information."""new_info=self.model.get_info()forcallbackinself.callbacks:try:callback_name=callback.get_name()callback.set_info(new_info[callback_name])exceptKeyError:pass
[docs]defadd_callback_info_to_model(self)->None:"""Add the additional information from the callbacks to the model."""forcinself.callbacks:self.model.add_info(c.get_name(),c.get_info())
[docs]@abstractmethoddeffit(self)->P2PFLModel:"""Fit the model."""pass
[docs]@abstractmethoddefinterrupt_fit(self)->None:"""Interrupt the fit process."""pass
[docs]@abstractmethoddefevaluate(self)->Dict[str,float]:""" Evaluate the model with actual parameters. Returns: The evaluation results. """pass
[docs]@abstractmethoddefget_framework(self)->str:""" Retrieve the learner name. Returns: The name of the learner class. """pass