Source code for p2pfl.communication.protocols.grpc.grpc_client

#
# This file is part of the federated_learning_p2p (p2pfl) distribution (see https://github.com/pguijas/p2pfl).
# Copyright (c) 2022 Pedro Guijas Bravo.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

"""GRPC client."""

import random
from datetime import datetime
from os.path import isfile
from typing import List, Optional

import grpc

from p2pfl.communication.protocols.client import Client
from p2pfl.communication.protocols.exceptions import CommunicationError, NeighborNotConnectedError
from p2pfl.communication.protocols.grpc.grpc_neighbors import GrpcNeighbors
from p2pfl.communication.protocols.grpc.proto import node_pb2, node_pb2_grpc
from p2pfl.management.logger import logger
from p2pfl.settings import Settings


[docs] class GrpcClient(Client): """ Implementation of the client side (i.e. who initiates the communication) of the GRPC communication protocol. Args: self_addr: Address of the node. neighbors: Neighbors of the node. """ def __init__(self, self_addr: str, neighbors: GrpcNeighbors) -> None: """Initialize the GRPC client.""" self.__self_addr = self_addr self.__neighbors = neighbors #### # Message Building ####
[docs] def build_message(self, cmd: str, args: Optional[List[str]] = None, round: Optional[int] = None) -> node_pb2.RootMessage: """ Build a RootMessage to send to the neighbors. Args: cmd: Command of the message. args: Arguments of the message. round: Round of the message. Returns: RootMessage to send. """ if round is None: round = -1 if args is None: args = [] hs = hash(str(cmd) + str(args) + str(datetime.now()) + str(random.randint(0, 100000))) args = [str(a) for a in args] return node_pb2.RootMessage( source=self.__self_addr, round=round, cmd=cmd, message=node_pb2.Message( ttl=Settings.TTL, hash=hs, args=args, ), )
[docs] def build_weights( self, cmd: str, round: int, serialized_model: bytes, contributors: Optional[List[str]] = None, weight: int = 1, ) -> node_pb2.RootMessage: """ Build a RootMessage with a Weights payload to send to the neighbors. Args: cmd: Command of the message. round: Round of the message. serialized_model: Serialized model to send. contributors: List of contributors. weight: Weight of the message (number of samples). Returns: RootMessage to send. """ if contributors is None: contributors = [] return node_pb2.RootMessage( source=self.__self_addr, round=round, cmd=cmd, weights=node_pb2.Weights( weights=serialized_model, contributors=contributors, num_samples=weight, ), )
#### # Message Sending ####
[docs] def send( self, nei: str, msg: node_pb2.RootMessage, create_connection: bool = False, raise_error: bool = False, remove_on_error: bool = True, ) -> None: """ Send a message to a neighbor. Args: nei (string): Neighbor address. msg (node_pb2.Message or node_pb2.Weights): Message to send. create_connection (bool): Create a connection if not exists. raise_error (bool): Raise error if an error occurs. remove_on_error (bool): Remove neighbor if an error occurs. """ channel = None try: # Get neighbor try: node_stub = self.__neighbors.get(nei)[1] except KeyError as e: raise NeighborNotConnectedError(f"Neighbor {nei} not found.") from e # Check if direct connection if node_stub is None and create_connection: if Settings.USE_SSL and isfile(Settings.SERVER_CRT): with open(Settings.CLIENT_KEY) as key_file, open(Settings.CLIENT_CRT) as crt_file, open(Settings.CA_CRT) as ca_file: private_key = key_file.read().encode() certificate_chain = crt_file.read().encode() root_certificates = ca_file.read().encode() creds = grpc.ssl_channel_credentials( root_certificates=root_certificates, private_key=private_key, certificate_chain=certificate_chain, ) channel = grpc.secure_channel(nei, creds) else: channel = grpc.insecure_channel(nei) node_stub = node_pb2_grpc.NodeServicesStub(channel) # Send if node_stub is not None: # Send message res = node_stub.send(msg, timeout=Settings.GRPC_TIMEOUT) else: raise NeighborNotConnectedError("Neighbor not directly connected (Stub not defined and create_connection is false).") if res.error: raise CommunicationError(f"Error while sending a message: {msg.cmd}: {res.error}") except Exception as e: # Remove neighbor logger.info( self.__self_addr, f"Cannot send message {msg.cmd} to {nei}. Error: {str(e)}", ) if remove_on_error: self.__neighbors.remove(nei, disconnect_msg=True) # Re-raise if raise_error: raise e finally: if channel is not None: channel.close()
[docs] def broadcast(self, msg: node_pb2.RootMessage, node_list: Optional[List[str]] = None) -> None: """ Broadcast a message to all the neighbors. Args: msg: Message to send. node_list: List of neighbors to send the message. If None, send to all the neighbors. """ # Node list nodes = node_list if node_list is not None else self.__neighbors.get_all(only_direct=True).keys() # Send for n in nodes: self.send(n, msg)