微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何从从 CSV 文件加载的自定义联合数据集构建 federated_averaging_process

如何解决如何从从 CSV 文件加载的自定义联合数据集构建 federated_averaging_process

我的问题是继续这个问题 How to create federated dataset from a CSV file?

我设法从给定的 csv 文件加载联合数据集并加载训练和测试数据。

我现在的问题是如何重现一个工作示例来构建一个迭代过程,该过程对这些数据执行自定义联合平均。

这是我的代码,但它不起作用:

import os

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff
from absl import app
from tensorflow.keras import layers

from src.main import Parameters


def main(args):
    working_dir = "D:/User/Documents/GitHub/TriaBaseMLBackup/input/fakehdfs/nms/ystr=2016/ymstr=1/ymdstr=26"
    client_id_colname = 'counter'
    SHUFFLE_BUFFER = 1000
    NUM_EPOCHS = 1

    for root,dirs,files in os.walk(working_dir):
        file_list = []

        for filename in files:
            if filename.endswith('.csv'):
                file_list.append(os.path.join(root,filename))
        df_list = []
        for file in file_list:
            df = pd.read_csv(file,delimiter="|",usecols=[1,2,6,7],header=None,na_values=["NIL"],na_filter=True,names=["meas_info","counter","value","time"],index_col='time')
            df_list.append(df[["value"]])

        if df_list:
            rawdata = pd.concat(df_list)

    client_ids = df.get(client_id_colname)
    train_client_ids = client_ids.sample(frac=0.5).tolist()
    test_client_ids = [x for x in client_ids if x not in train_client_ids]

    def create_tf_dataset_for_client_fn(client_id):
        # a function which takes a client_id and returns a
        # tf.data.Dataset for that client
        client_data = df[df['value'] == client_id]
    features = ['meas_info','counter']
    LABEL_COLUMN = 'value'
    dataset = tf.data.Dataset.from_tensor_slices(
        (collections.OrderedDict(client_data[features].to_dict('list')),client_data[LABEL_COLUMN].to_list())
    )
    global input_spec
    input_spec = dataset.element_spec
    dataset = dataset.shuffle(SHUFFLE_BUFFER).batch(1).repeat(NUM_EPOCHS)
    return dataset

    train_data = tff.simulation.ClientData.from_clients_and_fn(
        client_ids=train_client_ids,create_tf_dataset_for_client_fn=create_tf_dataset_for_client_fn
    )
    test_data = tff.simulation.ClientData.from_clients_and_fn(
        client_ids=test_client_ids,create_tf_dataset_for_client_fn=create_tf_dataset_for_client_fn
    )
    example_dataset = train_data.create_tf_dataset_for_client(
        train_data.client_ids[0]
    )
    # split client id into train and test clients
    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]
    tff_model = tf.keras.Sequential([
        layers.Dense(64),layers.Dense(1)
    ])

    def retrieve_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.LSTM(2,input_shape=(1,2),return_sequences=True),tf.keras.layers.Dense(256,activation=tf.nn.relu),tf.keras.layers.Activation(tf.nn.softmax),])

    return model

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(
            keras_model=retrieve_model(),input_spec=example_dataset.element_spec,loss=loss_builder(),metrics=metrics_builder())

    iterative_process = tff.learning.build_federated_averaging_process(
        tff_model_fn,Parameters.server_adam_optimizer_fn,Parameters.client_adam_optimizer_fn)
    server_state = iterative_process.initialize()

    for round_num in range(Parameters.FLAGS.total_rounds):
        sampled_clients = np.random.choice(
            train_data.client_ids,size=Parameters.FLAGS.train_clients_per_round,replace=False)
        sampled_train_data = [
            train_data.create_tf_dataset_for_client(client)
            for client in sampled_clients
        ]
        server_state,metrics = iterative_process.next(server_state,sampled_train_data)
        train_metrics = metrics['train']
        print(metrics)


if __name__ == '__main__':
    app.run(main)


def start():
    app.run(main)

这是我得到的错误,但我认为我的问题不仅仅是这个错误。我在这里做错了什么??

ValueError: The top-level structure in `input_spec` must contain exactly two top-level elements,as it must specify type information for both inputs to and predictions from the model. You passed input spec {'meas_info': TensorSpec(shape=(None,),dtype=tf.float32,name=None),'counter': TensorSpec(shape=(None,'value': TensorSpec(shape=(None,name=None)}.

enter image description here

感谢@Zachary Garrett 我通过添加这些代码行在他的帮助下解决了上述错误

 client_data = df[df['value'] == client_id]
        features = ['meas_info','counter']
        LABEL_COLUMN = 'value'
        dataset = tf.data.Dataset.from_tensor_slices(
            (collections.OrderedDict(client_data[features].to_dict('list')),client_data[LABEL_COLUMN].to_list())
        )
        global input_spec
        input_spec = dataset.element_spec
        dataset = dataset.shuffle(SHUFFLE_BUFFER).batch(1).repeat(NUM_EPOCHS)
        return dataset

我现在抛出 tff.learning.build_federated_averaging_process 的问题是这个

ValueError: Layer sequential expects 1 inputs,but it received 2 input tensors. Inputs received: [<tf.Tensor 'batch_input:0' shape=() dtype=float32>,<tf.Tensor 'batch_input_1:0' shape=() dtype=float32>]

我又想念什么?也许这里的图层顺序中的某些东西

def retrieve_model():
        model = tf.keras.models.Sequential([
            tf.keras.layers.LSTM(2,])

        return model

解决方法

tff.learning 包中的进程通常需要以 (x,y) 形式生成序列(元组或列表)的数据集。 xy 可以是单个张量,也可以是张量的嵌套结构(dictlist 等)。

查看数据集格式的一种简单方法是打印 .element_spec 属性。

从上面的代码中,我怀疑数据集只产生一个 dict,因为这一行:

dataset = tf.data.Dataset.from_tensor_slices(client_data.to_dict('list'))

这不会以 TFF 预期的方式将 x(特征)和 y(标签)分开。类似以下内容可能有效:

FEATURE_COLUMNS = [...]
LABEL_COLUMN = '...'
dataset = tf.data.Dataset.from_tensor_slices(
  (client_data[FEATURE_COLUMNS].to_dict('list'),client_data[LABEL_COLUMN].to_list())
)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?