Source code for p2pfl.learning.frameworks.tensorflow.keras_dataset
## This file is part of the federated_learning_p2p (p2pfl) distribution# (see https://github.com/pguijas/p2pfl).# Copyright (c) 2024 Pedro Guijas Bravo.## This program is free software: you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation, version 3.## This program is distributed in the hope that it will be useful, but# WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU# General Public License for more details.## You should have received a copy of the GNU General Public License# along with this program. If not, see <http://www.gnu.org/licenses/>.#"""Keras dataset export strategy."""fromtypingimportCallable,List,Optionalimporttensorflowastf# type: ignorefromdatasetsimportDataset# type: ignorefromp2pfl.learning.dataset.p2pfl_datasetimportDataExportStrategy
[docs]classKerasExportStrategy(DataExportStrategy):"""Export strategy for TensorFlow/Keras datasets."""
[docs]@staticmethoddefexport(data:Dataset,transforms:Optional[Callable]=None,batch_size:int=1,columns:Optional[List[str]]=None,label_cols:Optional[List[str]]=None,**kwargs,)->tf.data.Dataset:""" Export the data as a TensorFlow Dataset. Args: data: The Hugging Face Dataset to export. transforms: Optional transformations to apply (not implemented yet). batch_size: The batch size for the TensorFlow Dataset. seed: The seed for the TensorFlow Dataset. columns: The columns to include in the TensorFlow Dataset. label_cols: The columns to use as labels. **kwargs: Additional keyword arguments. Returns: A TensorFlow Dataset. """iflabel_colsisNone:label_cols=["label"]ifcolumnsisNone:columns=["image"]iftransformsisnotNone:raiseNotImplementedError("Transforms are not yet supported for KerasExportStrategy.")# Export Keras datasetreturndata.to_tf_dataset(batch_size=batch_size,columns=columns,label_cols=label_cols,)