690 lines
61 KiB
Text
690 lines
61 KiB
Text
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# K近邻分类实验"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"在这个练习中,我们使用电信企业的客户流失数据集,`e2.1_Orange_Telecom_Churn_Data.csv`(存放在当前目录下)。我们先读入数据集,做一些数据预处理,然后使用K近邻模型根据用户的特点来预测其是否会流失。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### 第一步:\n",
|
|||
|
"* 将数据集读入变量`data`中,并查看其前5行。\n",
|
|||
|
"* 去除其中的`\"state\"`,`\"area_code\"`和`\"phone_number\"`三列。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 88,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>state</th>\n",
|
|||
|
" <th>account_length</th>\n",
|
|||
|
" <th>area_code</th>\n",
|
|||
|
" <th>phone_number</th>\n",
|
|||
|
" <th>intl_plan</th>\n",
|
|||
|
" <th>voice_mail_plan</th>\n",
|
|||
|
" <th>number_vmail_messages</th>\n",
|
|||
|
" <th>total_day_minutes</th>\n",
|
|||
|
" <th>total_day_calls</th>\n",
|
|||
|
" <th>total_day_charge</th>\n",
|
|||
|
" <th>total_eve_minutes</th>\n",
|
|||
|
" <th>total_eve_calls</th>\n",
|
|||
|
" <th>total_eve_charge</th>\n",
|
|||
|
" <th>total_night_minutes</th>\n",
|
|||
|
" <th>total_night_calls</th>\n",
|
|||
|
" <th>total_night_charge</th>\n",
|
|||
|
" <th>total_intl_minutes</th>\n",
|
|||
|
" <th>total_intl_calls</th>\n",
|
|||
|
" <th>total_intl_charge</th>\n",
|
|||
|
" <th>number_customer_service_calls</th>\n",
|
|||
|
" <th>churned</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>KS</td>\n",
|
|||
|
" <td>128</td>\n",
|
|||
|
" <td>415</td>\n",
|
|||
|
" <td>382-4657</td>\n",
|
|||
|
" <td>no</td>\n",
|
|||
|
" <td>yes</td>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>265.1</td>\n",
|
|||
|
" <td>110</td>\n",
|
|||
|
" <td>45.07</td>\n",
|
|||
|
" <td>197.4</td>\n",
|
|||
|
" <td>99</td>\n",
|
|||
|
" <td>16.78</td>\n",
|
|||
|
" <td>244.7</td>\n",
|
|||
|
" <td>91</td>\n",
|
|||
|
" <td>11.01</td>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>2.70</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>OH</td>\n",
|
|||
|
" <td>107</td>\n",
|
|||
|
" <td>415</td>\n",
|
|||
|
" <td>371-7191</td>\n",
|
|||
|
" <td>no</td>\n",
|
|||
|
" <td>yes</td>\n",
|
|||
|
" <td>26</td>\n",
|
|||
|
" <td>161.6</td>\n",
|
|||
|
" <td>123</td>\n",
|
|||
|
" <td>27.47</td>\n",
|
|||
|
" <td>195.5</td>\n",
|
|||
|
" <td>103</td>\n",
|
|||
|
" <td>16.62</td>\n",
|
|||
|
" <td>254.4</td>\n",
|
|||
|
" <td>103</td>\n",
|
|||
|
" <td>11.45</td>\n",
|
|||
|
" <td>13.7</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>3.70</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>NJ</td>\n",
|
|||
|
" <td>137</td>\n",
|
|||
|
" <td>415</td>\n",
|
|||
|
" <td>358-1921</td>\n",
|
|||
|
" <td>no</td>\n",
|
|||
|
" <td>no</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>243.4</td>\n",
|
|||
|
" <td>114</td>\n",
|
|||
|
" <td>41.38</td>\n",
|
|||
|
" <td>121.2</td>\n",
|
|||
|
" <td>110</td>\n",
|
|||
|
" <td>10.30</td>\n",
|
|||
|
" <td>162.6</td>\n",
|
|||
|
" <td>104</td>\n",
|
|||
|
" <td>7.32</td>\n",
|
|||
|
" <td>12.2</td>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" <td>3.29</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>OH</td>\n",
|
|||
|
" <td>84</td>\n",
|
|||
|
" <td>408</td>\n",
|
|||
|
" <td>375-9999</td>\n",
|
|||
|
" <td>yes</td>\n",
|
|||
|
" <td>no</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>299.4</td>\n",
|
|||
|
" <td>71</td>\n",
|
|||
|
" <td>50.90</td>\n",
|
|||
|
" <td>61.9</td>\n",
|
|||
|
" <td>88</td>\n",
|
|||
|
" <td>5.26</td>\n",
|
|||
|
" <td>196.9</td>\n",
|
|||
|
" <td>89</td>\n",
|
|||
|
" <td>8.86</td>\n",
|
|||
|
" <td>6.6</td>\n",
|
|||
|
" <td>7</td>\n",
|
|||
|
" <td>1.78</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>OK</td>\n",
|
|||
|
" <td>75</td>\n",
|
|||
|
" <td>415</td>\n",
|
|||
|
" <td>330-6626</td>\n",
|
|||
|
" <td>yes</td>\n",
|
|||
|
" <td>no</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>166.7</td>\n",
|
|||
|
" <td>113</td>\n",
|
|||
|
" <td>28.34</td>\n",
|
|||
|
" <td>148.3</td>\n",
|
|||
|
" <td>122</td>\n",
|
|||
|
" <td>12.61</td>\n",
|
|||
|
" <td>186.9</td>\n",
|
|||
|
" <td>121</td>\n",
|
|||
|
" <td>8.41</td>\n",
|
|||
|
" <td>10.1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>2.73</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>False</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" state account_length ... number_customer_service_calls churned\n",
|
|||
|
"0 KS 128 ... 1 False\n",
|
|||
|
"1 OH 107 ... 1 False\n",
|
|||
|
"2 NJ 137 ... 0 False\n",
|
|||
|
"3 OH 84 ... 2 False\n",
|
|||
|
"4 OK 75 ... 3 False\n",
|
|||
|
"\n",
|
|||
|
"[5 rows x 21 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 88,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 将数据集读入变量data中,并查看其前5行\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"data = pd.read_csv('e2.1_Orange_Telecom_Churn_Data.csv')\n",
|
|||
|
"data.head()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 89,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 去除“state\",\"area_code\"和\"phone_number\"三列\n",
|
|||
|
"data.drop('state', axis=1, inplace=True)\n",
|
|||
|
"data.drop('area_code', axis=1, inplace=True)\n",
|
|||
|
"data.drop('phone_number', axis=1, inplace=True)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### 第二步:\n",
|
|||
|
"* 有些列的值是分类数据,如`'intl_plan'`, `'voice_mail_plan'`, `'churned'`这三列,需要把它们转换成数值数据。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 90,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>account_length</th>\n",
|
|||
|
" <th>intl_plan</th>\n",
|
|||
|
" <th>voice_mail_plan</th>\n",
|
|||
|
" <th>number_vmail_messages</th>\n",
|
|||
|
" <th>total_day_minutes</th>\n",
|
|||
|
" <th>total_day_calls</th>\n",
|
|||
|
" <th>total_day_charge</th>\n",
|
|||
|
" <th>total_eve_minutes</th>\n",
|
|||
|
" <th>total_eve_calls</th>\n",
|
|||
|
" <th>total_eve_charge</th>\n",
|
|||
|
" <th>total_night_minutes</th>\n",
|
|||
|
" <th>total_night_calls</th>\n",
|
|||
|
" <th>total_night_charge</th>\n",
|
|||
|
" <th>total_intl_minutes</th>\n",
|
|||
|
" <th>total_intl_calls</th>\n",
|
|||
|
" <th>total_intl_charge</th>\n",
|
|||
|
" <th>number_customer_service_calls</th>\n",
|
|||
|
" <th>churned</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>128</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>25</td>\n",
|
|||
|
" <td>265.1</td>\n",
|
|||
|
" <td>110</td>\n",
|
|||
|
" <td>45.07</td>\n",
|
|||
|
" <td>197.4</td>\n",
|
|||
|
" <td>99</td>\n",
|
|||
|
" <td>16.78</td>\n",
|
|||
|
" <td>244.7</td>\n",
|
|||
|
" <td>91</td>\n",
|
|||
|
" <td>11.01</td>\n",
|
|||
|
" <td>10.0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>2.70</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>107</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>26</td>\n",
|
|||
|
" <td>161.6</td>\n",
|
|||
|
" <td>123</td>\n",
|
|||
|
" <td>27.47</td>\n",
|
|||
|
" <td>195.5</td>\n",
|
|||
|
" <td>103</td>\n",
|
|||
|
" <td>16.62</td>\n",
|
|||
|
" <td>254.4</td>\n",
|
|||
|
" <td>103</td>\n",
|
|||
|
" <td>11.45</td>\n",
|
|||
|
" <td>13.7</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>3.70</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>137</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>243.4</td>\n",
|
|||
|
" <td>114</td>\n",
|
|||
|
" <td>41.38</td>\n",
|
|||
|
" <td>121.2</td>\n",
|
|||
|
" <td>110</td>\n",
|
|||
|
" <td>10.30</td>\n",
|
|||
|
" <td>162.6</td>\n",
|
|||
|
" <td>104</td>\n",
|
|||
|
" <td>7.32</td>\n",
|
|||
|
" <td>12.2</td>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" <td>3.29</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>84</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>299.4</td>\n",
|
|||
|
" <td>71</td>\n",
|
|||
|
" <td>50.90</td>\n",
|
|||
|
" <td>61.9</td>\n",
|
|||
|
" <td>88</td>\n",
|
|||
|
" <td>5.26</td>\n",
|
|||
|
" <td>196.9</td>\n",
|
|||
|
" <td>89</td>\n",
|
|||
|
" <td>8.86</td>\n",
|
|||
|
" <td>6.6</td>\n",
|
|||
|
" <td>7</td>\n",
|
|||
|
" <td>1.78</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>75</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>166.7</td>\n",
|
|||
|
" <td>113</td>\n",
|
|||
|
" <td>28.34</td>\n",
|
|||
|
" <td>148.3</td>\n",
|
|||
|
" <td>122</td>\n",
|
|||
|
" <td>12.61</td>\n",
|
|||
|
" <td>186.9</td>\n",
|
|||
|
" <td>121</td>\n",
|
|||
|
" <td>8.41</td>\n",
|
|||
|
" <td>10.1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>2.73</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" account_length intl_plan ... number_customer_service_calls churned\n",
|
|||
|
"0 128 0 ... 1 0\n",
|
|||
|
"1 107 0 ... 1 0\n",
|
|||
|
"2 137 0 ... 0 0\n",
|
|||
|
"3 84 1 ... 2 0\n",
|
|||
|
"4 75 1 ... 3 0\n",
|
|||
|
"\n",
|
|||
|
"[5 rows x 18 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 90,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.preprocessing import LabelBinarizer\n",
|
|||
|
"\n",
|
|||
|
"lb = LabelBinarizer()\n",
|
|||
|
"\n",
|
|||
|
"for col in ['intl_plan', 'voice_mail_plan', 'churned']:\n",
|
|||
|
" data[col] = lb.fit_transform(data[col])\n",
|
|||
|
"data.head(5)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### 第三步:\n",
|
|||
|
"* 将除“churned”列之外的所有其他列的数据与“churned”列的数据分开,即创建两张数据表,`X_data`和`y_data`。\n",
|
|||
|
"* 使用课件中提到的某种尺度转换方法(scaling method)来缩放`X_data`。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 91,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 生成X_data和y_data\n",
|
|||
|
"X_data = data.drop('churned', axis=1).values\n",
|
|||
|
"y_data = data['churned'].values"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 92,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 缩放X_data\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"stdSc=StandardScaler()\n",
|
|||
|
"X_data=stdSc.fit_transform(X_data)\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### 第四步:\n",
|
|||
|
"* 创建一个k=3的K近邻模型,并拟合`X_data`和`y_data`。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 93,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<style>#sk-container-id-6 {color: black;background-color: white;}#sk-container-id-6 pre{padding: 0;}#sk-container-id-6 div.sk-toggleable {background-color: white;}#sk-container-id-6 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-6 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-6 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-6 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-6 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-6 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-6 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-6 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-6 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-6 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-6 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-6 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-6 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-6 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-6 div.sk-item {position: relative;z-index: 1;}#sk-container-id-6 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-6 div.sk-item::before, #sk-container-id-6 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-6 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-6 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-6 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-6 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-6 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-6 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-6 div.sk-label-container {text-align: center;}#sk-container-id-6 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-6 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-6\" class=\"sk-top-container\
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
"KNeighborsClassifier(n_neighbors=3)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 93,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 创建一个3NN模型,并训练\n",
|
|||
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
|||
|
"knn = KNeighborsClassifier(n_neighbors=3)\n",
|
|||
|
"knn.fit(X_data, y_data)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### 第五步:\n",
|
|||
|
"* 用上一步训练好的K近邻模型预测相同的数据集,即`X_data`,并评测预测结果的精度。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 94,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" precision recall f1-score support\n",
|
|||
|
"\n",
|
|||
|
" 0 0.94 0.99 0.97 4293\n",
|
|||
|
" 1 0.93 0.62 0.74 707\n",
|
|||
|
"\n",
|
|||
|
" accuracy 0.94 5000\n",
|
|||
|
" macro avg 0.94 0.80 0.85 5000\n",
|
|||
|
"weighted avg 0.94 0.94 0.93 5000\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 预测并评价\n",
|
|||
|
"from sklearn.metrics import classification_report\n",
|
|||
|
"y_pred = knn.predict(X_data)\n",
|
|||
|
"print(classification_report(y_data, y_pred))\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### 第六步:\n",
|
|||
|
"* 构建一个同样是`n_neighbors=3`的模型,但是用距离作为聚集K个近邻预测结果的权重。同样计算此模型在X_data上的预测精度。 \n",
|
|||
|
"* 构建另一个K近邻模型:使用均匀分布的权重,但是将闵科夫斯基距离中的指数参数设为1(`p=1`),即使用曼哈顿距离。"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 95,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" precision recall f1-score support\n",
|
|||
|
"\n",
|
|||
|
" 0 1.00 1.00 1.00 4293\n",
|
|||
|
" 1 1.00 1.00 1.00 707\n",
|
|||
|
"\n",
|
|||
|
" accuracy 1.00 5000\n",
|
|||
|
" macro avg 1.00 1.00 1.00 5000\n",
|
|||
|
"weighted avg 1.00 1.00 1.00 5000\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# n_neighbors=3, weights='distance'\n",
|
|||
|
"knn = KNeighborsClassifier(n_neighbors=3, weights='distance')\n",
|
|||
|
"knn.fit(X_data, y_data)\n",
|
|||
|
"y_pred = knn.predict(X_data)\n",
|
|||
|
"print(classification_report(y_data, y_pred))\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 96,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" precision recall f1-score support\n",
|
|||
|
"\n",
|
|||
|
" 0 0.94 0.99 0.97 4293\n",
|
|||
|
" 1 0.94 0.63 0.75 707\n",
|
|||
|
"\n",
|
|||
|
" accuracy 0.94 5000\n",
|
|||
|
" macro avg 0.94 0.81 0.86 5000\n",
|
|||
|
"weighted avg 0.94 0.94 0.94 5000\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# n_neighbors=3, p=1\n",
|
|||
|
"knn = KNeighborsClassifier(n_neighbors=3, p=1)\n",
|
|||
|
"knn.fit(X_data, y_data)\n",
|
|||
|
"y_pred = knn.predict(X_data)\n",
|
|||
|
"print(classification_report(y_data, y_pred))\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### 第七步:\n",
|
|||
|
"* 将K值从1变化到20,训练20个不同的K近邻模型。权重使用均匀分布的权重(缺省的)。闵科夫斯基距离的指数参数(`p`)可以设为1或者2(只要一致即可)。将每个模型得到的精度和其`k`值存到一个列表或字典中。\n",
|
|||
|
"* 将`accuracy`和`k`的关系绘成图表。当`k=1`时,你观察到了什么? 为什么?"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 97,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"[(1, 1.0), (2, 0.9218), (3, 0.9396), (4, 0.912), (5, 0.928), (6, 0.9094), (7, 0.9194), (8, 0.9052), (9, 0.9144), (10, 0.9024), (11, 0.9112), (12, 0.9002), (13, 0.9086), (14, 0.8984), (15, 0.9032), (16, 0.8944), (17, 0.8998), (18, 0.8934), (19, 0.8996), (20, 0.8928)]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"scores = []\n",
|
|||
|
"for k in range(1, 21):\n",
|
|||
|
" knn = KNeighborsClassifier(n_neighbors=k)\n",
|
|||
|
" knn.fit(X_data, y_data)\n",
|
|||
|
" y_pred = knn.predict(X_data)\n",
|
|||
|
" scores.append((k, knn.score(X_data, y_data)))\n",
|
|||
|
"print(scores)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 98,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"[<matplotlib.lines.Line2D at 0x7fc1cc197c40>]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 98,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlEAAAG6CAYAAAAoFxMCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAABqfElEQVR4nO3deXxM5/4H8M+ZyWQiO1kkiJAQmoo9llI7oQtabdHauqhq+6uWlksX1C3V6r2qvdWWWzvlVlXtsROE2MWWCBIikX2yT5KZ8/tjzFSa/WTW+Lxfr7x+P3POnPPMjTQfz/M930cQRVEEEREREdWIzNIDICIiIrJFDFFEREREEjBEEREREUnAEEVEREQkAUMUERERkQQMUUREREQSMEQRERERSWBn6QHUZVqtFvfu3YOLiwsEQbD0cIiIiKgaRFFETk4OGjVqBJms4vkmhigTunfvHvz8/Cw9DCIiIpLgzp07aNKkSYXHGaJMyMXFBYDum+Dq6mrh0RAREVF1ZGdnw8/Pz/B7vCIMUSakX8JzdXVliCIiIrIxVZXisLCciIiISAKGKCIiIiIJGKKIiIiIJGCIIiIiIpKAIYqIiIhIAoYoIiIiIgkYooiIiIgkYIgiIiIiksAqQ9TatWsxadIkdO7cGUqlEoIgYOXKlTW+jlarxXfffYeQkBDUq1cPXl5eGD16NG7evFnhe/bs2YPevXvDxcUFrq6u6Nu3L/bv31+LT0NERER1kVWGqE8++QQ///wz4uPj4evrK/k6kyZNwnvvvQdRFPHee+9h8ODB+P333xEaGorY2Ngy569duxaDBw/G1atXMWHCBIwfPx6XL1/GwIED8dtvv9XmIxmNRiviRFw6tp5PxIm4dGi0oqWHRERE9EgSRFG0ut/C+/btQ8uWLeHv748vv/wSM2fOxIoVKzBhwoRqX+PgwYPo168fevXqhb1798Le3h4AsGvXLjz11FMYNGgQ9uzZYzg/MzMTAQEBsLOzw7lz5wwbDt69excdOnQAANy8ebPKfXQelp2dDTc3N6hUKqNs+7I7Oglzt11BkqrQ8JqvmwNmPxuMwW2kh00iIiL6S3V/f1vlTNSAAQPg7+9fq2ssW7YMADBv3jxDgAKAIUOGoE+fPggPD0dCQoLh9f/973/IysrC//3f/5XasblJkyZ49913kZaWhi1bttRqTLWxOzoJk9eeLRWgACBZVYjJa89id3SShUZGRET0aLLKEGUMhw4dgpOTE3r06FHmWFhYGADg8OHDpc4HgEGDBlXrfHPSaEXM3XYF5U0Z6l+bu+0Kl/aIiIjMqE6GqLy8PCQlJaF58+aQy+Vljrds2RIAStVF6f9//bGqzi+PWq1GdnZ2qS9jOHUro8wM1MNEAEmqQpy6lWGU+xEREVHV6mSIUqlUAAA3N7dyj+vXN/XnVfWe8s4vz4IFC+Dm5mb48vPzq/ngy5GSU3GAknIeERER1V6dDFGWMnPmTKhUKsPXnTt3jHJdbxcHo55HREREtWdn6QGYgn42qaKZI/0y28OzTg+/x8PDo8rzy6NUKqFUKqUNuhJdmjeAr5sDklWF5dZFCQB83BzQpXkDo9+biIiIylcnZ6KcnJzg6+uLW7duQaPRlDleXv1TZXVPldVLmYNcJmD2s8EAdIHpYfo/z342GHLZ348SERGRqdTJEAUAvXv3Rl5eHo4dO1bmmL4/VK9evUqdDwDh4eEVnq8/xxIGt/HF0jEd4eNWesnOx80BS8d0ZJ8oIiIiM7P5EJWWloZr164hLS2t1OtvvvkmAODTTz9FUVGR4fVdu3bh0KFDGDRoUKleVC+99BLc3Nzw3Xff4e7du4bX7969i++//x6enp547rnnTPxpKje4jS8iZvTDuG66cXcLaICIGf0YoIiIiCzAKmuili9fjoiICADApUuXDK/pezn17NkTb7zxBgDg+++/x9y5czF79mzMmTPHcI2+ffvijTfewPLly9GxY0c8/fTTSEpKwsaNG9GgQQN89913pe5Zv359fP/99xg7diw6duyIkSNHAgA2btyI9PR0bNy4sUbdyk1FLhMQ2rwBVkfGQyuCS3hEREQWYpUhKiIiAqtWrSr12rFjx0otzelDVGV++uknhISE4Oeff8a3334LZ2dnPPfcc/jiiy8QGBhY5vwxY8bA09MT8+fPx4oVKyAIAjp16oRPPvkEAwYMqP0HMxJvF13xemqO2sIjISIienRZ5d55dYWx987Tu5WWh76LDsHJXo7Lnw822nWJiIjIxvfOo8rpZ6LyijTIU5dYeDRERESPJoYoG+SktIOTvW47Gy7pERERWQZDlI3yejAblcIQRUREZBEMUTZKv8UL98sjIiKyDIYoG+Xl+mAmKpszUURERJbAEGWjDG0OchmiiIiILIEhykYZaqI4E0VERGQRDFE2ijVRRERElsUQZaPYtZyIiMiyGKJslLcrWxwQERFZEkOUjdIv52XkFaFYo7XwaIiIiB49DFE2yr2eAnYyAQCQxif0iIiIzI4hykbJZAKf0CMiIrIghigb5s2tX4iIiCyGIcqGebHNARERkcUwRNkw/RN6bHNARERkfgxRNszLmct5RERElsIQZcO8uQkxERGRxTBE2TB9r6hU1kQRERGZHUOUDePTeURERJbDEGXD9H2i0nLV0GpFC4+GiIjo0cIQZcM8HxSWF2tEZBUUW3g0REREjxaGKBtmbydDAyd7AOwVRUREZG4MUTbOm1u/EBERWQRDlI3T10Wx4SYREZF5MUTZOC8+oUdERGQRDFE2zpv75xEREVkEQ5SNY68oIiIiy2CIsnGGTYhZWE5ERGRWDFE2zrD1Sy5DFBERkTkxRNk4Q2F5NmuiiIiIzIkhysbpa6LyijTIU5dYeDRERESPDoYoG+ektIOTvRwAi8uJiIjMiSGqDvB2fdDmgEt6REREZsMQVQd4PdiImMXlRERE5sMQVQd4uXL/PCIiInNjiKoD2HCTiIjI/Bii6gBu/UJERGR+DFF1gH4mKpUzUURERGbDEFUHeDFEERERmR1DVB2g3z+PNVFERETmwxBVB+hrojLyilBUorXwaIiIiB4NDFF1QH1HBRRyAQCQxl5RREREZsEQVQcIgmBouMklPSIiIvNgiKojWFxORERkXgxRdYQXe0URERGZFUNUHeHNrV+IiIjMiiGqjuDWL0REROZltSEqKioKTz31FNzd3eHk5IRu3bph06ZNNbrG1atX8corr8DHxwdKpRL+/v6YMmUKMjIyyj2/pKQEv/zyC7p37w4vLy+4uLggODgY06dPR3JysjE+lsmwJoqIiMi87Cw9gPIcPHgQYWFhcHBwwKhRo+Di4oLNmzdj5MiRuHPnDqZNm1blNSIjIzFgwAAUFBRg2LBhCAwMxPnz57FkyRLs3r0bx48fh4eHR6n3jBw5Er///jtatGiBUaNGQalUIjIyEl9//TXWrl2Ls2fPwsfHx1Qfu1b0vaJSWRNFRERkHqKVKS4uFgMDA0WlUimeO3fO8HpWVpYYFBQk2tvbi7dv367yOm3atBEBiFu3bi31+ldffSUCECdNmlTq9ZMnT4oAxC5duohFRUWljr333nsiAHHu3Lk1+iwqlUoEIKpUqhq9T4rzCZmi/4ztYrf5+0x+LyIiorqsur+/rW4578CBA4iLi8PLL7+M9u3bG153c3PDrFmzUFRUhFWrVlV6jbi4OERHRyM0NBRDhw4tdWzatGnw8PDAmjVrkJeXZ3j95s2bAIABAwZAoVCUes8zzzwDAEhNTa3NRzMpfWF5ao4aWq1o4dEQERHVfVYXog4dOgQAGDRoUJljYWFhAIDDhw9Xeg19/VLz5s3LHJPJZGjatCny8/MRGRlpeP3xxx8HAOzbtw/FxcWl3rN9+3YAQP/+/av5KczP80GzzRKtiMz8IguPhoiIqO6zupqo2NhYAEDLli3LHPPx8YGzs7PhnIp4enoCAG7dulXmmFarRUJCAgAgJibGEIxCQkIwZcoUfPvttwgODsaQIUOgVCpx4sQJnDlzBnPnzsXw4cMrva9arYZa/Vdhd3Z2dqXnG5NCLkMDJ3tk5BUhNVc
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%matplotlib inline\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"plt.rcParams['font.size'] = 14\n",
|
|||
|
"plt.xlabel('k')\n",
|
|||
|
"plt.ylabel('accuracy')\n",
|
|||
|
"plt.plot([k for k, v in scores], [v for k, v in scores], 'o-')"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3.10.8 ('.venv': venv)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.10.8"
|
|||
|
},
|
|||
|
"vscode": {
|
|||
|
"interpreter": {
|
|||
|
"hash": "1f0d395e06aa83586067b19165efc9b683889967164248deef4bbf1fa27cfb00"
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|