微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何处理pysparkml中具有大量唯一值的分类特征

如何解决如何处理pysparkml中具有大量唯一值的分类特征

我正在使用 pysparkml 库及其模型来解决回归问题,并且我的数据具有一些具有大量唯一值(超过 1000)的分类特征。 处理它们的正确决定是什么?

据说几乎所有地方都使用 OneHotEncoder,但在 ohe 之后将有超过 10000 个稀疏列,并且后续建模需要很长时间。使用我的数据(400 万行)和集群配置,花费了超过 14 小时,但我没有得到结果。 在这种情况下,升级集群效率不高,因为我曾经查看 Ganglia 报告,其中显示了集群负载,并且内存使用率和 cpu 使用率都低于最大可用量的 20%。

我读到的另一个变体是在结果列上使用 OneHotEncoder + PCA。但它似乎工作更长时间,我认为这种方式不太正确,因为 PCA 是为连续变量设计的。

也许还有其他变体如何处理此类分类特征,例如一些 LabelEncoder(StringIndexer 本身不是 Laber 编码器,因为它留下了关于分类信息的额外元数据)

OneHotEncoder 使用的代码

  indexers = []
  for name in strings_to_index:
      indexers.append(StringIndexer(inputCol=name,outputCol=name+'_index',handleInvalid ='skip'))
      feature_list.append(name+'_ohe_enc')
      feature_list.remove(name)
  encoder = OneHotEncoderEstimator(inputCols=[name +'_index' for name in strings_to_index],outputCols=[name +'_ohe_enc' for name in strings_to_index])
  assembler = VectorAssembler(inputCols=feature_list,outputCol="features")
  rf = RandomForestRegressor(labelCol="label",featuresCol="features",cacheNodeIds=True,seed = 42)
  

  paramGrid = ParamGridBuilder() \
        .addGrid(rf.numTrees,[10,20]) \
        .addGrid(rf.maxDepth,[5,10]) \
      .build()


  evaluator = RegressionEvaluator(labelCol="label",predictionCol="prediction",metricName="mae")

  # Train model
  crossval = CrossValidator(estimator=rf,estimatorParamMaps=paramGrid,evaluator=evaluator,numFolds=2)
  
  pipeline = Pipeline(stages=indexers+[encoder,assembler,crossval])
  cvModel_rf = pipeline.fit(data_train)

配置

  • Databricks 运行时版本 6.4(包括 Apache Spark 2.4.5、Scala 2.11)
  • 驱动程序节点 Standard_DS4_v2
  • 3 个工作节点 Standard_DS4_v2
  • 使用 pyspark.ml 库进行建模和编码

解决方法

在pyspark中如何使用Hashing https://spark.apache.org/docs/2.2.0/api/python/pyspark.ml.html#pyspark.ml.feature.HashingTF

这会将一列中的 1000 个不同值压缩为 n 个(选择 2^X)个不同值

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?