p2pfl.learning.frameworks.flax.flax_dataset moduleΒΆ
Flax Dataset export strategy.
- class p2pfl.learning.frameworks.flax.flax_dataset.FlaxExportStrategy[source]ΒΆ
Bases:
DataExportStrategyExport strategy for JAX/Flax datasets.
- static export(data, batch_size=None, num_workers=0, **kwargs)[source]ΒΆ
Export the data using the JAX/Flax strategy.
- Parameters:
data (
Dataset) β The data to export. Transforms should already be applied to the dataset via set_transform.batch_size (
int|None) β The batch size to use for the exported data.num_workers (
int) β The number of workers to use for the exportedkwargs β Additional keyword arguments.
- Return type:
Generator[tuple[ndarray,ndarray],Any,None]- Returns:
The exported data.