如何解决input_fn 具有多个数据集的张量流估计器
我正在尝试构建一个联合学习模型,为此我需要训练多个模型,每个模型都有特定的数据。
然后我会在每一步之后将模型权重的平均值聚合成一个大的。
for round_comm in range(100):
for client in clients:
client_estimator = tf.estimator.Estimator(
model_fn=model_fn
)
# copy aggregated_weights at previous step in model
client_estimator.train(
input_fn=lambda: input_fn()
)
# Retrieve weights
# compute mean of weights
他们每个人都会训练 1 或 2 个 epoch,然后我将这些模型的所有权重聚合为一个大的
例如:
train Model_clientA on dataset_A (1st batch) for 1 epoch
train Model_clientB on dataset_B (1st batch) for 1 epoch
Aggregates every weights in Model_Clients_Agg.
Model_clientA and Model_clientB copy weights from Model_Clients_Agg.
then for next iteration :
train Model_clientA on dataset_A (next batch)
train Model_clientB on dataset_B (next batch)
etc. etc.
我已经试过了:
BATCH_SIZE = 512
EPOCHS = 4
def input_fn(data,epochs,batch_size):
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices(({'client_1':data[0],'client_2':data[1]},data[2]))
# Shuffle,repeat,and batch the examples.
SHUFFLE_SIZE = 1000
dataset = dataset.shuffle(SHUFFLE_SIZE).repeat(epochs).batch(batch_size)
dataset = dataset.prefetch(2)
# Return the dataset.
return dataset
但在这种情况下不合适。
如何提供可以适应多个数据集的 input_fn?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。