InvalidArgumentError: 发现 2 个根错误 (0) 无效参数:indices[10,0] = 101102 不在 [0, 101102)

如何解决InvalidArgumentError: 发现 2 个根错误 (0) 无效参数:indices[10,0] = 101102 不在 [0, 101102)

我正在尝试通过在 MovieLens 数据集上训练神经协同过滤 (NCF) 网络来创建电影推荐系统。我对 NCF 的实现是

def NCF(num_users,num_items,gmf_embedding_dim,mlp_embedding_dim):
    # Define input vectors for embedding
    u_input = Input(shape = [1,])
    i_input = Input(shape = [1,])

    # GMF embedding
    u_embedding_gmf = Embedding(input_dim = num_users,output_dim = gmf_embedding_dim(u_input)
    u_vec_gmf = Flatten()(u_embedding_gmf)

    i_embedding_gmf = Embedding(input_dim = num_items,output_dim = gmf_embedding_dim(i_input)
    i_vec_gmf = Flatten()(i_embedding_gmf)

    # MLP embedding
    u_embedding_mlp = Embedding(input_dim = num_users,output_dim = mlp_embedding_dim(u_input)
    u_vec_mlp = Flatten()(u_embedding_mlp)

    i_embedding_mlp = Embedding(input_dim = num_items,output_dim = mlp_embedding_dim(i_input)
    i_vec_mlp = Flatten()(i_embedding_mlp)

    # GMF path
    gmf_output = Dot(axes = 1)([u_vec_gmf,i_vec_gmf])

    # MLP path
    mlp_input_concat = Concatenate()([u_vec_mlp,i_vec_mlp])

    mlp_dense_1 = Dense(units = 128,activation = "relu")(mlp_input_concat)
    mlp_bn_1 = Batchnormalization()(mlp_dense_1)
    mlp_drop_1 = Dropout(0.3)(mlp_bn_1)

    mlp_dense_2 = Dense(units = 64,activation = "relu")(mlp_drop_1)
    mlp_bn_2 = Batchnormalization()(mlp_dense_2)
    mlp_output = Dropout(0.3)(mlp_bn_2)

    # Concatenate GMF and MLP pathways
    paths_concat = Concatenate()([gmf_output,mlp_output])

    # Prediction
    output = Dense(units = 1,activation = "sigmoid")(paths_concat)

    # Create model 

    return Model(inputs = [u_input,i_input],outputs = output)

我创建了一个函数来处理我的训练

def train(model,x_train,y_train,x_valid,y_valid,batch_size,epochs,save_name,checkpoint_path,history_path,lr = 0.001,lr_decay = True):

    if isfile(join(history_path,save_name)):
        return

    model.compile(loss = BinaryCrossentropy(),optimizer = Adam(learning_rate = lr),metrics["accuracy"])

    best_checkpoint = ModelCheckpoint(filepath = join(checkpoint_path,save_name),monitor = "val_loss",save_best_only = True)

    history_csv = CSVLogger(join(history_path,save_name))

    early_stop = EarlyStopping(monitor = "val_loss",patience = 30,restore_best_weights = True)

    lr_decay_callback = ReduceLROnPlateau(monitor = "val_loss",patience = 10,factor = 0.5,min_lr = 0.000001)

    callback_list = [best_checkpoint,history_csv,early_stop]

    if lr_decay:
        callback_list.append(lr_decay_callback)
    
    model.fit(x = x_train,y = y_train,validation_data = (x_valid,y_valid),epochs = epochs,callbacks = callback_list,batch_size = batch_size)

使用

编码准备用于嵌入层的 user_ID 和 movie_ID 值
enc = LabelEncoder()
train_set["user_ID"] = enc.fit_transform(train_set["user_ID"].values)
enc = LabelEncoder()
train_set["movie_ID"] = enc.fit_transform(train_set["movie_ID"].values)

enc = LabelEncoder()
valid_set["user_ID"] = enc.fit_transform(valid_set["user_ID"].values)
enc = LabelEncoder()
valid_set["movie_ID"] = enc.fit_transform(valid_set["movie_ID"].values)

enc = LabelEncoder()
test_set["user_ID"] = enc.fit_transform(test_set["user_ID"].values)
enc = LabelEncoder()
test_set["movie_ID"] = enc.fit_transform(test_set["movie_ID"].values)

然后开始训练

train(model = NCF(num_users = train_set["user_ID"].nunique() + 1,num_items = 
     train_set["movie_ID"].nunique() + 1,gmf_embedding_dim = 10,mlp_embedding_dim = 10),x_train = [train_set["user_ID"],train_set["movie_ID"]],y_train =
     train_set["interaction"],x_valid = [valid_set["user_ID"],valid_set["movie_ID"]],y_valid = 
     valid_set["interaction"],batch_size = (train_set.shape[0])/10,epochs = 50,save_name = "NCF_1",checkpoint_path = "D:/Movie Recommendation System Project/model data/checkpoints",history_path = "D:/Movie Recommendation System Project/model data/training history")

直到第一个 epoch 的最后一批,我才收到错误

InvalidArgumentError:发现 2 个根错误

(0) 无效参数:indices[10,0] = 101102 不在 [0,101102) [[节点functional_9/embedding_16/embedding_lookup(定义在D:/电影推荐系统项目/架构和培训\training_and_evaluation.py:38)]] [[functional_9/embedding_18/embedding_lookup/_16]]

(1) 无效参数:indices[10,101102) [[节点功能_9/embedding_16/embedding_lookup(定义在D:/电影推荐系统项目/架构和培训\training_and_evaluation.py:38)]]

0 次成功操作。

忽略了 0 个派生错误。 [操作:__inference_test_function_529078]

这是一个与我在上次尝试结束时收到的错误非常相似的错误,其中值是 101101 而不是 101102。作为一个天真的解决方案,我尝试将 1 添加到我的 num_users 和 num_movies 值,但现在错误消息中的值似乎只是增加了 1。我觉得我在这里缺少关于嵌入层的一些明显或基本的东西。有人可以帮忙吗?

解决方法

我相信这个错误的发生是由于嵌入层遇到了一个它没有预料到的值。当您调用 NCF 函数时,您传递的是训练集的唯一用户数。而是计算完整数据集中的唯一用户数并将其发送到 NCF 函数。

例如:

total_num_users = train_set["user_ID"].nunique() + valid_set["user_ID"].nunique() + test_set["user_ID"]..]nunique()

train(model = NCF(num_users = total_num_users,num_items = 
     train_set["movie_ID"].nunique() + 1,gmf_embedding_dim = 10,mlp_embedding_dim = 10),x_train = [train_set["user_ID"],train_set["movie_ID"]],y_train =
     train_set["interaction"],x_valid = [valid_set["user_ID"],valid_set["movie_ID"]],y_valid = 
     valid_set["interaction"],batch_size = (train_set.shape[0])/10,epochs = 50,save_name = "NCF_1",checkpoint_path = "D:/Movie Recommendation System Project/model data/checkpoints",history_path = "D:/Movie Recommendation System Project/model data/training history")

确保对您嵌入的其他分类变量采用相同的方法。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?