p2pfl.learning.frameworks.pytorch.lightning_dataset moduleΒΆ

PyTorch dataset integration.

class p2pfl.learning.frameworks.pytorch.lightning_dataset.PyTorchExportStrategy[source]ΒΆ

Bases: DataExportStrategy

Export strategy for PyTorch tensors.

static export(data, transforms=None, batch_size=1, num_workers=0, **kwargs)[source]ΒΆ

Export the data using the PyTorch strategy.

Parameters:
  • data (Dataset) – The data to export.

  • transforms (Optional[Callable]) – The transforms to apply to the data.

  • batch_size (int) – The batch size to use for the exported data.

  • num_workers (int) – The number of workers to use for the exported

  • kwargs – Additional keyword arguments.

Return type:

DataLoader

Returns:

The exported data.

class p2pfl.learning.frameworks.pytorch.lightning_dataset.TorchvisionDatasetFactory[source]ΒΆ

Bases: object

Factory class for loading PyTorch Vision datasets in P2PFL.

static get_mnist(cache_dir, train=True, download=True)[source]ΒΆ

Get the MNIST dataset from PytorchVision.

Parameters:
  • cache_dir (Union[str, Path]) – The directory where the dataset will be stored.

  • train (bool) – Whether to get the training or test dataset.

  • download (bool) – Whether to download the dataset.

Return type:

P2PFLDataset