p2pfl.learning.frameworks.flax.flax_model moduleΒΆ
Flax Model for P2PFL.
- class p2pfl.learning.frameworks.flax.flax_model.FlaxModel(model, init_params, params=None, num_samples=None, contributors=None, additional_info=None)[source]ΒΆ
Bases:
P2PFLModel
P2PFL model abstraction for Flax.
- build_copy(**kwargs)[source]ΒΆ
Build a copy of the model.
- Parameters:
**kwargs β Parameters of the model initialization.
- Return type:
- Returns:
A copy of the model.
- decode_parameters(data)[source]ΒΆ
Decode the parameters of the model.
- Parameters:
data (
bytes
) β The parameters of the model.- Return type:
Tuple
[List
[ndarray
],Dict
[str
,Any
]]
- encode_parameters(params=None)[source]ΒΆ
Encode the parameters of the model.
- Parameters:
params (
Optional
[List
[ndarray
]]) β The parameters of the model.- Return type:
bytes
- get_framework()[source]ΒΆ
Retrieve the model framework name.
- Return type:
str
- Returns:
The name of the model framework.
- get_parameters()[source]ΒΆ
Get the parameters of the model.
- Return type:
List
[ndarray
]- Returns:
The parameters of the model as a list of NumPy arrays.
- set_parameters(params)[source]ΒΆ
Set the parameters of the model.
- Parameters:
params (
Union
[List
[ndarray
],bytes
]) β The parameters of the model.- Raises:
ModelNotMatchingError β If parameters donβt match the model.
- Return type:
None
- class p2pfl.learning.frameworks.flax.flax_model.MLP(hidden_sizes=(256, 128), out_channels=10, parent=<flax.linen.module._Sentinel object>, name=None)[source]ΒΆ
Bases:
Module
Multilayer Perceptron (MLP) for MNIST classification using Flax.
-
name:
Optional
[str
] = NoneΒΆ
-
out_channels:
int
= 10ΒΆ
-
parent:
Union
[Module
,Scope
,_Sentinel
,None
] = NoneΒΆ
- scope: Scope | None = NoneΒΆ
-
name: