#
# This file is part of the federated_learning_p2p (p2pfl) distribution (see https://github.com/pguijas/p2pfl).
# Copyright (c) 2022 Pedro Guijas Bravo.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
"""GRPC communication protocol."""
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Union
from p2pfl.communication.commands.command import Command
from p2pfl.communication.commands.message.heartbeat_command import HeartbeatCommand # Need to decouple this command
from p2pfl.communication.protocols.communication_protocol import CommunicationProtocol
from p2pfl.communication.protocols.exceptions import ProtocolNotStartedError
from p2pfl.communication.protocols.gossiper import Gossiper
from p2pfl.communication.protocols.grpc.address import AddressParser
from p2pfl.communication.protocols.grpc.grpc_client import GrpcClient
from p2pfl.communication.protocols.grpc.grpc_neighbors import GrpcNeighbors
from p2pfl.communication.protocols.grpc.grpc_server import GrpcServer
from p2pfl.communication.protocols.grpc.proto import node_pb2
from p2pfl.communication.protocols.heartbeater import Heartbeater
from p2pfl.settings import Settings
[docs]
def running(func):
"""Ensure that the server is running before executing a method."""
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self._server.is_running():
raise ProtocolNotStartedError("The protocol has not been started.")
return func(self, *args, **kwargs)
return wrapper
[docs]
class GrpcCommunicationProtocol(CommunicationProtocol):
"""
GRPC communication protocol.
Args:
addr: Address of the node.
commands: Commands to add to the communication protocol.
.. todo:: https://grpc.github.io/grpc/python/grpc_asyncio.html
.. todo:: Decouple the heeartbeat command.
"""
def __init__(self, addr: str = "127.0.0.1", commands: Optional[List[Command]] = None) -> None:
"""Initialize the GRPC communication protocol."""
# Parse IP address
parsed_address = AddressParser(addr)
self.addr = parsed_address.get_parsed_address()
# Neighbors
self._neighbors = GrpcNeighbors(self.addr)
# GRPC Client
self._client = GrpcClient(self.addr, self._neighbors)
# Gossip
self._gossiper = Gossiper(self.addr, self._client)
# GRPC
self._server = GrpcServer(self.addr, self._gossiper, self._neighbors, commands)
# Hearbeat
self._heartbeater = Heartbeater(self.addr, self._neighbors, self._client)
# Commands
self.add_command(HeartbeatCommand(self._heartbeater))
if commands is None:
commands = []
self.add_command(commands)
[docs]
def get_address(self) -> str:
"""
Get the address.
Returns:
The address.
"""
return self.addr
[docs]
def start(self) -> None:
"""Start the GRPC communication protocol."""
self._server.start()
self._heartbeater.start()
self._gossiper.start()
[docs]
@running
def stop(self) -> None:
"""Stop the GRPC communication protocol."""
self._heartbeater.stop()
self._gossiper.stop()
self._neighbors.clear_neighbors()
self._server.stop()
[docs]
def add_command(self, cmds: Union[Command, List[Command]]) -> None:
"""
Add a command to the communication protocol.
Args:
cmds: The command to add.
"""
self._server.add_command(cmds)
[docs]
@running
def connect(self, addr: str, non_direct: bool = False) -> bool:
"""
Connect to a neighbor.
Args:
addr: The address to connect to.
non_direct: The non direct flag.
"""
return self._neighbors.add(addr, non_direct=non_direct)
[docs]
@running
def disconnect(self, nei: str, disconnect_msg: bool = True) -> None:
"""
Disconnect from a neighbor.
Args:
nei: The neighbor to disconnect from.
disconnect_msg: The disconnect message flag.
"""
self._neighbors.remove(nei, disconnect_msg=disconnect_msg)
[docs]
def build_msg(self, cmd: str, args: Optional[List[str]] = None, round: Optional[int] = None) -> Any:
"""
Build a message.
Args:
cmd: The message.
args: The arguments.
round: The round.
"""
if args is None:
args = []
return self._client.build_message(cmd, args, round)
[docs]
def build_weights(
self,
cmd: str,
round: int,
serialized_model: bytes,
contributors: Optional[List[str]] = None,
weight: int = 1,
) -> Any:
"""
Build weights.
Args:
cmd: The command.
round: The round.
serialized_model: The serialized model.
contributors: The model contributors.
weight: The weight of the model (amount of samples used).
"""
if contributors is None:
contributors = []
return self._client.build_weights(cmd, round, serialized_model, contributors, weight)
[docs]
@running
def send(
self,
nei: str,
msg: Union[node_pb2.RootMessage],
raise_error: bool = False,
remove_on_error: bool = True,
) -> None:
"""
Send a message to a neighbor.
Args:
nei: The neighbor to send the message.
msg: The message to send.
raise_error: If raise error.
remove_on_error: If remove on error.
"""
self._client.send(nei, msg, raise_error=raise_error, remove_on_error=remove_on_error)
[docs]
@running
def broadcast(self, msg: node_pb2.RootMessage, node_list: Optional[List[str]] = None) -> None:
"""
Broadcast a message to all neighbors.
Args:
msg: The message to broadcast.
node_list: Optional node list.
"""
self._client.broadcast(msg, node_list)
[docs]
@running
def get_neighbors(self, only_direct: bool = False) -> Dict[str, Any]:
"""
Get the neighbors.
Args:
only_direct: The only direct flag.
"""
return self._neighbors.get_all(only_direct)
[docs]
@running
def wait_for_termination(self) -> None:
"""
Get the neighbors.
Args:
only_direct: The only direct flag.
"""
self._server.wait_for_termination()
[docs]
@running
def gossip_weights(
self,
early_stopping_fn: Callable[[], bool],
get_candidates_fn: Callable[[], List[str]],
status_fn: Callable[[], Any],
model_fn: Callable[[str], Any],
period: Optional[float] = None,
create_connection: bool = False,
) -> None:
"""
Gossip model weights.
Args:
early_stopping_fn: The early stopping function.
get_candidates_fn: The get candidates function.
status_fn: The status function.
model_fn: The model function.
period: The period.
create_connection: The create connection flag.
"""
if period is None:
period = Settings.GOSSIP_MODELS_PERIOD
self._gossiper.gossip_weights(
early_stopping_fn,
get_candidates_fn,
status_fn,
model_fn,
period,
create_connection,
)