微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何重塑 Tensorflow 数据集中的数据?

如何解决如何重塑 Tensorflow 数据集中的数据?

我正在编写一个数据管道,将成批的时间序列序列和相应的标签输入到需要 3D 输入形状的 LSTM 模型中。我目前有以下几点:

def split(window):
    return window[:-label_length],window[-label_length]

dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length,shift=label_shift,stride=1,drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split,num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer,seed=shuffle_seed,reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size,drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

for x,y in dataset.take(1): x.shape 的结果形状是 (32,20),其中 32 是批量大小,20 是序列的长度,但我需要 (32,20,1) 的形状,其中额外的维度表示特征。

我的问题是如何重塑,理想情况下是在缓存数据之前传递给 split 函数dataset.map 函数中?

解决方法

这很简单。在您的拆分功能中执行此操作

def split(window):
    return window[:-label_length,tf.newaxis],window[-label_length,tf.newaxis,tf.newaxis]

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。