689 lines
61 KiB
Text
689 lines
61 KiB
Text
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# K近邻分类实验"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"在这个练习中,我们使用电信企业的客户流失数据集,`e2.1_Orange_Telecom_Churn_Data.csv`(存放在当前目录下)。我们先读入数据集,做一些数据预处理,然后使用K近邻模型根据用户的特点来预测其是否会流失。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 第一步:\n",
|
||
"* 将数据集读入变量`data`中,并查看其前5行。\n",
|
||
"* 去除其中的`\"state\"`,`\"area_code\"`和`\"phone_number\"`三列。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 88,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>state</th>\n",
|
||
" <th>account_length</th>\n",
|
||
" <th>area_code</th>\n",
|
||
" <th>phone_number</th>\n",
|
||
" <th>intl_plan</th>\n",
|
||
" <th>voice_mail_plan</th>\n",
|
||
" <th>number_vmail_messages</th>\n",
|
||
" <th>total_day_minutes</th>\n",
|
||
" <th>total_day_calls</th>\n",
|
||
" <th>total_day_charge</th>\n",
|
||
" <th>total_eve_minutes</th>\n",
|
||
" <th>total_eve_calls</th>\n",
|
||
" <th>total_eve_charge</th>\n",
|
||
" <th>total_night_minutes</th>\n",
|
||
" <th>total_night_calls</th>\n",
|
||
" <th>total_night_charge</th>\n",
|
||
" <th>total_intl_minutes</th>\n",
|
||
" <th>total_intl_calls</th>\n",
|
||
" <th>total_intl_charge</th>\n",
|
||
" <th>number_customer_service_calls</th>\n",
|
||
" <th>churned</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>KS</td>\n",
|
||
" <td>128</td>\n",
|
||
" <td>415</td>\n",
|
||
" <td>382-4657</td>\n",
|
||
" <td>no</td>\n",
|
||
" <td>yes</td>\n",
|
||
" <td>25</td>\n",
|
||
" <td>265.1</td>\n",
|
||
" <td>110</td>\n",
|
||
" <td>45.07</td>\n",
|
||
" <td>197.4</td>\n",
|
||
" <td>99</td>\n",
|
||
" <td>16.78</td>\n",
|
||
" <td>244.7</td>\n",
|
||
" <td>91</td>\n",
|
||
" <td>11.01</td>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>2.70</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>OH</td>\n",
|
||
" <td>107</td>\n",
|
||
" <td>415</td>\n",
|
||
" <td>371-7191</td>\n",
|
||
" <td>no</td>\n",
|
||
" <td>yes</td>\n",
|
||
" <td>26</td>\n",
|
||
" <td>161.6</td>\n",
|
||
" <td>123</td>\n",
|
||
" <td>27.47</td>\n",
|
||
" <td>195.5</td>\n",
|
||
" <td>103</td>\n",
|
||
" <td>16.62</td>\n",
|
||
" <td>254.4</td>\n",
|
||
" <td>103</td>\n",
|
||
" <td>11.45</td>\n",
|
||
" <td>13.7</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>3.70</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>NJ</td>\n",
|
||
" <td>137</td>\n",
|
||
" <td>415</td>\n",
|
||
" <td>358-1921</td>\n",
|
||
" <td>no</td>\n",
|
||
" <td>no</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>243.4</td>\n",
|
||
" <td>114</td>\n",
|
||
" <td>41.38</td>\n",
|
||
" <td>121.2</td>\n",
|
||
" <td>110</td>\n",
|
||
" <td>10.30</td>\n",
|
||
" <td>162.6</td>\n",
|
||
" <td>104</td>\n",
|
||
" <td>7.32</td>\n",
|
||
" <td>12.2</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>3.29</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>OH</td>\n",
|
||
" <td>84</td>\n",
|
||
" <td>408</td>\n",
|
||
" <td>375-9999</td>\n",
|
||
" <td>yes</td>\n",
|
||
" <td>no</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>299.4</td>\n",
|
||
" <td>71</td>\n",
|
||
" <td>50.90</td>\n",
|
||
" <td>61.9</td>\n",
|
||
" <td>88</td>\n",
|
||
" <td>5.26</td>\n",
|
||
" <td>196.9</td>\n",
|
||
" <td>89</td>\n",
|
||
" <td>8.86</td>\n",
|
||
" <td>6.6</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>1.78</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>OK</td>\n",
|
||
" <td>75</td>\n",
|
||
" <td>415</td>\n",
|
||
" <td>330-6626</td>\n",
|
||
" <td>yes</td>\n",
|
||
" <td>no</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>166.7</td>\n",
|
||
" <td>113</td>\n",
|
||
" <td>28.34</td>\n",
|
||
" <td>148.3</td>\n",
|
||
" <td>122</td>\n",
|
||
" <td>12.61</td>\n",
|
||
" <td>186.9</td>\n",
|
||
" <td>121</td>\n",
|
||
" <td>8.41</td>\n",
|
||
" <td>10.1</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>2.73</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>False</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" state account_length ... number_customer_service_calls churned\n",
|
||
"0 KS 128 ... 1 False\n",
|
||
"1 OH 107 ... 1 False\n",
|
||
"2 NJ 137 ... 0 False\n",
|
||
"3 OH 84 ... 2 False\n",
|
||
"4 OK 75 ... 3 False\n",
|
||
"\n",
|
||
"[5 rows x 21 columns]"
|
||
]
|
||
},
|
||
"execution_count": 88,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 将数据集读入变量data中,并查看其前5行\n",
|
||
"import pandas as pd\n",
|
||
"data = pd.read_csv('e2.1_Orange_Telecom_Churn_Data.csv')\n",
|
||
"data.head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 89,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 去除“state\",\"area_code\"和\"phone_number\"三列\n",
|
||
"data.drop('state', axis=1, inplace=True)\n",
|
||
"data.drop('area_code', axis=1, inplace=True)\n",
|
||
"data.drop('phone_number', axis=1, inplace=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 第二步:\n",
|
||
"* 有些列的值是分类数据,如`'intl_plan'`, `'voice_mail_plan'`, `'churned'`这三列,需要把它们转换成数值数据。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 90,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>account_length</th>\n",
|
||
" <th>intl_plan</th>\n",
|
||
" <th>voice_mail_plan</th>\n",
|
||
" <th>number_vmail_messages</th>\n",
|
||
" <th>total_day_minutes</th>\n",
|
||
" <th>total_day_calls</th>\n",
|
||
" <th>total_day_charge</th>\n",
|
||
" <th>total_eve_minutes</th>\n",
|
||
" <th>total_eve_calls</th>\n",
|
||
" <th>total_eve_charge</th>\n",
|
||
" <th>total_night_minutes</th>\n",
|
||
" <th>total_night_calls</th>\n",
|
||
" <th>total_night_charge</th>\n",
|
||
" <th>total_intl_minutes</th>\n",
|
||
" <th>total_intl_calls</th>\n",
|
||
" <th>total_intl_charge</th>\n",
|
||
" <th>number_customer_service_calls</th>\n",
|
||
" <th>churned</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>128</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>25</td>\n",
|
||
" <td>265.1</td>\n",
|
||
" <td>110</td>\n",
|
||
" <td>45.07</td>\n",
|
||
" <td>197.4</td>\n",
|
||
" <td>99</td>\n",
|
||
" <td>16.78</td>\n",
|
||
" <td>244.7</td>\n",
|
||
" <td>91</td>\n",
|
||
" <td>11.01</td>\n",
|
||
" <td>10.0</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>2.70</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>107</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>26</td>\n",
|
||
" <td>161.6</td>\n",
|
||
" <td>123</td>\n",
|
||
" <td>27.47</td>\n",
|
||
" <td>195.5</td>\n",
|
||
" <td>103</td>\n",
|
||
" <td>16.62</td>\n",
|
||
" <td>254.4</td>\n",
|
||
" <td>103</td>\n",
|
||
" <td>11.45</td>\n",
|
||
" <td>13.7</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>3.70</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>137</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>243.4</td>\n",
|
||
" <td>114</td>\n",
|
||
" <td>41.38</td>\n",
|
||
" <td>121.2</td>\n",
|
||
" <td>110</td>\n",
|
||
" <td>10.30</td>\n",
|
||
" <td>162.6</td>\n",
|
||
" <td>104</td>\n",
|
||
" <td>7.32</td>\n",
|
||
" <td>12.2</td>\n",
|
||
" <td>5</td>\n",
|
||
" <td>3.29</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>84</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>299.4</td>\n",
|
||
" <td>71</td>\n",
|
||
" <td>50.90</td>\n",
|
||
" <td>61.9</td>\n",
|
||
" <td>88</td>\n",
|
||
" <td>5.26</td>\n",
|
||
" <td>196.9</td>\n",
|
||
" <td>89</td>\n",
|
||
" <td>8.86</td>\n",
|
||
" <td>6.6</td>\n",
|
||
" <td>7</td>\n",
|
||
" <td>1.78</td>\n",
|
||
" <td>2</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>75</td>\n",
|
||
" <td>1</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>0</td>\n",
|
||
" <td>166.7</td>\n",
|
||
" <td>113</td>\n",
|
||
" <td>28.34</td>\n",
|
||
" <td>148.3</td>\n",
|
||
" <td>122</td>\n",
|
||
" <td>12.61</td>\n",
|
||
" <td>186.9</td>\n",
|
||
" <td>121</td>\n",
|
||
" <td>8.41</td>\n",
|
||
" <td>10.1</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>2.73</td>\n",
|
||
" <td>3</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" account_length intl_plan ... number_customer_service_calls churned\n",
|
||
"0 128 0 ... 1 0\n",
|
||
"1 107 0 ... 1 0\n",
|
||
"2 137 0 ... 0 0\n",
|
||
"3 84 1 ... 2 0\n",
|
||
"4 75 1 ... 3 0\n",
|
||
"\n",
|
||
"[5 rows x 18 columns]"
|
||
]
|
||
},
|
||
"execution_count": 90,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.preprocessing import LabelBinarizer\n",
|
||
"\n",
|
||
"lb = LabelBinarizer()\n",
|
||
"\n",
|
||
"for col in ['intl_plan', 'voice_mail_plan', 'churned']:\n",
|
||
" data[col] = lb.fit_transform(data[col])\n",
|
||
"data.head(5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 第三步:\n",
|
||
"* 将除“churned”列之外的所有其他列的数据与“churned”列的数据分开,即创建两张数据表,`X_data`和`y_data`。\n",
|
||
"* 使用课件中提到的某种尺度转换方法(scaling method)来缩放`X_data`。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 91,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 生成X_data和y_data\n",
|
||
"X_data = data.drop('churned', axis=1).values\n",
|
||
"y_data = data['churned'].values"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 92,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 缩放X_data\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"stdSc=StandardScaler()\n",
|
||
"X_data=stdSc.fit_transform(X_data)\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 第四步:\n",
|
||
"* 创建一个k=3的K近邻模型,并拟合`X_data`和`y_data`。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 93,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<style>#sk-container-id-6 {color: black;background-color: white;}#sk-container-id-6 pre{padding: 0;}#sk-container-id-6 div.sk-toggleable {background-color: white;}#sk-container-id-6 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-6 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-6 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-6 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-6 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-6 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-6 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-6 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-6 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-6 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-6 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-6 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-6 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-6 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-6 div.sk-item {position: relative;z-index: 1;}#sk-container-id-6 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-6 div.sk-item::before, #sk-container-id-6 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-6 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-6 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-6 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-6 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-6 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-6 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-6 div.sk-label-container {text-align: center;}#sk-container-id-6 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-6 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-6\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>KNeighborsClassifier(n_neighbors=3)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" checked><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">KNeighborsClassifier</label><div class=\"sk-toggleable__content\"><pre>KNeighborsClassifier(n_neighbors=3)</pre></div></div></div></div></div>"
|
||
],
|
||
"text/plain": [
|
||
"KNeighborsClassifier(n_neighbors=3)"
|
||
]
|
||
},
|
||
"execution_count": 93,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 创建一个3NN模型,并训练\n",
|
||
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
"knn = KNeighborsClassifier(n_neighbors=3)\n",
|
||
"knn.fit(X_data, y_data)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 第五步:\n",
|
||
"* 用上一步训练好的K近邻模型预测相同的数据集,即`X_data`,并评测预测结果的精度。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 94,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0 0.94 0.99 0.97 4293\n",
|
||
" 1 0.93 0.62 0.74 707\n",
|
||
"\n",
|
||
" accuracy 0.94 5000\n",
|
||
" macro avg 0.94 0.80 0.85 5000\n",
|
||
"weighted avg 0.94 0.94 0.93 5000\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 预测并评价\n",
|
||
"from sklearn.metrics import classification_report\n",
|
||
"y_pred = knn.predict(X_data)\n",
|
||
"print(classification_report(y_data, y_pred))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 第六步:\n",
|
||
"* 构建一个同样是`n_neighbors=3`的模型,但是用距离作为聚集K个近邻预测结果的权重。同样计算此模型在X_data上的预测精度。 \n",
|
||
"* 构建另一个K近邻模型:使用均匀分布的权重,但是将闵科夫斯基距离中的指数参数设为1(`p=1`),即使用曼哈顿距离。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 95,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0 1.00 1.00 1.00 4293\n",
|
||
" 1 1.00 1.00 1.00 707\n",
|
||
"\n",
|
||
" accuracy 1.00 5000\n",
|
||
" macro avg 1.00 1.00 1.00 5000\n",
|
||
"weighted avg 1.00 1.00 1.00 5000\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# n_neighbors=3, weights='distance'\n",
|
||
"knn = KNeighborsClassifier(n_neighbors=3, weights='distance')\n",
|
||
"knn.fit(X_data, y_data)\n",
|
||
"y_pred = knn.predict(X_data)\n",
|
||
"print(classification_report(y_data, y_pred))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 96,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0 0.94 0.99 0.97 4293\n",
|
||
" 1 0.94 0.63 0.75 707\n",
|
||
"\n",
|
||
" accuracy 0.94 5000\n",
|
||
" macro avg 0.94 0.81 0.86 5000\n",
|
||
"weighted avg 0.94 0.94 0.94 5000\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# n_neighbors=3, p=1\n",
|
||
"knn = KNeighborsClassifier(n_neighbors=3, p=1)\n",
|
||
"knn.fit(X_data, y_data)\n",
|
||
"y_pred = knn.predict(X_data)\n",
|
||
"print(classification_report(y_data, y_pred))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 第七步:\n",
|
||
"* 将K值从1变化到20,训练20个不同的K近邻模型。权重使用均匀分布的权重(缺省的)。闵科夫斯基距离的指数参数(`p`)可以设为1或者2(只要一致即可)。将每个模型得到的精度和其`k`值存到一个列表或字典中。\n",
|
||
"* 将`accuracy`和`k`的关系绘成图表。当`k=1`时,你观察到了什么? 为什么?"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 97,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[(1, 1.0), (2, 0.9218), (3, 0.9396), (4, 0.912), (5, 0.928), (6, 0.9094), (7, 0.9194), (8, 0.9052), (9, 0.9144), (10, 0.9024), (11, 0.9112), (12, 0.9002), (13, 0.9086), (14, 0.8984), (15, 0.9032), (16, 0.8944), (17, 0.8998), (18, 0.8934), (19, 0.8996), (20, 0.8928)]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"scores = []\n",
|
||
"for k in range(1, 21):\n",
|
||
" knn = KNeighborsClassifier(n_neighbors=k)\n",
|
||
" knn.fit(X_data, y_data)\n",
|
||
" y_pred = knn.predict(X_data)\n",
|
||
" scores.append((k, knn.score(X_data, y_data)))\n",
|
||
"print(scores)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 98,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[<matplotlib.lines.Line2D at 0x7fc1cc197c40>]"
|
||
]
|
||
},
|
||
"execution_count": 98,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"%matplotlib inline\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"plt.rcParams['font.size'] = 14\n",
|
||
"plt.xlabel('k')\n",
|
||
"plt.ylabel('accuracy')\n",
|
||
"plt.plot([k for k, v in scores], [v for k, v in scores], 'o-')"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.8"
|
||
},
|
||
"vscode": {
|
||
"interpreter": {
|
||
"hash": "1f0d395e06aa83586067b19165efc9b683889967164248deef4bbf1fa27cfb00"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|