p2pfl.learning.frameworks.pytorch.lightning_dataset moduleΒΆ
PyTorch dataset integration.
- class p2pfl.learning.frameworks.pytorch.lightning_dataset.PyTorchExportStrategy[source]ΒΆ
Bases:
DataExportStrategyExport strategy for PyTorch tensors.
- static export(data, batch_size=None, num_workers=0, **kwargs)[source]ΒΆ
Export the data using the PyTorch strategy.
- Parameters:
data (
Dataset) β The data to export. Transforms should already be applied to the dataset via set_transform.batch_size (
int|None) β The batch size to use for the exported data.num_workers (
int) β The number of workers to use for the exportedkwargs β Additional keyword arguments.
- Return type:
DataLoader- Returns:
The exported data.
- class p2pfl.learning.frameworks.pytorch.lightning_dataset.TorchvisionDatasetFactory[source]ΒΆ
Bases:
objectFactory class for loading PyTorch Vision datasets in P2PFL.
- static get_mnist(cache_dir, train=True, download=True)[source]ΒΆ
Get the MNIST dataset from PytorchVision.
- Parameters:
cache_dir (
str|Path) β The directory where the dataset will be stored.train (
bool) β Whether to get the training or test dataset.download (
bool) β Whether to download the dataset.
- Return type: