p2pfl.learning.frameworks.pytorch.lightning_dataset moduleΒΆ
PyTorch dataset integration.
- class p2pfl.learning.frameworks.pytorch.lightning_dataset.PyTorchExportStrategy[source]ΒΆ
Bases:
DataExportStrategy
Export strategy for PyTorch tensors.
- static export(data, transforms=None, batch_size=1, num_workers=0, **kwargs)[source]ΒΆ
Export the data using the PyTorch strategy.
- Parameters:
data (
Dataset
) β The data to export.transforms (
Optional
[Callable
]) β The transforms to apply to the data.batch_size (
int
) β The batch size to use for the exported data.num_workers (
int
) β The number of workers to use for the exportedkwargs β Additional keyword arguments.
- Return type:
DataLoader
- Returns:
The exported data.
- class p2pfl.learning.frameworks.pytorch.lightning_dataset.TorchvisionDatasetFactory[source]ΒΆ
Bases:
object
Factory class for loading PyTorch Vision datasets in P2PFL.
- static get_mnist(cache_dir, train=True, download=True)[source]ΒΆ
Get the MNIST dataset from PytorchVision.
- Parameters:
cache_dir (
Union
[str
,Path
]) β The directory where the dataset will be stored.train (
bool
) β Whether to get the training or test dataset.download (
bool
) β Whether to download the dataset.
- Return type: