微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何将 Tensorflow Recommenders 的检索任务与 Keras 数据生成器一起使用

如何解决如何将 Tensorflow Recommenders 的检索任务与 Keras 数据生成器一起使用

我最近开始使用该包来构建推荐系统,到目前为止,我已经成功构建了一个排名任务,该任务从 Keras 数据生成获取输入。但是,我无法为检索任务使用相同的管道,因为 the recommended approach 实例化此类任务涉及传递 tf.data.Dataset,如下所示。

    self.task = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=movies.batch(128).map(self.candidate_model),),)

查看文档和源代码后,我发现还可以传递 factorized_top_k 层。因此,我尝试了以下操作:

class RetrievalModel(tfrs.Model):
    def __init__(self,unique_user_ids,unique_item_ids,time_features=[],customer_features=[],item_features=[]):

        super().__init__()  
        
        user_embedding_dim = item_embedding_dim = embedding_dim

        # The two-tower model needs user and item towers of the same size
        if len(item_features) - len(customer_features) > 0:
            user_embedding_dim = 2 * embedding_dim + 13

        # Compute embeddings for users.
        self.user_model = usermodel2D(unique_user_ids,customer_features=customer_features,time_features=time_features,embedding_dim=user_embedding_dim,)

        # Compute embeddings for items.
        self.item_model = Itemmodel2D(unique_item_ids,item_features=item_features,embedding_dim=item_embedding_dim,)
    
        self.candidate_layer = tfrs.layers.factorized_top_k.ScaNN(self.user_model)
        metrics= tfrs.metrics.FactorizedTopK(candidates=self.candidate_layer)
        self.task = tfrs.tasks.Retrieval(metrics=metrics)

    def call(self,inputs):
        user_embedding = self.user_model(inputs)
        item_embedding = self.item_model(inputs)
        self.candidate_layer.index(candidates=item_embedding)

        return user_embedding,item_embedding

    def compute_loss(self,inputs,training=False):
        user_embedding,item_embedding = self(inputs)
        return self.task(user_embedding,item_embedding,compute_metrics=not training)

但是也没有用。所以我只是想知道如何在不使用 tf.data.Dataset 的情况下创建检索任务。我很感激这里的任何反馈。非常感谢您抽出宝贵时间!

这里是一些输出和调试打印。我简化了一些部分,使其更易于阅读和隐藏机密信息。

--- init ---
item_features: 4
customer_features: 3
time_features: 3
user_embedding_dim: 67
item_embedding_dim: 27
user: len(unique_user_ids) 135447
user: len(customer_features) 3
user: len(time_features) 3
item: len(unique_item_ids) 504
item: len(customer_features) 4
candidate_layer: <tensorflow_recommenders.layers.factorized_top_k.ScaNN object at 0x7fc6ea5b5240>
metrics: <tensorflow_recommenders.metrics.factorized_top_k.FactorizedTopK object at 0x7fc6ea5b5518>
task: <tensorflow_recommenders.tasks.retrieval.Retrieval object at 0x7fc6ea561c88>
train_datagenerator: <deep_learning.helpers.DataGenerator object at 0x7fc88ac82f60>

--- call ---
inputs: {'user_id': <tf.Tensor: shape=(1024,dtype=int32,numpy=
array([158429273,460546163,144561824,...,130676640,17285232,111347467],dtype=int32)>,'item_id': <tf.Tensor: shape=(1024,numpy=
array([ 10903699,484336382,459214922,945589400,303642080],... }
user: user_embedding.shape (1024,67)
user: context_embedding.shape (1024,12)
user: after align user_embedding.shape (1024,79)
user: user_embedding.shape (1024,79)
user: context_embedding.shape (1024,3)
user: after align user_embedding.shape (1024,82)
user_embedding.shape (1024,82)
item: item_embedding.shape (1024,27)
item: context_embedding.shape (1024,28)
item: after align context item_embedding.shape (1024,55)
item: item_embedding.shape (1024,55)
item: context_embedding.shape (1024,27)
item: after align context item_embedding.shape (1024,82)
item_embedding.shape (1024,82)

--- compute loss ---
inputs: {'user_id': <tf.Tensor 'IteratorGetNext:0' shape=(None,) dtype=int32>,'item_id': <tf.Tensor 'IteratorGetNext:1' shape=(None,...}

--- call ---
inputs: {'user_id': <tf.Tensor 'IteratorGetNext:0' shape=(None,...}

user: user_embedding.shape (None,67)
user: context_embedding.shape (None,12)
user: after align user_embedding.shape (None,79)
user: user_embedding.shape (None,79)
user: context_embedding.shape (None,3)
user: after align user_embedding.shape (None,82)
user_embedding.shape (None,82)

item: item_embedding.shape (None,27)
item: context_embedding.shape (None,28)
item: after align context item_embedding.shape (None,55)
item: item_embedding.shape (None,55)
item: context_embedding.shape (None,27)
item: after align context item_embedding.shape (None,82)
item_embedding.shape (None,82)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-6-7b4e907381cd> in <module>
     11                              batch_size=batch_size,12                              use_implicit_Feedback=use_implicit_Feedback,---> 13                              use_all_features=use_all_features)

~.../deep_learning/helpers.py in hyperparameter_tune_deep_learning(project_path,data_folder_name,split_suffix,study_name,n_trials,trial_epochs,batch_size,use_implicit_Feedback,use_all_features)
    741                                            use_all_features,monitor),742                    n_trials=n_trials,--> 743                    callbacks=[tensorboard_callback])
    744 
    745     study_dir = project_path + '/data/hyperparameter_tuning_studies/'

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/optuna/study.py in optimize(self,func,timeout,n_jobs,catch,callbacks,gc_after_trial,show_progress_bar)
    313             callbacks=callbacks,314             gc_after_trial=gc_after_trial,--> 315             show_progress_bar=show_progress_bar,316         )
    317 

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/optuna/_optimize.py in _optimize(study,show_progress_bar)
     63                 reseed_sampler_rng=False,64                 time_start=None,---> 65                 progress_bar=progress_bar,66             )
     67         else:

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/optuna/_optimize.py in _optimize_sequential(study,reseed_sampler_rng,time_start,progress_bar)
    154 
    155         try:
