From e289ff3c0c054c76a7f63ac7e07c5f8fbee5d849 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=8D=93=E7=AB=8B?= <13190677+zhang-zhuoli@user.noreply.gitee.com> Date: Sat, 15 Jul 2023 12:05:42 +0000 Subject: [PATCH] =?UTF-8?q?=E5=89=8D=E9=A6=88=E7=A5=9E=E7=BB=8F=E7=BD=91?= =?UTF-8?q?=E7=BB=9C=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 张卓立 <13190677+zhang-zhuoli@user.noreply.gitee.com> --- .../torch_reg.ipynb | 944 ++++++++++++++++++ 1 file changed, 944 insertions(+) create mode 100644 共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb diff --git a/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb b/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb new file mode 100644 index 0000000..632f392 --- /dev/null +++ b/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb @@ -0,0 +1,944 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a31d3c40-3593-4ca3-980e-578ee51e171a", + "metadata": {}, + "source": [ + "# 用神经网络进行回归预测" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "738c03b4-3ca0-4a87-9143-53ddea5179be", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.model_selection import train_test_split" + ] + }, + { + "cell_type": "markdown", + "id": "637063c9-5924-416c-8201-adaae8514ffd", + "metadata": {}, + "source": [ + "设置神经网络超参数,批大小为500,学习率为0.01,一共分别进行100次正向传播和反向传播" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b909615c-783d-4bfb-acf6-1bc3bacfbb18", + "metadata": {}, + "outputs": [], + "source": [ + "# torch参数\n", + "batch_size = 500\n", + "lr = 0.01\n", + "max_epochs = 100\n", + "num_workers = 0\n", + "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "904cfade-5667-440a-a13c-34fdf57f7b0f", + "metadata": {}, + "source": [ + "定义所需数据集类" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f6e65156-8535-49c2-8d9d-ecb8639855c9", + "metadata": {}, + "outputs": [], + "source": [ + "variables = ['number_of_reviews', 'price', 'accommodates',\n", + " 'host_response_rate', 'host_acceptance_rate', 'review_scores_rating']\n", + "# 数据集类\n", + "\n", + "\n", + "class USDataset(Dataset):\n", + " def __init__(self, df):\n", + " '''\n", + " 初始化\n", + " df: 处理后的数据集\n", + " '''\n", + " self.df = df\n", + " self.info = df[['number_of_reviews', 'price', 'accommodates',\n", + " 'host_response_rate', 'host_acceptance_rate']].values\n", + " self.target = df['review_scores_rating'].values\n", + "\n", + " def __getitem__(self, index):\n", + " '''\n", + " 根据编号返回信息\n", + " index: 样本编号\n", + " '''\n", + " info = self.info[index]\n", + " target = self.target[index]\n", + " return info, target\n", + "\n", + " def __len__(self):\n", + " '''\n", + " 返回数据集样本个数\n", + " '''\n", + " return len(self.df)" + ] + }, + { + "cell_type": "markdown", + "id": "fc574f2e-6443-48af-9fee-dcf5fb29af58", + "metadata": {}, + "source": [ + "读入数据,由于样本量很大,直接删除有缺失值的样本。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "17f8aedb-abf1-4622-98ec-d525a779274b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
host_response_ratehost_acceptance_rateaccommodatespricenumber_of_reviewsreview_scores_rating
01.000.332.0120.090.04.50
11.000.982.090.0351.04.58
21.000.982.066.067.04.52
31.000.981.033.0297.04.70
51.001.002.045.042.04.98
.....................
2032521.000.934.0152.01.04.00
2032531.000.972.045.01.03.00
2032541.000.972.040.01.01.00
2032760.990.992.043.01.05.00
2033081.001.003.0110.01.05.00
\n", + "

