随机森林回归

Signed-off-by: 张卓立 <13190677+zhang-zhuoli@user.noreply.gitee.com>
This commit is contained in:
张卓立
2023-07-15 12:05:21 +00:00
committed by Gitee
parent df2e5b3cbb
commit a395c9c6e0
@@ -0,0 +1,278 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "67537bcc-9ece-42d8-a6a7-924966604450",
"metadata": {},
"source": [
"# 随机森林回归"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8f588675-9e05-45c7-9203-15c52f7ddd05",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0a02061e-296a-4f6a-91af-c7fa27d46f17",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>host_response_rate</th>\n",
" <th>host_acceptance_rate</th>\n",
" <th>accommodates</th>\n",
" <th>price</th>\n",
" <th>number_of_reviews</th>\n",
" <th>review_scores_rating</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.00</td>\n",
" <td>0.33</td>\n",
" <td>2.0</td>\n",
" <td>120.0</td>\n",
" <td>90.0</td>\n",
" <td>4.50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.00</td>\n",
" <td>0.98</td>\n",
" <td>2.0</td>\n",
" <td>90.0</td>\n",
" <td>351.0</td>\n",
" <td>4.58</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.00</td>\n",
" <td>0.98</td>\n",
" <td>2.0</td>\n",
" <td>66.0</td>\n",
" <td>67.0</td>\n",
" <td>4.52</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.00</td>\n",
" <td>0.98</td>\n",
" <td>1.0</td>\n",
" <td>33.0</td>\n",
" <td>297.0</td>\n",
" <td>4.70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1.00</td>\n",
" <td>1.00</td>\n",
" <td>2.0</td>\n",
" <td>45.0</td>\n",
" <td>42.0</td>\n",
" <td>4.98</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203252</th>\n",
" <td>1.00</td>\n",
" <td>0.93</td>\n",
" <td>4.0</td>\n",
" <td>152.0</td>\n",
" <td>1.0</td>\n",
" <td>4.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203253</th>\n",
" <td>1.00</td>\n",
" <td>0.97</td>\n",
" <td>2.0</td>\n",
" <td>45.0</td>\n",
" <td>1.0</td>\n",
" <td>3.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203254</th>\n",
" <td>1.00</td>\n",
" <td>0.97</td>\n",
" <td>2.0</td>\n",
" <td>40.0</td>\n",
" <td>1.0</td>\n",
" <td>1.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203276</th>\n",
" <td>0.99</td>\n",
" <td>0.99</td>\n",
" <td>2.0</td>\n",
" <td>43.0</td>\n",
" <td>1.0</td>\n",
" <td>5.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203308</th>\n",
" <td>1.00</td>\n",
" <td>1.00</td>\n",
" <td>3.0</td>\n",
" <td>110.0</td>\n",
" <td>1.0</td>\n",
" <td>5.00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>134835 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" host_response_rate host_acceptance_rate accommodates price \\\n",
"0 1.00 0.33 2.0 120.0 \n",
"1 1.00 0.98 2.0 90.0 \n",
"2 1.00 0.98 2.0 66.0 \n",
"3 1.00 0.98 1.0 33.0 \n",
"5 1.00 1.00 2.0 45.0 \n",
"... ... ... ... ... \n",
"203252 1.00 0.93 4.0 152.0 \n",
"203253 1.00 0.97 2.0 45.0 \n",
"203254 1.00 0.97 2.0 40.0 \n",
"203276 0.99 0.99 2.0 43.0 \n",
"203308 1.00 1.00 3.0 110.0 \n",
"\n",
" number_of_reviews review_scores_rating \n",
"0 90.0 4.50 \n",
"1 351.0 4.58 \n",
"2 67.0 4.52 \n",
"3 297.0 4.70 \n",
"5 42.0 4.98 \n",
"... ... ... \n",
"203252 1.0 4.00 \n",
"203253 1.0 3.00 \n",
"203254 1.0 1.00 \n",
"203276 1.0 5.00 \n",
"203308 1.0 5.00 \n",
"\n",
"[134835 rows x 6 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"variables = ['number_of_reviews', 'price', 'accommodates',\n",
" 'host_response_rate', 'host_acceptance_rate', 'review_scores_rating']\n",
"df = pd.read_csv('../data/2022-01(US_25).csv', usecols=variables)\n",
"df['price'] = df['price'].replace('\\$', '', regex=True)\n",
"df['price'] = df['price'].replace('\\,', '', regex=True).astype(float)\n",
"df[['host_response_rate', 'host_acceptance_rate']] = df[['host_response_rate',\n",
" 'host_acceptance_rate']].replace('\\%', '', regex=True).astype(float)*0.01\n",
"df[['number_of_reviews']] = df[['number_of_reviews']].astype(float)\n",
"for col in variables:\n",
" df[col] = df[col].astype(np.float32)\n",
" df = df[np.isnan(df[col]) != 1]\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "35569391-e849-4a5e-adf5-25dfd849b2a0",
"metadata": {},
"outputs": [],
"source": [
"# 固定划分训练集和测试集\n",
"info = df.iloc[:, :-1].values\n",
"target = df.iloc[:, -1].values\n",
"# 标准化\n",
"stdscaler = StandardScaler()\n",
"info_train, info_test, target_train, target_test = train_test_split(\n",
" info, target, test_size=0.3,shuffle=True, random_state=420)\n",
"info_train = stdscaler.fit_transform(info_train)\n",
"info_test = stdscaler.transform(info_test)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7ca33f2e-bcc8-4234-a27b-fa7e1df04030",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE:0.2333\n"
]
}
],
"source": [
"# randomforest回归\n",
"rf = RandomForestRegressor(n_estimators=100, random_state=0)\n",
"rf.fit(info_train, target_train)\n",
"target_pred = rf.predict(info_test)\n",
"print(\"MSE:%.4f\" % mean_squared_error(target_test, target_pred))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}