p2pfl.learning.frameworks.flax.flax_model moduleΒΆ

Flax Model for P2PFL.

class p2pfl.learning.frameworks.flax.flax_model.FlaxModel(model, init_params, params=None, num_samples=None, contributors=None, additional_info=None)[source]ΒΆ

Bases: P2PFLModel

P2PFL model abstraction for Flax.

build_copy(**kwargs)[source]ΒΆ

Build a copy of the model.

Parameters:

**kwargs – Parameters of the model initialization.

Return type:

P2PFLModel

Returns:

A copy of the model.

decode_parameters(data)[source]ΒΆ

Decode the parameters of the model.

Parameters:

data (bytes) – The parameters of the model.

Return type:

Tuple[List[ndarray], Dict[str, Any]]

encode_parameters(params=None)[source]ΒΆ

Encode the parameters of the model.

Parameters:

params (Optional[List[ndarray]]) – The parameters of the model.

Return type:

bytes

get_framework()[source]ΒΆ

Retrieve the model framework name.

Return type:

str

Returns:

The name of the model framework.

get_parameters()[source]ΒΆ

Get the parameters of the model.

Return type:

List[ndarray]

Returns:

The parameters of the model as a list of NumPy arrays.

set_parameters(params)[source]ΒΆ

Set the parameters of the model.

Parameters:

params (Union[List[ndarray], bytes]) – The parameters of the model.

Raises:

ModelNotMatchingError – If parameters don’t match the model.

Return type:

None

class p2pfl.learning.frameworks.flax.flax_model.MLP(hidden_sizes=(256, 128), out_channels=10, parent=<flax.linen.module._Sentinel object>, name=None)[source]ΒΆ

Bases: Module

Multilayer Perceptron (MLP) for MNIST classification using Flax.

hidden_sizes: Tuple[int, int] = (256, 128)ΒΆ
name: Optional[str] = NoneΒΆ
out_channels: int = 10ΒΆ
parent: Union[Module, Scope, _Sentinel, None] = NoneΒΆ
scope: Scope | None = NoneΒΆ