134835 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " host_response_rate host_acceptance_rate accommodates price \\\n", + "0 1.00 0.33 2.0 120.0 \n", + "1 1.00 0.98 2.0 90.0 \n", + "2 1.00 0.98 2.0 66.0 \n", + "3 1.00 0.98 1.0 33.0 \n", + "5 1.00 1.00 2.0 45.0 \n", + "... ... ... ... ... \n", + "203252 1.00 0.93 4.0 152.0 \n", + "203253 1.00 0.97 2.0 45.0 \n", + "203254 1.00 0.97 2.0 40.0 \n", + "203276 0.99 0.99 2.0 43.0 \n", + "203308 1.00 1.00 3.0 110.0 \n", + "\n", + " number_of_reviews review_scores_rating \n", + "0 90.0 4.50 \n", + "1 351.0 4.58 \n", + "2 67.0 4.52 \n", + "3 297.0 4.70 \n", + "5 42.0 4.98 \n", + "... ... ... \n", + "203252 1.0 4.00 \n", + "203253 1.0 3.00 \n", + "203254 1.0 1.00 \n", + "203276 1.0 5.00 \n", + "203308 1.0 5.00 \n", + "\n", + "[134835 rows x 6 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# 读入数据\n", + "df = pd.read_csv('../data/2022-01(US_25).csv', usecols=variables)\n", + "df['price'] = df['price'].replace('\\$', '', regex=True)\n", + "df['price'] = df['price'].replace('\\,', '', regex=True).astype(float)\n", + "df[['host_response_rate', 'host_acceptance_rate']] = df[['host_response_rate',\n", + " 'host_acceptance_rate']].replace('\\%', '', regex=True).astype(float)*0.01\n", + "df[['number_of_reviews']] = df[['number_of_reviews']].astype(float)\n", + "for col in variables:\n", + " df[col] = df[col].astype(np.float32)\n", + " df = df[np.isnan(df[col]) != 1]\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "758d8171-9310-4a51-b9a6-b0fa6274ffe1", + "metadata": {}, + "source": [ + "将数据标准化" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2cbe7cab-42cd-4b3b-bf11-080cc0a0f4c4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "df_train:\n", + " host_response_rate host_acceptance_rate accommodates price \\\n", + "0 0.297689 0.440210 -0.101531 -0.147944 \n", + "1 0.297689 0.440210 -0.101531 -0.272867 \n", + "2 0.297689 0.393235 -0.807638 -0.258453 \n", + "3 0.297689 0.534159 -0.101531 -0.366559 \n", + "4 0.297689 0.393235 -0.454584 -0.099897 \n", + "... ... ... ... ... \n", + "94379 0.297689 -0.499281 -0.101531 0.313307 \n", + "94380 0.297689 0.440210 2.016788 0.697684 \n", + "94381 0.297689 0.534159 0.604575 -0.222417 \n", + "94382 0.297689 0.252311 -0.807638 -0.301695 \n", + "94383 0.297689 -0.687179 -0.807638 -0.318512 \n", + "\n", + " number_of_reviews review_scores_rating \n", + "0 -0.629684 3.00 \n", + "1 0.060760 4.93 \n", + "2 -0.629684 4.00 \n", + "3 2.045787 4.91 \n", + "4 0.295018 4.83 \n", + "... ... ... \n", + "94379 -0.407756 5.00 \n", + "94380 -0.531050 5.00 \n", + "94381 0.036101 4.69 \n", + "94382 1.145743 4.85 \n", + "94383 0.726545 4.83 \n", + "\n", + "[94384 rows x 6 columns]\n", + "============================================\n", + "df_test:\n", + " host_response_rate host_acceptance_rate accommodates price \\\n", + "0 0.297689 0.581133 -1.160691 -0.392985 \n", + "1 0.297689 -0.969026 0.604575 0.714500 \n", + "2 0.297689 0.581133 -0.807638 -0.042241 \n", + "3 0.095805 -1.626670 0.604575 3.104842 \n", + "4 0.297689 -0.452307 -0.807638 -0.289683 \n", + "... ... ... ... ... \n", + "40446 0.095805 0.534159 -0.101531 -0.244039 \n", + "40447 -0.038785 0.534159 1.310681 1.353526 \n", + "40448 0.297689 -0.217434 -0.101531 0.197994 \n", + "40449 0.297689 0.440210 0.604575 0.358952 \n", + "40450 0.297689 0.581133 -0.807638 -0.159956 \n", + "\n", + " number_of_reviews review_scores_rating \n", + "0 0.726545 4.76 \n", + "1 -0.580367 4.60 \n", + "2 3.352699 4.92 \n", + "3 -0.617355 2.50 \n", + "4 -0.592696 5.00 \n", + "... ... ... \n", + "40446 -0.617355 5.00 \n", + "40447 -0.629684 5.00 \n", + "40448 -0.185828 4.97 \n", + "40449 -0.494061 4.67 \n", + "40450 0.368994 4.96 \n", + "\n", + "[40451 rows x 6 columns]\n" + ] + } + ], + "source": [ + "# 固定划分测试集和训练集\n", + "col_df = df.columns\n", + "info = df.iloc[:, :-1].values\n", + "target = df.iloc[:, -1].values\n", + "# 标准化\n", + "stdscaler = StandardScaler()\n", + "info_train, info_test, target_train, target_test = train_test_split(\n", + " info, target, test_size=0.3, random_state=420)\n", + "info_train = stdscaler.fit_transform(info_train)\n", + "info_test = stdscaler.transform(info_test)\n", + "df_train = np.hstack((info_train, target_train.reshape(-1, 1)))\n", + "df_test = np.hstack((info_test, target_test.reshape(-1, 1)))\n", + "df_train = pd.DataFrame(df_train, columns=col_df)\n", + "df_test = pd.DataFrame(df_test, columns=col_df)\n", + "print(\"df_train:\\n\", df_train)\n", + "print(\"============================================\")\n", + "print(\"df_test:\\n\", df_test)" + ] + }, + { + "cell_type": "markdown", + "id": "1a9661ee-3dc4-4d5c-8a8d-bab3dc299d41", + "metadata": {}, + "source": [ + "设置训练集和验证集" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a10021bb-4e08-49b2-8269-114b8ce37be6", + "metadata": {}, + "outputs": [], + "source": [ + "# 初始化数据集\n", + "train_data = USDataset(df_train)\n", + "test_data = USDataset(df_test)\n", + "train_loader = DataLoader(train_data, batch_size=batch_size,\n", + " num_workers=num_workers, shuffle=True, drop_last=True)\n", + "test_loader = DataLoader(test_data, batch_size=batch_size,\n", + " num_workers=num_workers, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "id": "41a15589-9aaf-4541-a603-30a9f433c385", + "metadata": {}, + "source": [ + "设计并初始化网络,其中神经网络是有一个输入层,一个输出层,三个隐藏层的前馈神经网络,激活函数是Sigmoid函数" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "764f28bc-81dd-4924-9039-5eceba54d739", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Parameter containing:\n", + "tensor([[ 0.1981, -0.0925, 0.1532, -0.2224, 0.1906],\n", + " [-0.4302, -0.2503, -0.1189, 0.4182, -0.3575],\n", + " [ 0.1431, 0.3615, 0.2464, -0.0292, -0.3012],\n", + " [-0.2803, 0.2226, 0.2143, 0.2458, -0.2719],\n", + " [ 0.1669, 0.1216, -0.0148, 0.2795, -0.2338],\n", + " [-0.0047, 0.1909, -0.1634, 0.3766, 0.3018],\n", + " [ 0.4288, 0.3602, -0.3976, -0.1180, 0.3682],\n", + " [-0.1992, 0.1626, 0.3656, -0.4359, 0.2141],\n", + " [-0.3708, 0.1608, 0.0614, 0.0473, -0.3628],\n", + " [-0.3157, 0.1828, 0.3053, -0.2029, -0.3746],\n", + " [ 0.0809, -0.1369, 0.2581, 0.4280, 0.2290],\n", + " [-0.3749, -0.1776, 0.1791, -0.3965, -0.1430],\n", + " [ 0.3633, 0.1174, -0.4268, 0.1386, 0.3422],\n", + " [ 0.3554, -0.3894, -0.4382, 0.2104, 0.4403],\n", + " [-0.0714, -0.1196, -0.0719, 0.0933, 0.1312]], device='cuda:0',\n", + " requires_grad=True), Parameter containing:\n", + "tensor([-0.3285, 0.3194, 0.4210, 0.3921, 0.2482, 0.0288, -0.2423, 0.3984,\n", + " -0.4366, -0.0518, 0.2578, -0.2965, 0.3803, -0.1613, -0.3905],\n", + " device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([[ 0.2330, 0.2382, 0.1661, -0.0358, -0.2211, -0.2119, 0.2325, 0.2296,\n", + " 0.1116, -0.1443, -0.1474, 0.1758, 0.1321, -0.1886, 0.1001],\n", + " [ 0.1013, -0.1303, -0.1558, -0.1694, -0.0610, 0.0903, 0.2512, 0.2028,\n", + " -0.0437, -0.0142, -0.0613, -0.1092, -0.2330, 0.1156, -0.0872],\n", + " [-0.2243, 0.1081, 0.1204, -0.1407, 0.2454, 0.1751, -0.0659, -0.0778,\n", + " 0.1699, 0.2420, 0.1921, -0.1972, 0.0297, -0.1185, 0.1182],\n", + " [ 0.1126, -0.1653, -0.2057, -0.0083, 0.0714, -0.1995, 0.1577, -0.2197,\n", + " -0.0100, -0.1859, -0.2013, -0.2470, -0.1084, -0.0784, -0.0567],\n", + " [-0.2133, 0.0452, 0.0957, 0.1400, -0.1520, -0.0957, 0.1586, 0.0450,\n", + " -0.1405, -0.1795, -0.1409, 0.2279, -0.1048, -0.0113, 0.1268],\n", + " [-0.2536, -0.1572, 0.2436, -0.1054, -0.2366, -0.1888, 0.0218, -0.0935,\n", + " -0.1377, 0.0630, 0.0064, -0.0734, 0.0873, 0.1178, -0.1568],\n", + " [-0.0283, 0.2286, 0.2539, 0.1228, 0.2385, -0.2216, -0.1502, -0.2049,\n", + " 0.0286, 0.1939, 0.0313, 0.1531, 0.2514, -0.2186, 0.0171],\n", + " [ 0.1501, 0.0752, -0.1921, -0.1191, -0.0748, 0.2097, 0.2306, -0.2150,\n", + " -0.2348, 0.0575, -0.0489, 0.0678, -0.1273, -0.1990, -0.0770],\n", + " [-0.0807, -0.0379, -0.1252, -0.1873, -0.0793, 0.0908, 0.1573, 0.1212,\n", + " 0.0937, 0.1422, -0.1502, -0.0369, 0.0321, -0.0830, 0.0571],\n", + " [-0.1586, -0.0179, 0.2351, 0.1576, 0.0692, -0.0726, -0.0875, -0.0382,\n", + " -0.2227, -0.0080, 0.0904, 0.1201, 0.2029, 0.0351, 0.0133]],\n", + " device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([-0.0141, 0.0566, -0.2259, -0.1533, -0.0400, 0.0475, -0.1080, -0.2453,\n", + " -0.1625, 0.1345], device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([[-0.2120, -0.1994, -0.0052, 0.2988, -0.2471, 0.0135, 0.1283, -0.0201,\n", + " 0.1182, -0.1972],\n", + " [-0.0328, 0.0113, 0.1010, 0.0589, -0.1486, 0.2598, 0.1771, 0.0474,\n", + " -0.0413, 0.1537],\n", + " [ 0.2083, -0.1183, -0.2833, -0.3092, 0.2081, 0.2566, 0.1134, -0.1159,\n", + " -0.2981, -0.2882],\n", + " [-0.1033, -0.2929, -0.2451, -0.2850, 0.3157, 0.3092, 0.0539, -0.2594,\n", + " 0.1327, -0.0194],\n", + " [ 0.1040, -0.3053, -0.1769, -0.2137, -0.0262, 0.2108, 0.0255, -0.1202,\n", + " 0.0413, 0.2326],\n", + " [-0.0355, -0.2424, 0.1064, 0.0394, -0.1730, 0.2130, 0.1174, -0.1641,\n", + " -0.1420, -0.0496]], device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([ 0.2195, -0.0579, 0.1234, 0.1060, -0.2140, 0.0288], device='cuda:0',\n", + " requires_grad=True), Parameter containing:\n", + "tensor([[ 0.2707, 0.1434, 0.0318, -0.0551, -0.1369, -0.1496]],\n", + " device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([-0.0661], device='cuda:0', requires_grad=True)]\n", + "==================================================\n", + "[Parameter containing:\n", + "tensor([[-6.9726e-02, 1.7731e+00, -7.6372e-02, -7.1058e-01, -7.1123e-04],\n", + " [-7.1336e-01, -5.4361e-01, -2.6582e-01, 5.6964e-01, -4.9271e-01],\n", + " [-1.3366e+00, -6.1715e-01, -8.0915e-01, -1.0871e+00, -1.7130e+00],\n", + " [ 1.1439e+00, 1.1702e+00, -5.0160e-02, -6.6409e-01, -9.0478e-02],\n", + " [-1.2307e-01, 9.7143e-01, -3.2899e-01, 1.0481e+00, 8.7611e-02],\n", + " [-3.6240e-01, 6.6012e-01, -1.7746e+00, 1.7587e-01, 3.9182e-01],\n", + " [ 9.5627e-01, -1.4544e+00, 8.6239e-01, 6.4864e-01, 1.2865e+00],\n", + " [-9.4586e-01, -2.7911e+00, 1.3601e+00, -8.6802e-01, 3.7288e-01],\n", + " [ 6.0642e-01, 8.3700e-01, 4.6751e-01, 9.6069e-01, 1.0403e+00],\n", + " [-2.4263e-01, -1.1883e+00, -1.1396e+00, 1.4794e+00, -6.1805e-01],\n", + " [-4.5019e-01, 3.1949e-01, 2.2522e-01, -1.4746e+00, -1.5766e+00],\n", + " [-1.0685e+00, -8.6054e-01, -4.0435e-01, 1.1669e-01, -5.0217e-01],\n", + " [ 3.2439e-01, -6.3227e-01, -9.7261e-01, 4.4522e-01, -1.1293e+00],\n", + " [-7.5776e-01, -5.1496e-01, 1.3075e+00, -3.3717e-01, 3.9504e-01],\n", + " [ 8.5924e-01, -1.0643e-01, -1.2336e+00, 1.7424e+00, -3.5817e-01]],\n", + " device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([[-0.5137, -0.3243, 0.4541, 0.4884, -0.3226, -0.9424, 1.2633, -0.5411,\n", + " -0.4047, -0.2105, -1.2065, 0.1194, -0.8706, -0.3806, -0.3470],\n", + " [ 1.1079, -0.6157, 0.0802, 0.1673, -1.1413, 0.3360, 2.0812, -0.4530,\n", + " 1.9244, -0.1222, 0.8667, -0.4083, -1.5679, 1.3790, -0.7961],\n", + " [-1.4472, 0.9576, 0.8538, 0.3888, -2.0439, -0.0737, -1.1856, -0.7158,\n", + " 2.7370, 2.5444, -0.7023, 1.4209, 0.0901, 0.8644, 0.7805],\n", + " [-0.0899, 0.7839, -0.3986, -0.0622, 0.5134, 0.8943, -0.7492, -1.7582,\n", + " 0.7121, 1.2855, -0.4851, 0.0736, -0.2256, 0.4992, -0.7854],\n", + " [ 2.5287, 0.5654, 0.7835, -0.6008, -1.9796, 1.6482, -1.0928, -0.0595,\n", + " -0.0564, -0.0729, -0.2910, -0.9512, -1.3852, -0.2895, -0.8505],\n", + " [-1.4534, -1.3051, 0.1817, 0.5341, 1.1192, -0.0837, -0.5133, -1.2826,\n", + " 1.5556, -1.0946, -1.0002, 0.3343, 0.1432, 1.8591, 0.0199],\n", + " [ 1.2811, -1.2248, 0.5370, 2.1831, 1.1538, 0.9405, 1.3957, 0.2959,\n", + " 1.1088, 1.6196, -0.1388, -0.0548, -0.2318, 0.1342, -0.1255],\n", + " [ 1.2155, -0.1889, -0.3502, -0.3293, 0.3099, 2.5889, 1.1579, 0.2868,\n", + " -0.1206, 0.3678, 0.2955, 0.7390, -1.1026, -0.5398, 0.0348],\n", + " [-0.9189, 0.8586, 1.3321, -0.7081, 0.0048, 0.7643, -0.5209, 0.2017,\n", + " 0.8094, -1.4677, 1.0291, 0.1726, 0.3298, 1.6261, -0.1650],\n", + " [ 0.7050, -0.3248, -0.2147, -0.4022, -0.3151, 0.9486, 1.3281, -0.2464,\n", + " -0.2804, 1.7853, -0.8694, -1.3244, 0.4455, 0.6532, 1.1969]],\n", + " device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',\n", + " requires_grad=True), Parameter containing:\n", + "tensor([[ 0.7268, -0.5613, 0.3554, -0.7328, 1.7000, -0.1115, -0.4507, 0.0304,\n", + " 0.5701, 1.0746],\n", + " [-0.7969, 0.7418, 1.5317, -1.0003, -0.2625, 1.7585, -3.1004, -1.2840,\n", + " -0.2230, 0.9398],\n", + " [-0.2167, -0.4646, -0.1701, -1.1754, 1.2224, -0.1925, -1.0030, -1.3171,\n", + " 0.7915, -1.7573],\n", + " [ 1.1589, -0.5555, 0.6436, -0.1423, -0.0412, -0.2644, 0.0179, 1.9647,\n", + " -1.0822, 0.4064],\n", + " [ 0.2140, -0.6242, 1.9719, -1.6243, -0.7416, 1.2082, 0.9993, 2.4522,\n", + " 1.2368, -0.9525],\n", + " [ 0.9530, 0.5625, -0.6094, -0.3213, 0.6820, -0.6440, 1.5492, -0.4234,\n", + " -0.3488, -0.1487]], device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([0., 0., 0., 0., 0., 0.], device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([[-1.2324, -0.2007, -1.9850, -1.1963, 0.4371, -0.8892]],\n", + " device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([0.], device='cuda:0', requires_grad=True)]\n" + ] + } + ], + "source": [ + "def initialize(self):\n", + " '''\n", + " 初始化网络参数\n", + " '''\n", + " for m in self.modules():\n", + " if isinstance(m, nn.Linear):\n", + " torch.nn.init.normal_(m.weight.data, 0.1)\n", + " if m.bias is not None:\n", + " torch.nn.init.zeros_(m.bias.data)\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " '''\n", + " 网络结构\n", + " '''\n", + "\n", + " def __init__(self, **kwargs):\n", + " super(Net, self).__init__()\n", + " self.fc = nn.Sequential(\n", + " nn.Linear(5, 15),\n", + " nn.Sigmoid(),\n", + " nn.Linear(15, 10),\n", + " nn.Sigmoid(),\n", + " nn.Linear(10, 6),\n", + " nn.Sigmoid(),\n", + " nn.Linear(6, 1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = x.view(-1, 5)\n", + " x = self.fc(x)\n", + " return x.squeeze(-1)\n", + "\n", + "\n", + "# 初始化网络\n", + "model = Net()\n", + "model = model.cuda()\n", + "print(list(model.parameters()))\n", + "initialize(model)\n", + "print(\"==================================================\")\n", + "print(list(model.parameters()))" + ] + }, + { + "cell_type": "markdown", + "id": "20fcccfe-3f82-44c3-bba2-9c66fb15cecd", + "metadata": {}, + "source": [ + "定义损失函数为均方误差损失函数,优化算法为随机梯度下降" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7b530a85-50b1-4d0f-b47e-14831d0ee85d", + "metadata": {}, + "outputs": [], + "source": [ + "# 损失函数\n", + "criterion = nn.MSELoss()\n", + "# 优化器\n", + "optimizer = optim.SGD(model.parameters(), lr=lr)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8694737f-afb5-4e58-ad14-3220971b142f", + "metadata": {}, + "outputs": [], + "source": [ + "def train(epoch):\n", + " '''\n", + " 训练器\n", + "\n", + " Parameters\n", + " ----------\n", + " epoch : int\n", + "\n", + " Returns\n", + " -------\n", + " None.\n", + "\n", + " '''\n", + " model.train()\n", + " train_loss = 0\n", + " for info, target in train_loader:\n", + " info = info.cuda()\n", + " target = target.cuda()\n", + " optimizer.zero_grad()\n", + " output = model(info)\n", + " loss = criterion(output, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " train_loss += loss.item()*info.size(0)\n", + " train_loss = train_loss/len(train_loader.dataset)\n", + " print(\"epoch:%d, train loss:%.4f\" % (epoch, train_loss))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "bb6be6f8-16e6-4c2b-a56c-232b09faa654", + "metadata": {}, + "outputs": [], + "source": [ + "def validate(epoch):\n", + " '''\n", + " 验证\n", + "\n", + " Parameters\n", + " ----------\n", + " epoch : int\n", + "\n", + " Returns\n", + " -------\n", + " None.\n", + "\n", + " '''\n", + " model.eval()\n", + " val_loss = 0\n", + " with torch.no_grad():\n", + " for info, target in test_loader:\n", + " info, target = info.cuda(), target.cuda()\n", + " output = model(info)\n", + " loss = criterion(output, target)\n", + " val_loss += loss.item()*info.size(0)\n", + " val_loss = val_loss/len(test_loader.dataset)\n", + " print(\"epoch:%d, validation loss:%.4f\" %\n", + " (epoch, val_loss))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "40532ea8-acfe-469a-aa9f-14d3bec2619d", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch:1, train loss:2.0057\n", + "epoch:1, validation loss:0.2197\n", + "epoch:2, train loss:0.2096\n", + "epoch:2, validation loss:0.2196\n", + "epoch:3, train loss:0.2098\n", + "epoch:3, validation loss:0.2196\n", + "epoch:4, train loss:0.2098\n", + "epoch:4, validation loss:0.2196\n", + "epoch:5, train loss:0.2100\n", + "epoch:5, validation loss:0.2195\n", + "epoch:6, train loss:0.2099\n", + "epoch:6, validation loss:0.2195\n", + "epoch:7, train loss:0.2101\n", + "epoch:7, validation loss:0.2195\n", + "epoch:8, train loss:0.2099\n", + "epoch:8, validation loss:0.2194\n", + "epoch:9, train loss:0.2099\n", + "epoch:9, validation loss:0.2194\n", + "epoch:10, train loss:0.2099\n", + "epoch:10, validation loss:0.2194\n", + "epoch:11, train loss:0.2093\n", + "epoch:11, validation loss:0.2193\n", + "epoch:12, train loss:0.2099\n", + "epoch:12, validation loss:0.2193\n", + "epoch:13, train loss:0.2100\n", + "epoch:13, validation loss:0.2193\n", + "epoch:14, train loss:0.2097\n", + "epoch:14, validation loss:0.2193\n", + "epoch:15, train loss:0.2094\n", + "epoch:15, validation loss:0.2192\n", + "epoch:16, train loss:0.2091\n", + "epoch:16, validation loss:0.2192\n", + "epoch:17, train loss:0.2098\n", + "epoch:17, validation loss:0.2192\n", + "epoch:18, train loss:0.2089\n", + "epoch:18, validation loss:0.2191\n", + "epoch:19, train loss:0.2091\n", + "epoch:19, validation loss:0.2191\n", + "epoch:20, train loss:0.2096\n", + "epoch:20, validation loss:0.2190\n", + "epoch:21, train loss:0.2092\n", + "epoch:21, validation loss:0.2190\n", + "epoch:22, train loss:0.2084\n", + "epoch:22, validation loss:0.2190\n", + "epoch:23, train loss:0.2093\n", + "epoch:23, validation loss:0.2190\n", + "epoch:24, train loss:0.2088\n", + "epoch:24, validation loss:0.2190\n", + "epoch:25, train loss:0.2088\n", + "epoch:25, validation loss:0.2189\n", + "epoch:26, train loss:0.2090\n", + "epoch:26, validation loss:0.2189\n", + "epoch:27, train loss:0.2092\n", + "epoch:27, validation loss:0.2189\n", + "epoch:28, train loss:0.2086\n", + "epoch:28, validation loss:0.2188\n", + "epoch:29, train loss:0.2087\n", + "epoch:29, validation loss:0.2189\n", + "epoch:30, train loss:0.2089\n", + "epoch:30, validation loss:0.2188\n", + "epoch:31, train loss:0.2090\n", + "epoch:31, validation loss:0.2187\n", + "epoch:32, train loss:0.2087\n", + "epoch:32, validation loss:0.2187\n", + "epoch:33, train loss:0.2092\n", + "epoch:33, validation loss:0.2187\n", + "epoch:34, train loss:0.2092\n", + "epoch:34, validation loss:0.2187\n", + "epoch:35, train loss:0.2091\n", + "epoch:35, validation loss:0.2186\n", + "epoch:36, train loss:0.2089\n", + "epoch:36, validation loss:0.2186\n", + "epoch:37, train loss:0.2086\n", + "epoch:37, validation loss:0.2186\n", + "epoch:38, train loss:0.2090\n", + "epoch:38, validation loss:0.2186\n", + "epoch:39, train loss:0.2088\n", + "epoch:39, validation loss:0.2185\n", + "epoch:40, train loss:0.2087\n", + "epoch:40, validation loss:0.2185\n", + "epoch:41, train loss:0.2087\n", + "epoch:41, validation loss:0.2186\n", + "epoch:42, train loss:0.2085\n", + "epoch:42, validation loss:0.2185\n", + "epoch:43, train loss:0.2088\n", + "epoch:43, validation loss:0.2184\n", + "epoch:44, train loss:0.2089\n", + "epoch:44, validation loss:0.2184\n", + "epoch:45, train loss:0.2089\n", + "epoch:45, validation loss:0.2184\n", + "epoch:46, train loss:0.2086\n", + "epoch:46, validation loss:0.2184\n", + "epoch:47, train loss:0.2089\n", + "epoch:47, validation loss:0.2183\n", + "epoch:48, train loss:0.2084\n", + "epoch:48, validation loss:0.2183\n", + "epoch:49, train loss:0.2080\n", + "epoch:49, validation loss:0.2183\n", + "epoch:50, train loss:0.2088\n", + "epoch:50, validation loss:0.2183\n", + "epoch:51, train loss:0.2086\n", + "epoch:51, validation loss:0.2183\n", + "epoch:52, train loss:0.2082\n", + "epoch:52, validation loss:0.2182\n", + "epoch:53, train loss:0.2080\n", + "epoch:53, validation loss:0.2182\n", + "epoch:54, train loss:0.2088\n", + "epoch:54, validation loss:0.2182\n", + "epoch:55, train loss:0.2086\n", + "epoch:55, validation loss:0.2181\n", + "epoch:56, train loss:0.2082\n", + "epoch:56, validation loss:0.2181\n", + "epoch:57, train loss:0.2088\n", + "epoch:57, validation loss:0.2181\n", + "epoch:58, train loss:0.2084\n", + "epoch:58, validation loss:0.2181\n", + "epoch:59, train loss:0.2085\n", + "epoch:59, validation loss:0.2180\n", + "epoch:60, train loss:0.2085\n", + "epoch:60, validation loss:0.2180\n", + "epoch:61, train loss:0.2084\n", + "epoch:61, validation loss:0.2180\n", + "epoch:62, train loss:0.2082\n", + "epoch:62, validation loss:0.2180\n", + "epoch:63, train loss:0.2081\n", + "epoch:63, validation loss:0.2180\n", + "epoch:64, train loss:0.2081\n", + "epoch:64, validation loss:0.2179\n", + "epoch:65, train loss:0.2078\n", + "epoch:65, validation loss:0.2179\n", + "epoch:66, train loss:0.2079\n", + "epoch:66, validation loss:0.2179\n", + "epoch:67, train loss:0.2080\n", + "epoch:67, validation loss:0.2179\n", + "epoch:68, train loss:0.2080\n", + "epoch:68, validation loss:0.2178\n", + "epoch:69, train loss:0.2082\n", + "epoch:69, validation loss:0.2178\n", + "epoch:70, train loss:0.2079\n", + "epoch:70, validation loss:0.2178\n", + "epoch:71, train loss:0.2083\n", + "epoch:71, validation loss:0.2178\n", + "epoch:72, train loss:0.2081\n", + "epoch:72, validation loss:0.2177\n", + "epoch:73, train loss:0.2079\n", + "epoch:73, validation loss:0.2177\n", + "epoch:74, train loss:0.2076\n", + "epoch:74, validation loss:0.2177\n", + "epoch:75, train loss:0.2081\n", + "epoch:75, validation loss:0.2177\n", + "epoch:76, train loss:0.2081\n", + "epoch:76, validation loss:0.2176\n", + "epoch:77, train loss:0.2080\n", + "epoch:77, validation loss:0.2176\n", + "epoch:78, train loss:0.2079\n", + "epoch:78, validation loss:0.2176\n", + "epoch:79, train loss:0.2081\n", + "epoch:79, validation loss:0.2176\n", + "epoch:80, train loss:0.2078\n", + "epoch:80, validation loss:0.2175\n", + "epoch:81, train loss:0.2076\n", + "epoch:81, validation loss:0.2175\n", + "epoch:82, train loss:0.2077\n", + "epoch:82, validation loss:0.2175\n", + "epoch:83, train loss:0.2077\n", + "epoch:83, validation loss:0.2175\n", + "epoch:84, train loss:0.2080\n", + "epoch:84, validation loss:0.2175\n", + "epoch:85, train loss:0.2074\n", + "epoch:85, validation loss:0.2174\n", + "epoch:86, train loss:0.2077\n", + "epoch:86, validation loss:0.2174\n", + "epoch:87, train loss:0.2079\n", + "epoch:87, validation loss:0.2174\n", + "epoch:88, train loss:0.2079\n", + "epoch:88, validation loss:0.2173\n", + "epoch:89, train loss:0.2075\n", + "epoch:89, validation loss:0.2173\n", + "epoch:90, train loss:0.2075\n", + "epoch:90, validation loss:0.2173\n", + "epoch:91, train loss:0.2070\n", + "epoch:91, validation loss:0.2173\n", + "epoch:92, train loss:0.2079\n", + "epoch:92, validation loss:0.2173\n", + "epoch:93, train loss:0.2075\n", + "epoch:93, validation loss:0.2172\n", + "epoch:94, train loss:0.2078\n", + "epoch:94, validation loss:0.2172\n", + "epoch:95, train loss:0.2068\n", + "epoch:95, validation loss:0.2172\n", + "epoch:96, train loss:0.2072\n", + "epoch:96, validation loss:0.2172\n", + "epoch:97, train loss:0.2076\n", + "epoch:97, validation loss:0.2172\n", + "epoch:98, train loss:0.2079\n", + "epoch:98, validation loss:0.2171\n", + "epoch:99, train loss:0.2074\n", + "epoch:99, validation loss:0.2171\n", + "epoch:100, train loss:0.2074\n", + "epoch:100, validation loss:0.2171\n" + ] + } + ], + "source": [ + "# 训练过程\n", + "for epoch in range(1, max_epochs+1):\n", + " train(epoch)\n", + " validate(epoch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3f438d3-813f-4db5-b8e9-da71d54730d4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}