微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在PySpark中从spark.ml提取模型超参数?

如何解决如何在PySpark中从spark.ml提取模型超参数?

也遇到这个问题。我发现由于某些原因(我不知道为什么)需要调用java属性。因此,只需执行以下操作:

from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder, CrossValidator
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator(metricName="mae")
lr = LinearRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [500]) \
                                .addGrid(lr.regParam, [0]) \
                                .addGrid(lr.elasticNetParam, [1]) \
                                .build()
lr_cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, \
                        evaluator=evaluator, numFolds=3)
lrModel = lr_cv.fit(your_training_set_here)
bestModel = lrModel.bestModel

打印出所需的参数:

>>> print 'Best Param (regParam): ', bestModel._java_obj.getRegparam()
0
>>> print 'Best Param (MaxIter): ', bestModel._java_obj.getMaxIter()
500
>>> print 'Best Param (elasticNetParam): ', bestModel._java_obj.getElasticNetparam()
1

这同样适用于其他方法extractParamMap()。他们应该尽快解决此问题。

解决方法

我正在修改PySpark文档中的一些交叉验证代码,并试图让PySpark告诉我选择了哪种模型:

from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.linalg import Vectors
from pyspark.ml.tuning import ParamGridBuilder,CrossValidator

dataset = sqlContext.createDataFrame(
    [(Vectors.dense([0.0]),0.0),(Vectors.dense([0.4]),1.0),(Vectors.dense([0.5]),(Vectors.dense([0.6]),(Vectors.dense([1.0]),1.0)] * 10,["features","label"])
lr = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr.regParam,[0.1,0.01,0.001,0.0001]).build()
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(estimator=lr,estimatorParamMaps=grid,evaluator=evaluator)
cvModel = cv.fit(dataset)

在PySpark shell中运行此代码,可以得到线性回归模型的系数,但似乎无法找到lr.regParam通过交叉验证过程选择的值。有任何想法吗?

In [3]: cvModel.bestModel.coefficients
Out[3]: DenseVector([3.1573])

In [4]: cvModel.bestModel.explainParams()
Out[4]: ''

In [5]: cvModel.bestModel.extractParamMap()
Out[5]: {}

In [15]: cvModel.params
Out[15]: []

In [36]: cvModel.bestModel.params
Out[36]: []

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。