--> 156             trial = _run_trial(study,catch)
    157         except Exception:
    158             raise

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/optuna/_optimize.py in _run_trial(study,catch)
    187 
    188     try:
--> 189         value = func(trial)
    190     except exceptions.TrialPruned as e:
    191         # Register the last intermediate value if present as the value of the trial.

~.../deep_learning/helpers.py in <lambda>(trial)
    739     study.optimize(lambda trial: objective(trial,train_datagenerator,val_datagenerator,740                                            time_features,customer_features,item_features,--> 741                                            use_all_features,743                    callbacks=[tensorboard_callback])

~.../deep_learning/helpers.py in objective(trial,time_features,use_all_features,monitor)
    632                   validation_data=val_datagenerator,validation_freq=1,633                   callbacks=[reduce_lr,trial_pruner],--> 634                   verbose=1)
    635 
    636     return history.history[monitor][-1]

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self,x,y,epochs,verbose,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,validation_batch_size,validation_freq,max_queue_size,workers,use_multiprocessing)
   1181                 _r=1):
   1182               callbacks.on_train_batch_begin(step)
-> 1183               tmp_logs = self.train_function(iterator)
   1184               if data_handler.should_sync:
   1185                 context.async_wait()

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in __call__(self,*args,**kwds)
    887 
    888       with OptionalXlaContext(self._jit_compile):
--> 889         result = self._call(*args,**kwds)
    890 
    891       new_tracing_count = self.experimental_get_tracing_count()

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _call(self,**kwds)
    931       # This is the first call of __call__,so we have to initialize.
    932       initializers = []
--> 933       self._initialize(args,kwds,add_initializers_to=initializers)
    934     finally:
    935       # At this point we kNow that the initialization is complete (or less

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _initialize(self,args,add_initializers_to)
    762     self._concrete_stateful_fn = (
    763         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 764             *args,**kwds))
    765 
    766     def invalid_creator_scope(*unused_args,**unused_kwds):

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self,**kwargs)
   3048       args,kwargs = None,None
   3049     with self._lock:
-> 3050       graph_function,_ = self._maybe_define_function(args,kwargs)
   3051     return graph_function
   3052 

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self,kwargs)
   3442 
   3443           self._function_cache.missed.add(call_context_key)
-> 3444           graph_function = self._create_graph_function(args,kwargs)
   3445           self._function_cache.primary[cache_key] = graph_function
   3446 

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self,kwargs,override_flat_arg_shapes)
   3287             arg_names=arg_names,3288             override_flat_arg_shapes=override_flat_arg_shapes,-> 3289             capture_by_value=self._capture_by_value),3290         self._function_attributes,3291         function_spec=self.function_spec,~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name,python_func,signature,func_graph,autograph,autograph_options,add_control_dependencies,arg_names,op_return_value,collections,capture_by_value,override_flat_arg_shapes)
    997         _,original_func = tf_decorator.unwrap(python_func)
    998 
--> 999       func_outputs = python_func(*func_args,**func_kwargs)
   1000 
   1001       # invariant: `func_outputs` contains only Tensors,CompositeTensors,~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args,**kwds)
    670         # the function a weak reference to itself to avoid a reference cycle.
    671         with OptionalXlaContext(compile_with_xla):
--> 672           out = weak_wrapped_fn().__wrapped__(*args,**kwds)
    673         return out
    674 

~/workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args,**kwargs)
    984           except Exception as e:  # pylint:disable=broad-except
    985             if hasattr(e,"ag_error_Metadata"):
--> 986               raise e.ag_error_Metadata.to_exception(e)
    987             else:
    988               raise

ValueError: in user code:

    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:855 train_function  *
        return step_function(self,iterator)
    /home/.../deep_learning/models.py:457 call  *
        index = self.candidate_layer.index(candidates=item_embedding)
    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow_recommenders/layers/factorized_top_k.py:491 index  *
        identifiers = tf.range(candidates.shape[0])
    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:206 wrapper  **
        return target(*args,**kwargs)
    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1908 range
        limit = ops.convert_to_tensor(limit,dtype=dtype,name="limit")
    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/profiler/trace.py:163 wrapped
        return func(*args,**kwargs)
    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:1566 convert_to_tensor
        ret = conversion_func(value,name=name,as_ref=as_ref)
    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py:339 _constant_tensor_conversion_function
        return constant(v,name=name)
    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py:265 constant
        allow_broadcast=True)
    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py:283 _constant_impl
        allow_broadcast=allow_broadcast))
    /home/.../workspace/conda/envs/tf_gpu_py36/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py:445 make_tensor_proto
        raise ValueError("None values not supported.")

    ValueError: None values not supported.

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?