随机森林回归
Signed-off-by: 张卓立 <13190677+zhang-zhuoli@user.noreply.gitee.com>
This commit is contained in:
@@ -0,0 +1,278 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "67537bcc-9ece-42d8-a6a7-924966604450",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 随机森林回归"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "8f588675-9e05-45c7-9203-15c52f7ddd05",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor\n",
|
||||
"from sklearn.metrics import mean_squared_error\n",
|
||||
"from sklearn.preprocessing import StandardScaler\n",
|
||||
"from sklearn.model_selection import train_test_split"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "0a02061e-296a-4f6a-91af-c7fa27d46f17",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>host_response_rate</th>\n",
|
||||
" <th>host_acceptance_rate</th>\n",
|
||||
" <th>accommodates</th>\n",
|
||||
" <th>price</th>\n",
|
||||
" <th>number_of_reviews</th>\n",
|
||||
" <th>review_scores_rating</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.33</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>120.0</td>\n",
|
||||
" <td>90.0</td>\n",
|
||||
" <td>4.50</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.98</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>90.0</td>\n",
|
||||
" <td>351.0</td>\n",
|
||||
" <td>4.58</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.98</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>66.0</td>\n",
|
||||
" <td>67.0</td>\n",
|
||||
" <td>4.52</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.98</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>33.0</td>\n",
|
||||
" <td>297.0</td>\n",
|
||||
" <td>4.70</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>45.0</td>\n",
|
||||
" <td>42.0</td>\n",
|
||||
" <td>4.98</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203252</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.93</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" <td>152.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>4.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203253</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.97</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>45.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>3.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203254</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.97</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>40.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203276</th>\n",
|
||||
" <td>0.99</td>\n",
|
||||
" <td>0.99</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>43.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203308</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>110.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.00</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>134835 rows × 6 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" host_response_rate host_acceptance_rate accommodates price \\\n",
|
||||
"0 1.00 0.33 2.0 120.0 \n",
|
||||
"1 1.00 0.98 2.0 90.0 \n",
|
||||
"2 1.00 0.98 2.0 66.0 \n",
|
||||
"3 1.00 0.98 1.0 33.0 \n",
|
||||
"5 1.00 1.00 2.0 45.0 \n",
|
||||
"... ... ... ... ... \n",
|
||||
"203252 1.00 0.93 4.0 152.0 \n",
|
||||
"203253 1.00 0.97 2.0 45.0 \n",
|
||||
"203254 1.00 0.97 2.0 40.0 \n",
|
||||
"203276 0.99 0.99 2.0 43.0 \n",
|
||||
"203308 1.00 1.00 3.0 110.0 \n",
|
||||
"\n",
|
||||
" number_of_reviews review_scores_rating \n",
|
||||
"0 90.0 4.50 \n",
|
||||
"1 351.0 4.58 \n",
|
||||
"2 67.0 4.52 \n",
|
||||
"3 297.0 4.70 \n",
|
||||
"5 42.0 4.98 \n",
|
||||
"... ... ... \n",
|
||||
"203252 1.0 4.00 \n",
|
||||
"203253 1.0 3.00 \n",
|
||||
"203254 1.0 1.00 \n",
|
||||
"203276 1.0 5.00 \n",
|
||||
"203308 1.0 5.00 \n",
|
||||
"\n",
|
||||
"[134835 rows x 6 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"variables = ['number_of_reviews', 'price', 'accommodates',\n",
|
||||
" 'host_response_rate', 'host_acceptance_rate', 'review_scores_rating']\n",
|
||||
"df = pd.read_csv('../data/2022-01(US_25).csv', usecols=variables)\n",
|
||||
"df['price'] = df['price'].replace('\\$', '', regex=True)\n",
|
||||
"df['price'] = df['price'].replace('\\,', '', regex=True).astype(float)\n",
|
||||
"df[['host_response_rate', 'host_acceptance_rate']] = df[['host_response_rate',\n",
|
||||
" 'host_acceptance_rate']].replace('\\%', '', regex=True).astype(float)*0.01\n",
|
||||
"df[['number_of_reviews']] = df[['number_of_reviews']].astype(float)\n",
|
||||
"for col in variables:\n",
|
||||
" df[col] = df[col].astype(np.float32)\n",
|
||||
" df = df[np.isnan(df[col]) != 1]\n",
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "35569391-e849-4a5e-adf5-25dfd849b2a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 固定划分训练集和测试集\n",
|
||||
"info = df.iloc[:, :-1].values\n",
|
||||
"target = df.iloc[:, -1].values\n",
|
||||
"# 标准化\n",
|
||||
"stdscaler = StandardScaler()\n",
|
||||
"info_train, info_test, target_train, target_test = train_test_split(\n",
|
||||
" info, target, test_size=0.3,shuffle=True, random_state=420)\n",
|
||||
"info_train = stdscaler.fit_transform(info_train)\n",
|
||||
"info_test = stdscaler.transform(info_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "7ca33f2e-bcc8-4234-a27b-fa7e1df04030",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"MSE:0.2333\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# randomforest回归\n",
|
||||
"rf = RandomForestRegressor(n_estimators=100, random_state=0)\n",
|
||||
"rf.fit(info_train, target_train)\n",
|
||||
"target_pred = rf.predict(info_test)\n",
|
||||
"print(\"MSE:%.4f\" % mean_squared_error(target_test, target_pred))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user