p2pfl.learning.frameworks.flax.flax_dataset moduleΒΆ

Flax Dataset export strategy.

class p2pfl.learning.frameworks.flax.flax_dataset.FlaxExportStrategy[source]ΒΆ

Bases: DataExportStrategy

Export strategy for JAX/Flax datasets.

static export(data, transforms=None, batch_size=1, num_workers=0, **kwargs)[source]ΒΆ

Export the data using the PyTorch strategy.

Parameters:
  • data (Dataset) – The data to export.

  • transforms (Optional[Callable]) – The transforms to apply to the data.

  • batch_size (int) – The batch size to use for the exported data.

  • num_workers (int) – The number of workers to use for the exported

  • kwargs – Additional keyword arguments.

Return type:

Generator[Tuple[Array, Array], Any, None]

Returns:

The exported data.