微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在 Keras 中用高阶张量编写激活函数?

如何解决如何在 Keras 中用高阶张量编写激活函数?

我想在 Keras 中创建一个特定的神经网络。在这个神经网络中,我使用由以下给出的层 $$ f(x) = C_k(\underbrace{x,\dots,x)}_{\times k}+\phi(w^\intercal x+b) $$ 表达式 $\phi(w^\intercal x+b)$ 只是带有权重和偏差的正常激活函数$C_k$一个 0-k 张量,即当你给它 $k$ 个副本时class="math-container">$x$ 它返回一个数字。对于 $k=0$ 这意味着 $C_k$一个数字,对于 $k=1$ 这意味着 $C_k(x)$ 等价于 $ c^\intercal x$ 对于某些向量 $c$,对于 $k=2$ > 这意味着 $C_k(x,x)$ 等于 $x^\intercal Ax$对于某些矩阵 $A$

我知道我应该如何实现 $\phi(w^\intercal x+b)$ 部分,以及如何进行加法。我坚持实现 $C_k$。我添加了定义 $C_k$ 的尝试。

class TensorLayer(keras.layers.Layer):
    def __init__(self,order,units=1,**kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.order = order

    def build(self,input_shape):
        # delay deFinition of self.w to build() to make sure the shape matches the input.
        self.w = self.add_weight(
            shape=tuple([input_shape[-1] for _ in range(1,self.order+1)] + [self.units]),initializer="random_normal",trainable=True,)

    def call(self,inputs):
        out = self.w
        print("Order {}: weights are {}".format(self.order,out.shape))
        for i in range(1,self.order+1):

            # out = inputs @ out
            out = out @ inputs
            print("Order {}: after {} shape is {} for input {}".format(self.order,i,out.shape,inputs.shape))
        print("Order {}: output is {}".format(self.order,out.shape))
        return out

    def get_config(self):
        config = super(polynomialLayer,self).get_config()
        config.update({"units": self.order})
        return config

有趣的部分在 call 方法中。如果 inputs 只是一个向量,则 call 方法计算正确的输出在这种情况下,inputs一个形状为 (batch_size,vector_size) 的矩阵,它在构建层时被创建为 (None,vector_size)。然而,这意味着 call 方法的返回是无意义的。

也许当我手动将 inputs 解构为向量时,这会起作用,但这似乎非常低效。到目前为止,我还没有想出替代方案。

如何重新定义有效的 call 方法?任何帮助将不胜感激。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?