微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在Tensorflow Fedarated中加载Fashion MNIST数据集?

如何解决如何在Tensorflow Fedarated中加载Fashion MNIST数据集?

我正在与Tensorflow联合进行一个项目。我已经设法使用TensorFlow联合学习模拟提供的库来加载,训练和测试一些数据集。

例如,我加载emnist数据集

emnist_train,emnist_test = tff.simulation.datasets.emnist.load_data()

,它获得了load_data()作为tff.simulation.ClientData实例返回的数据集。该界面使我可以遍历客户端ID,并允许我选择数据的子集进行仿真。

len(emnist_train.client_ids)

3383


emnist_train.element_type_structure


OrderedDict([('pixels',TensorSpec(shape=(28,28),dtype=tf.float32,name=None)),('label',TensorSpec(shape=(),dtype=tf.int32,name=None))])


example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

我正在尝试使用Keras加载fashion_mnist数据集以执行一些联合操作:

fashion_train,fashion_test=tf.keras.datasets.fashion_mnist.load_data()

但是我得到这个错误

AttributeError: 'tuple' object has no attribute 'element_spec'

因为Keras返回一个Numpy数组的元组,而不是像以前那样返回tff.simulation.ClientData:

def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=factory.retrieve_model(True),input_spec=fashion_test.element_spec,loss=loss_builder(),metrics=metrics_builder())

iterative_process = tff.learning.build_federated_averaging_process(
    tff_model_fn,Parameters.server_adam_optimizer_fn,Parameters.client_adam_optimizer_fn)
server_state = iterative_process.initialize()

总结一下,

  1. 是否可以通过Keras Tuple Numpy数组创建tff.simulation.ClientData元组元素?

  2. 我想到的另一种解决方案是使用 tff.simulation.HDF5ClientData并加载 手动以HDF5格式(train.h5,test.h5)的适当文件获取tff.simulation.ClientData,但是我的问题是我找不到fashion_mnist HDF5文件格式的网址,我的意思是就像训练和测试一样:

          fileprefix = 'fed_emnist_digitsonly'
          sha256 = '55333deb8546765427c385710ca5e7301e16f4ed8b60c1dc5ae224b42bd5b14b'
          filename = fileprefix + '.tar.bz2'
          path = tf.keras.utils.get_file(
              filename,origin='https://storage.googleapis.com/tff-datasets-public/' + filename,file_hash=sha256,hash_algorithm='sha256',extract=True,archive_format='tar',cache_dir=cache_dir)
    
          dir_path = os.path.dirname(path)
          train_client_data = hdf5_client_data.HDF5ClientData(
              os.path.join(dir_path,fileprefix + '_train.h5'))
          test_client_data = hdf5_client_data.HDF5ClientData(
              os.path.join(dir_path,fileprefix + '_test.h5'))
    
          return train_client_data,test_client_data
    

我的最终目标是使fashion_mnist数据集与TensorFlow联合学习一起工作。

解决方法

您在正确的轨道上。回顾一下:tff.simulation.dataset API返回的数据集是tff.simulation.ClientData对象。 tf.keras.datasets.fashion_mnist.load_data返回的对象是tuple个numpy数组。

因此,需要实现一个tff.simulation.ClientData来包装tf.keras.datasets.fashion_mnist.load_data返回的数据集。以前有关实现ClientData对象的一些问题:

这确实需要回答一个重要的问题:Fashion MNIST数据应如何划分为单个用户?数据集不包含可用于分区的要素。研究人员提出了几种方法来对数据进行综合分区,例如为每个参与者随机抽取一些标签,但这将对模型训练产生很大影响,并且有助于在此处进行一些思考。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?