Source code for p2pfl.communication.protocols.grpc.grpc_communication_protocol

#
# This file is part of the federated_learning_p2p (p2pfl) distribution (see https://github.com/pguijas/p2pfl).
# Copyright (c) 2022 Pedro Guijas Bravo.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

"""GRPC communication protocol."""

from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Union

from p2pfl.communication.commands.command import Command
from p2pfl.communication.commands.message.heartbeat_command import HeartbeatCommand  # Need to decouple this command
from p2pfl.communication.protocols.communication_protocol import CommunicationProtocol
from p2pfl.communication.protocols.exceptions import ProtocolNotStartedError
from p2pfl.communication.protocols.gossiper import Gossiper
from p2pfl.communication.protocols.grpc.address import AddressParser
from p2pfl.communication.protocols.grpc.grpc_client import GrpcClient
from p2pfl.communication.protocols.grpc.grpc_neighbors import GrpcNeighbors
from p2pfl.communication.protocols.grpc.grpc_server import GrpcServer
from p2pfl.communication.protocols.grpc.proto import node_pb2
from p2pfl.communication.protocols.heartbeater import Heartbeater
from p2pfl.settings import Settings


[docs] def running(func): """Ensure that the server is running before executing a method.""" @wraps(func) def wrapper(self, *args, **kwargs): if not self._server.is_running(): raise ProtocolNotStartedError("The protocol has not been started.") return func(self, *args, **kwargs) return wrapper
[docs] class GrpcCommunicationProtocol(CommunicationProtocol): """ GRPC communication protocol. Args: addr: Address of the node. commands: Commands to add to the communication protocol. .. todo:: https://grpc.github.io/grpc/python/grpc_asyncio.html .. todo:: Decouple the heeartbeat command. """ def __init__(self, addr: str = "127.0.0.1", commands: Optional[List[Command]] = None) -> None: """Initialize the GRPC communication protocol.""" # Parse IP address parsed_address = AddressParser(addr) self.addr = parsed_address.get_parsed_address() # Neighbors self._neighbors = GrpcNeighbors(self.addr) # GRPC Client self._client = GrpcClient(self.addr, self._neighbors) # Gossip self._gossiper = Gossiper(self.addr, self._client) # GRPC self._server = GrpcServer(self.addr, self._gossiper, self._neighbors, commands) # Hearbeat self._heartbeater = Heartbeater(self.addr, self._neighbors, self._client) # Commands self.add_command(HeartbeatCommand(self._heartbeater)) if commands is None: commands = [] self.add_command(commands)
[docs] def get_address(self) -> str: """ Get the address. Returns: The address. """ return self.addr
[docs] def start(self) -> None: """Start the GRPC communication protocol.""" self._server.start() self._heartbeater.start() self._gossiper.start()
[docs] @running def stop(self) -> None: """Stop the GRPC communication protocol.""" self._heartbeater.stop() self._gossiper.stop() self._neighbors.clear_neighbors() self._server.stop()
[docs] def add_command(self, cmds: Union[Command, List[Command]]) -> None: """ Add a command to the communication protocol. Args: cmds: The command to add. """ self._server.add_command(cmds)
[docs] @running def connect(self, addr: str, non_direct: bool = False) -> bool: """ Connect to a neighbor. Args: addr: The address to connect to. non_direct: The non direct flag. """ return self._neighbors.add(addr, non_direct=non_direct)
[docs] @running def disconnect(self, nei: str, disconnect_msg: bool = True) -> None: """ Disconnect from a neighbor. Args: nei: The neighbor to disconnect from. disconnect_msg: The disconnect message flag. """ self._neighbors.remove(nei, disconnect_msg=disconnect_msg)
[docs] def build_msg(self, cmd: str, args: Optional[List[str]] = None, round: Optional[int] = None) -> Any: """ Build a message. Args: cmd: The message. args: The arguments. round: The round. """ if args is None: args = [] return self._client.build_message(cmd, args, round)
[docs] def build_weights( self, cmd: str, round: int, serialized_model: bytes, contributors: Optional[List[str]] = None, weight: int = 1, ) -> Any: """ Build weights. Args: cmd: The command. round: The round. serialized_model: The serialized model. contributors: The model contributors. weight: The weight of the model (amount of samples used). """ if contributors is None: contributors = [] return self._client.build_weights(cmd, round, serialized_model, contributors, weight)
[docs] @running def send( self, nei: str, msg: Union[node_pb2.RootMessage], raise_error: bool = False, remove_on_error: bool = True, ) -> None: """ Send a message to a neighbor. Args: nei: The neighbor to send the message. msg: The message to send. raise_error: If raise error. remove_on_error: If remove on error. """ self._client.send(nei, msg, raise_error=raise_error, remove_on_error=remove_on_error)
[docs] @running def broadcast(self, msg: node_pb2.RootMessage, node_list: Optional[List[str]] = None) -> None: """ Broadcast a message to all neighbors. Args: msg: The message to broadcast. node_list: Optional node list. """ self._client.broadcast(msg, node_list)
[docs] @running def get_neighbors(self, only_direct: bool = False) -> Dict[str, Any]: """ Get the neighbors. Args: only_direct: The only direct flag. """ return self._neighbors.get_all(only_direct)
[docs] @running def wait_for_termination(self) -> None: """ Get the neighbors. Args: only_direct: The only direct flag. """ self._server.wait_for_termination()
[docs] @running def gossip_weights( self, early_stopping_fn: Callable[[], bool], get_candidates_fn: Callable[[], List[str]], status_fn: Callable[[], Any], model_fn: Callable[[str], Any], period: Optional[float] = None, create_connection: bool = False, ) -> None: """ Gossip model weights. Args: early_stopping_fn: The early stopping function. get_candidates_fn: The get candidates function. status_fn: The status function. model_fn: The model function. period: The period. create_connection: The create connection flag. """ if period is None: period = Settings.GOSSIP_MODELS_PERIOD self._gossiper.gossip_weights( early_stopping_fn, get_candidates_fn, status_fn, model_fn, period, create_connection, )