p2pfl.learning.frameworks.flax.flax_dataset moduleΒΆ

Flax Dataset export strategy.

class p2pfl.learning.frameworks.flax.flax_dataset.FlaxExportStrategy[source]ΒΆ

Bases: DataExportStrategy

Export strategy for JAX/Flax datasets.

static export(data, batch_size=None, num_workers=0, **kwargs)[source]ΒΆ

Export the data using the JAX/Flax strategy.

Parameters:
  • data (Dataset) – The data to export. Transforms should already be applied to the dataset via set_transform.

  • batch_size (Optional[int]) – The batch size to use for the exported data.

  • num_workers (int) – The number of workers to use for the exported

  • kwargs – Additional keyword arguments.

Return type:

Generator[Tuple[Array, Array], Any, None]

Returns:

The exported data.