p2pfl.learning.frameworks.flax.flax_dataset moduleΒΆ
Flax Dataset export strategy.
- class p2pfl.learning.frameworks.flax.flax_dataset.FlaxExportStrategy[source]ΒΆ
Bases:
DataExportStrategy
Export strategy for JAX/Flax datasets.
- static export(data, transforms=None, batch_size=1, num_workers=0, **kwargs)[source]ΒΆ
Export the data using the PyTorch strategy.
- Parameters:
data (
Dataset
) β The data to export.transforms (
Optional
[Callable
]) β The transforms to apply to the data.batch_size (
int
) β The batch size to use for the exported data.num_workers (
int
) β The number of workers to use for the exportedkwargs β Additional keyword arguments.
- Return type:
Generator
[Tuple
[Array
,Array
],Any
,None
]- Returns:
The exported data.