Tensorflow 最小化功能无法正常工作

如何解决Tensorflow 最小化功能无法正常工作

我正在尝试像这样应用 tensorflow 的最小化函数

train_op = optimizer_dict[optimizer][0](*optimizer_dict[optimizer][1]).minimize(cost)

然而,自从 tensorflow 更新后,代码的要求似乎发生了变化。我已尝试适应新版本,但出现错误

Shape must be rank 1 but is rank 2 for '{{node BiasAdd_589}} = BiasAdd[T=DT_FLOAT,data_format="NHWC"](Placeholder_398,BiasAdd_589/ReadVariableOp)' with input shapes: [?,8],[8,1000].

我在下面提供了完整的功能,以便您可以了解问题的完整上下文。如果需要提供任何其他信息,请告诉我。

def train_tensorflow(sess,trX,trY,train_steps,full_train,train_size,net_type,transform_dict,loss_type,optimizer,optimizer_dict):
    '''
    Automatically constructs,trains,and tests a tensorflow neural network,returning the r squared value of the output.
    :param sess: A tensorflow session.
    :param trX: Numpy array that contains the training features.
    :param trY: Numpy array that contains the training outputs. Must have shape of at least 1 on columns.
    :param train_steps: Integer value denoting the number of times to iterate through training.
    :param full_train: Boolean value denoting whether to use the full training set for each iteration.
    :param train_size: Integer value denoting the number of samples to pull from the training set for each iteration of training.
    :param net_type: List of alternating string values and integer values. Must always start and end with a string values. The strings denote the type of each layer. The integer values denote the end size of each layer,through this constrained for certain layer types. Sizes of zero drop that layer out.
    :param transform_dict: Dictionary of strings to tuples of tensors that encode how to set up the layers of the neural network.
    :param loss_type: String denoting typoe of tensor to use for loss type,Use l2_loss for regression,cross_entropy for classification.
    :param optimizer: String denotiong the type of optimization tensor to use for training the neural network.
    :param optimizer_dict: Dictionary of strings to tuples of tensors that encode how to set up the optimizers of the neural networks.
    :return: predict_op: Tensor that encodes the neural network.
        X: Placeholder tensor for the features array.
        y: Placeholder tensor for the output array.
    '''

    # Set up input and output tensors.
    X = tf.compat.v1.placeholder("float",[None,trX.shape[1]])
    y = tf.compat.v1.placeholder("float",trY.shape[1]])

    # Set up network.
    tmp_model = make_model(X,trX.shape[1],trY.shape[1],transform_dict)
    py_x = tmp_model[0]

    # Set up cost and training type.
    if (loss_type == "l2_loss"):
        cost = tf.nn.l2_loss(tf.subtract(py_x,y))
    elif (loss_type == "cross_entropy"):
        cost = -tf.reduce_sum(y*tf.log(py_x))

    # Gets the optimizer to be used for training and set it up.
    if (type(optimizer) == str):
        train_op = optimizer_dict[optimizer][0](*optimizer_dict[optimizer][1]).minimize(cost,tape=tf.GradientTape(persistent=True).gradient(cost,[tmp_model[1],tmp_model[2]]))
    else:
        #train_op = optimizer[0](*optimizer[1]).minimize(cost,var_list=[py_x],tmp_model[1]))
        print("in else")
    predict_op = py_x

    init = tf.initialize_all_variables()
    sess.run(init)

    # Trains given number of times
    try:
        for i in range(train_steps):

            # If full_train is selected,the trains on the full set of training data,in 100 sample increments.
            if (full_train):
                for start,end in zip(range(0,len(trX),100),range(100,100)):
                    sess.run(train_op,Feed_dict={X: trX[start:end],y: trY[start:end]})

            # If full_train is not selectged,then trains on a random set of samples from the training data.
            else:
                indices = random_index_list(train_size,len(trY))
                sess.run(train_op,Feed_dict={X: trX[indices],y: trY[indices]})

    except:
        print("Error during training")
        sess.close()
        return None

    return predict_op,X,y

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?