微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

试图了解ML上的示例脚本

如何解决试图了解ML上的示例脚本

我正在尝试研究有关机器学习的示例脚本:Common pitfalls in interpretation of coefficients of linear models,但在理解某些步骤时遇到了麻烦。脚本的开头看起来像这样:

import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.datasets import fetch_openml

survey = fetch_openml(data_id=534,as_frame=True)

# We identify features `X` and targets `y`: the column WAGE is our
# target variable (i.e.,the variable which we want to predict).
X = survey.data[survey.feature_names]
X.describe(include="all")

X.head()

# Our target for prediction is the wage.
y = survey.target.values.ravel()
survey.target.head()

from sklearn.model_selection import train_test_split

X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=42)

train_dataset = X_train.copy()
train_dataset.insert(0,"WAGE",y_train)
_ = sns.pairplot(train_dataset,kind='reg',diag_kind='kde')

我的问题出在行上

y = survey.target.values.ravel()
survey.target.head()

如果我们在这些行之后立即检查survey.target.head(),则输出

Out[36]: 
0    5.10
1    4.95
2    6.67
3    4.00
4    7.50
Name: WAGE,dtype: float64

模型如何知道WAGE是目标变量?不必明确声明吗?

解决方法

survey.target.values.ravel()的目的是使数组变平,但是在此示例中没有必要。 survey.target是pd系列(即1列数据框),survey.target.values是一个numpy数组。因为survey.target中只有1列,所以您可以同时使用它进行训练/测试拆分。

type(survey.target)
pandas.core.series.Series

type(survey.target.values)
numpy.ndarray

如果我们仅使用Survey.target,则可以看到回归将起作用:

y = survey.target

X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=42)

train_dataset = X_train.copy()
train_dataset.insert(0,"WAGE",y_train)
sns.pairplot(train_dataset,kind='reg',diag_kind='kde')

enter image description here

如果您有另一个数据集,例如虹膜,我想将花瓣宽度相对于其余区域进行回归。您可以使用方括号[]来调用data.frame的列:

from sklearn.datasets import load_iris
from sklearn.linear_model import LinearRegression

dat = load_iris(as_frame=True).frame

X = dat[['sepal length (cm)','sepal width (cm)','petal length (cm)']]
y = dat[['petal width (cm)']]

X_train,random_state=42)

LR = LinearRegression()
LR.fit(X_train,y_train)
plt.scatter(x=y_test,y=LR.predict(X_test))

enter image description here

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。