From e289ff3c0c054c76a7f63ac7e07c5f8fbee5d849 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=BC=A0=E5=8D=93=E7=AB=8B?=
<13190677+zhang-zhuoli@user.noreply.gitee.com>
Date: Sat, 15 Jul 2023 12:05:42 +0000
Subject: [PATCH] =?UTF-8?q?=E5=89=8D=E9=A6=88=E7=A5=9E=E7=BB=8F=E7=BD=91?=
=?UTF-8?q?=E7=BB=9C=E5=9B=9E=E5=BD=92?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: 张卓立 <13190677+zhang-zhuoli@user.noreply.gitee.com>
---
.../torch_reg.ipynb | 944 ++++++++++++++++++
1 file changed, 944 insertions(+)
create mode 100644 共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb
diff --git a/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb b/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb
new file mode 100644
index 0000000..632f392
--- /dev/null
+++ b/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb
@@ -0,0 +1,944 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "a31d3c40-3593-4ca3-980e-578ee51e171a",
+ "metadata": {},
+ "source": [
+ "# 用神经网络进行回归预测"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "738c03b4-3ca0-4a87-9143-53ddea5179be",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.model_selection import train_test_split"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "637063c9-5924-416c-8201-adaae8514ffd",
+ "metadata": {},
+ "source": [
+ "设置神经网络超参数,批大小为500,学习率为0.01,一共分别进行100次正向传播和反向传播"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "b909615c-783d-4bfb-acf6-1bc3bacfbb18",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# torch参数\n",
+ "batch_size = 500\n",
+ "lr = 0.01\n",
+ "max_epochs = 100\n",
+ "num_workers = 0\n",
+ "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "904cfade-5667-440a-a13c-34fdf57f7b0f",
+ "metadata": {},
+ "source": [
+ "定义所需数据集类"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "f6e65156-8535-49c2-8d9d-ecb8639855c9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "variables = ['number_of_reviews', 'price', 'accommodates',\n",
+ " 'host_response_rate', 'host_acceptance_rate', 'review_scores_rating']\n",
+ "# 数据集类\n",
+ "\n",
+ "\n",
+ "class USDataset(Dataset):\n",
+ " def __init__(self, df):\n",
+ " '''\n",
+ " 初始化\n",
+ " df: 处理后的数据集\n",
+ " '''\n",
+ " self.df = df\n",
+ " self.info = df[['number_of_reviews', 'price', 'accommodates',\n",
+ " 'host_response_rate', 'host_acceptance_rate']].values\n",
+ " self.target = df['review_scores_rating'].values\n",
+ "\n",
+ " def __getitem__(self, index):\n",
+ " '''\n",
+ " 根据编号返回信息\n",
+ " index: 样本编号\n",
+ " '''\n",
+ " info = self.info[index]\n",
+ " target = self.target[index]\n",
+ " return info, target\n",
+ "\n",
+ " def __len__(self):\n",
+ " '''\n",
+ " 返回数据集样本个数\n",
+ " '''\n",
+ " return len(self.df)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fc574f2e-6443-48af-9fee-dcf5fb29af58",
+ "metadata": {},
+ "source": [
+ "读入数据,由于样本量很大,直接删除有缺失值的样本。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "17f8aedb-abf1-4622-98ec-d525a779274b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " host_response_rate | \n",
+ " host_acceptance_rate | \n",
+ " accommodates | \n",
+ " price | \n",
+ " number_of_reviews | \n",
+ " review_scores_rating | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 1.00 | \n",
+ " 0.33 | \n",
+ " 2.0 | \n",
+ " 120.0 | \n",
+ " 90.0 | \n",
+ " 4.50 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 1.00 | \n",
+ " 0.98 | \n",
+ " 2.0 | \n",
+ " 90.0 | \n",
+ " 351.0 | \n",
+ " 4.58 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1.00 | \n",
+ " 0.98 | \n",
+ " 2.0 | \n",
+ " 66.0 | \n",
+ " 67.0 | \n",
+ " 4.52 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 1.00 | \n",
+ " 0.98 | \n",
+ " 1.0 | \n",
+ " 33.0 | \n",
+ " 297.0 | \n",
+ " 4.70 | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 1.00 | \n",
+ " 1.00 | \n",
+ " 2.0 | \n",
+ " 45.0 | \n",
+ " 42.0 | \n",
+ " 4.98 | \n",
+ "
\n",
+ " \n",
+ " | ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " | 203252 | \n",
+ " 1.00 | \n",
+ " 0.93 | \n",
+ " 4.0 | \n",
+ " 152.0 | \n",
+ " 1.0 | \n",
+ " 4.00 | \n",
+ "
\n",
+ " \n",
+ " | 203253 | \n",
+ " 1.00 | \n",
+ " 0.97 | \n",
+ " 2.0 | \n",
+ " 45.0 | \n",
+ " 1.0 | \n",
+ " 3.00 | \n",
+ "
\n",
+ " \n",
+ " | 203254 | \n",
+ " 1.00 | \n",
+ " 0.97 | \n",
+ " 2.0 | \n",
+ " 40.0 | \n",
+ " 1.0 | \n",
+ " 1.00 | \n",
+ "
\n",
+ " \n",
+ " | 203276 | \n",
+ " 0.99 | \n",
+ " 0.99 | \n",
+ " 2.0 | \n",
+ " 43.0 | \n",
+ " 1.0 | \n",
+ " 5.00 | \n",
+ "
\n",
+ " \n",
+ " | 203308 | \n",
+ " 1.00 | \n",
+ " 1.00 | \n",
+ " 3.0 | \n",
+ " 110.0 | \n",
+ " 1.0 | \n",
+ " 5.00 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
134835 rows × 6 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " host_response_rate host_acceptance_rate accommodates price \\\n",
+ "0 1.00 0.33 2.0 120.0 \n",
+ "1 1.00 0.98 2.0 90.0 \n",
+ "2 1.00 0.98 2.0 66.0 \n",
+ "3 1.00 0.98 1.0 33.0 \n",
+ "5 1.00 1.00 2.0 45.0 \n",
+ "... ... ... ... ... \n",
+ "203252 1.00 0.93 4.0 152.0 \n",
+ "203253 1.00 0.97 2.0 45.0 \n",
+ "203254 1.00 0.97 2.0 40.0 \n",
+ "203276 0.99 0.99 2.0 43.0 \n",
+ "203308 1.00 1.00 3.0 110.0 \n",
+ "\n",
+ " number_of_reviews review_scores_rating \n",
+ "0 90.0 4.50 \n",
+ "1 351.0 4.58 \n",
+ "2 67.0 4.52 \n",
+ "3 297.0 4.70 \n",
+ "5 42.0 4.98 \n",
+ "... ... ... \n",
+ "203252 1.0 4.00 \n",
+ "203253 1.0 3.00 \n",
+ "203254 1.0 1.00 \n",
+ "203276 1.0 5.00 \n",
+ "203308 1.0 5.00 \n",
+ "\n",
+ "[134835 rows x 6 columns]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# 读入数据\n",
+ "df = pd.read_csv('../data/2022-01(US_25).csv', usecols=variables)\n",
+ "df['price'] = df['price'].replace('\\$', '', regex=True)\n",
+ "df['price'] = df['price'].replace('\\,', '', regex=True).astype(float)\n",
+ "df[['host_response_rate', 'host_acceptance_rate']] = df[['host_response_rate',\n",
+ " 'host_acceptance_rate']].replace('\\%', '', regex=True).astype(float)*0.01\n",
+ "df[['number_of_reviews']] = df[['number_of_reviews']].astype(float)\n",
+ "for col in variables:\n",
+ " df[col] = df[col].astype(np.float32)\n",
+ " df = df[np.isnan(df[col]) != 1]\n",
+ "df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "758d8171-9310-4a51-b9a6-b0fa6274ffe1",
+ "metadata": {},
+ "source": [
+ "将数据标准化"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "2cbe7cab-42cd-4b3b-bf11-080cc0a0f4c4",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "df_train:\n",
+ " host_response_rate host_acceptance_rate accommodates price \\\n",
+ "0 0.297689 0.440210 -0.101531 -0.147944 \n",
+ "1 0.297689 0.440210 -0.101531 -0.272867 \n",
+ "2 0.297689 0.393235 -0.807638 -0.258453 \n",
+ "3 0.297689 0.534159 -0.101531 -0.366559 \n",
+ "4 0.297689 0.393235 -0.454584 -0.099897 \n",
+ "... ... ... ... ... \n",
+ "94379 0.297689 -0.499281 -0.101531 0.313307 \n",
+ "94380 0.297689 0.440210 2.016788 0.697684 \n",
+ "94381 0.297689 0.534159 0.604575 -0.222417 \n",
+ "94382 0.297689 0.252311 -0.807638 -0.301695 \n",
+ "94383 0.297689 -0.687179 -0.807638 -0.318512 \n",
+ "\n",
+ " number_of_reviews review_scores_rating \n",
+ "0 -0.629684 3.00 \n",
+ "1 0.060760 4.93 \n",
+ "2 -0.629684 4.00 \n",
+ "3 2.045787 4.91 \n",
+ "4 0.295018 4.83 \n",
+ "... ... ... \n",
+ "94379 -0.407756 5.00 \n",
+ "94380 -0.531050 5.00 \n",
+ "94381 0.036101 4.69 \n",
+ "94382 1.145743 4.85 \n",
+ "94383 0.726545 4.83 \n",
+ "\n",
+ "[94384 rows x 6 columns]\n",
+ "============================================\n",
+ "df_test:\n",
+ " host_response_rate host_acceptance_rate accommodates price \\\n",
+ "0 0.297689 0.581133 -1.160691 -0.392985 \n",
+ "1 0.297689 -0.969026 0.604575 0.714500 \n",
+ "2 0.297689 0.581133 -0.807638 -0.042241 \n",
+ "3 0.095805 -1.626670 0.604575 3.104842 \n",
+ "4 0.297689 -0.452307 -0.807638 -0.289683 \n",
+ "... ... ... ... ... \n",
+ "40446 0.095805 0.534159 -0.101531 -0.244039 \n",
+ "40447 -0.038785 0.534159 1.310681 1.353526 \n",
+ "40448 0.297689 -0.217434 -0.101531 0.197994 \n",
+ "40449 0.297689 0.440210 0.604575 0.358952 \n",
+ "40450 0.297689 0.581133 -0.807638 -0.159956 \n",
+ "\n",
+ " number_of_reviews review_scores_rating \n",
+ "0 0.726545 4.76 \n",
+ "1 -0.580367 4.60 \n",
+ "2 3.352699 4.92 \n",
+ "3 -0.617355 2.50 \n",
+ "4 -0.592696 5.00 \n",
+ "... ... ... \n",
+ "40446 -0.617355 5.00 \n",
+ "40447 -0.629684 5.00 \n",
+ "40448 -0.185828 4.97 \n",
+ "40449 -0.494061 4.67 \n",
+ "40450 0.368994 4.96 \n",
+ "\n",
+ "[40451 rows x 6 columns]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 固定划分测试集和训练集\n",
+ "col_df = df.columns\n",
+ "info = df.iloc[:, :-1].values\n",
+ "target = df.iloc[:, -1].values\n",
+ "# 标准化\n",
+ "stdscaler = StandardScaler()\n",
+ "info_train, info_test, target_train, target_test = train_test_split(\n",
+ " info, target, test_size=0.3, random_state=420)\n",
+ "info_train = stdscaler.fit_transform(info_train)\n",
+ "info_test = stdscaler.transform(info_test)\n",
+ "df_train = np.hstack((info_train, target_train.reshape(-1, 1)))\n",
+ "df_test = np.hstack((info_test, target_test.reshape(-1, 1)))\n",
+ "df_train = pd.DataFrame(df_train, columns=col_df)\n",
+ "df_test = pd.DataFrame(df_test, columns=col_df)\n",
+ "print(\"df_train:\\n\", df_train)\n",
+ "print(\"============================================\")\n",
+ "print(\"df_test:\\n\", df_test)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1a9661ee-3dc4-4d5c-8a8d-bab3dc299d41",
+ "metadata": {},
+ "source": [
+ "设置训练集和验证集"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "a10021bb-4e08-49b2-8269-114b8ce37be6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 初始化数据集\n",
+ "train_data = USDataset(df_train)\n",
+ "test_data = USDataset(df_test)\n",
+ "train_loader = DataLoader(train_data, batch_size=batch_size,\n",
+ " num_workers=num_workers, shuffle=True, drop_last=True)\n",
+ "test_loader = DataLoader(test_data, batch_size=batch_size,\n",
+ " num_workers=num_workers, shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "41a15589-9aaf-4541-a603-30a9f433c385",
+ "metadata": {},
+ "source": [
+ "设计并初始化网络,其中神经网络是有一个输入层,一个输出层,三个隐藏层的前馈神经网络,激活函数是Sigmoid函数"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "764f28bc-81dd-4924-9039-5eceba54d739",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[Parameter containing:\n",
+ "tensor([[ 0.1981, -0.0925, 0.1532, -0.2224, 0.1906],\n",
+ " [-0.4302, -0.2503, -0.1189, 0.4182, -0.3575],\n",
+ " [ 0.1431, 0.3615, 0.2464, -0.0292, -0.3012],\n",
+ " [-0.2803, 0.2226, 0.2143, 0.2458, -0.2719],\n",
+ " [ 0.1669, 0.1216, -0.0148, 0.2795, -0.2338],\n",
+ " [-0.0047, 0.1909, -0.1634, 0.3766, 0.3018],\n",
+ " [ 0.4288, 0.3602, -0.3976, -0.1180, 0.3682],\n",
+ " [-0.1992, 0.1626, 0.3656, -0.4359, 0.2141],\n",
+ " [-0.3708, 0.1608, 0.0614, 0.0473, -0.3628],\n",
+ " [-0.3157, 0.1828, 0.3053, -0.2029, -0.3746],\n",
+ " [ 0.0809, -0.1369, 0.2581, 0.4280, 0.2290],\n",
+ " [-0.3749, -0.1776, 0.1791, -0.3965, -0.1430],\n",
+ " [ 0.3633, 0.1174, -0.4268, 0.1386, 0.3422],\n",
+ " [ 0.3554, -0.3894, -0.4382, 0.2104, 0.4403],\n",
+ " [-0.0714, -0.1196, -0.0719, 0.0933, 0.1312]], device='cuda:0',\n",
+ " requires_grad=True), Parameter containing:\n",
+ "tensor([-0.3285, 0.3194, 0.4210, 0.3921, 0.2482, 0.0288, -0.2423, 0.3984,\n",
+ " -0.4366, -0.0518, 0.2578, -0.2965, 0.3803, -0.1613, -0.3905],\n",
+ " device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([[ 0.2330, 0.2382, 0.1661, -0.0358, -0.2211, -0.2119, 0.2325, 0.2296,\n",
+ " 0.1116, -0.1443, -0.1474, 0.1758, 0.1321, -0.1886, 0.1001],\n",
+ " [ 0.1013, -0.1303, -0.1558, -0.1694, -0.0610, 0.0903, 0.2512, 0.2028,\n",
+ " -0.0437, -0.0142, -0.0613, -0.1092, -0.2330, 0.1156, -0.0872],\n",
+ " [-0.2243, 0.1081, 0.1204, -0.1407, 0.2454, 0.1751, -0.0659, -0.0778,\n",
+ " 0.1699, 0.2420, 0.1921, -0.1972, 0.0297, -0.1185, 0.1182],\n",
+ " [ 0.1126, -0.1653, -0.2057, -0.0083, 0.0714, -0.1995, 0.1577, -0.2197,\n",
+ " -0.0100, -0.1859, -0.2013, -0.2470, -0.1084, -0.0784, -0.0567],\n",
+ " [-0.2133, 0.0452, 0.0957, 0.1400, -0.1520, -0.0957, 0.1586, 0.0450,\n",
+ " -0.1405, -0.1795, -0.1409, 0.2279, -0.1048, -0.0113, 0.1268],\n",
+ " [-0.2536, -0.1572, 0.2436, -0.1054, -0.2366, -0.1888, 0.0218, -0.0935,\n",
+ " -0.1377, 0.0630, 0.0064, -0.0734, 0.0873, 0.1178, -0.1568],\n",
+ " [-0.0283, 0.2286, 0.2539, 0.1228, 0.2385, -0.2216, -0.1502, -0.2049,\n",
+ " 0.0286, 0.1939, 0.0313, 0.1531, 0.2514, -0.2186, 0.0171],\n",
+ " [ 0.1501, 0.0752, -0.1921, -0.1191, -0.0748, 0.2097, 0.2306, -0.2150,\n",
+ " -0.2348, 0.0575, -0.0489, 0.0678, -0.1273, -0.1990, -0.0770],\n",
+ " [-0.0807, -0.0379, -0.1252, -0.1873, -0.0793, 0.0908, 0.1573, 0.1212,\n",
+ " 0.0937, 0.1422, -0.1502, -0.0369, 0.0321, -0.0830, 0.0571],\n",
+ " [-0.1586, -0.0179, 0.2351, 0.1576, 0.0692, -0.0726, -0.0875, -0.0382,\n",
+ " -0.2227, -0.0080, 0.0904, 0.1201, 0.2029, 0.0351, 0.0133]],\n",
+ " device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([-0.0141, 0.0566, -0.2259, -0.1533, -0.0400, 0.0475, -0.1080, -0.2453,\n",
+ " -0.1625, 0.1345], device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([[-0.2120, -0.1994, -0.0052, 0.2988, -0.2471, 0.0135, 0.1283, -0.0201,\n",
+ " 0.1182, -0.1972],\n",
+ " [-0.0328, 0.0113, 0.1010, 0.0589, -0.1486, 0.2598, 0.1771, 0.0474,\n",
+ " -0.0413, 0.1537],\n",
+ " [ 0.2083, -0.1183, -0.2833, -0.3092, 0.2081, 0.2566, 0.1134, -0.1159,\n",
+ " -0.2981, -0.2882],\n",
+ " [-0.1033, -0.2929, -0.2451, -0.2850, 0.3157, 0.3092, 0.0539, -0.2594,\n",
+ " 0.1327, -0.0194],\n",
+ " [ 0.1040, -0.3053, -0.1769, -0.2137, -0.0262, 0.2108, 0.0255, -0.1202,\n",
+ " 0.0413, 0.2326],\n",
+ " [-0.0355, -0.2424, 0.1064, 0.0394, -0.1730, 0.2130, 0.1174, -0.1641,\n",
+ " -0.1420, -0.0496]], device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([ 0.2195, -0.0579, 0.1234, 0.1060, -0.2140, 0.0288], device='cuda:0',\n",
+ " requires_grad=True), Parameter containing:\n",
+ "tensor([[ 0.2707, 0.1434, 0.0318, -0.0551, -0.1369, -0.1496]],\n",
+ " device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([-0.0661], device='cuda:0', requires_grad=True)]\n",
+ "==================================================\n",
+ "[Parameter containing:\n",
+ "tensor([[-6.9726e-02, 1.7731e+00, -7.6372e-02, -7.1058e-01, -7.1123e-04],\n",
+ " [-7.1336e-01, -5.4361e-01, -2.6582e-01, 5.6964e-01, -4.9271e-01],\n",
+ " [-1.3366e+00, -6.1715e-01, -8.0915e-01, -1.0871e+00, -1.7130e+00],\n",
+ " [ 1.1439e+00, 1.1702e+00, -5.0160e-02, -6.6409e-01, -9.0478e-02],\n",
+ " [-1.2307e-01, 9.7143e-01, -3.2899e-01, 1.0481e+00, 8.7611e-02],\n",
+ " [-3.6240e-01, 6.6012e-01, -1.7746e+00, 1.7587e-01, 3.9182e-01],\n",
+ " [ 9.5627e-01, -1.4544e+00, 8.6239e-01, 6.4864e-01, 1.2865e+00],\n",
+ " [-9.4586e-01, -2.7911e+00, 1.3601e+00, -8.6802e-01, 3.7288e-01],\n",
+ " [ 6.0642e-01, 8.3700e-01, 4.6751e-01, 9.6069e-01, 1.0403e+00],\n",
+ " [-2.4263e-01, -1.1883e+00, -1.1396e+00, 1.4794e+00, -6.1805e-01],\n",
+ " [-4.5019e-01, 3.1949e-01, 2.2522e-01, -1.4746e+00, -1.5766e+00],\n",
+ " [-1.0685e+00, -8.6054e-01, -4.0435e-01, 1.1669e-01, -5.0217e-01],\n",
+ " [ 3.2439e-01, -6.3227e-01, -9.7261e-01, 4.4522e-01, -1.1293e+00],\n",
+ " [-7.5776e-01, -5.1496e-01, 1.3075e+00, -3.3717e-01, 3.9504e-01],\n",
+ " [ 8.5924e-01, -1.0643e-01, -1.2336e+00, 1.7424e+00, -3.5817e-01]],\n",
+ " device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([[-0.5137, -0.3243, 0.4541, 0.4884, -0.3226, -0.9424, 1.2633, -0.5411,\n",
+ " -0.4047, -0.2105, -1.2065, 0.1194, -0.8706, -0.3806, -0.3470],\n",
+ " [ 1.1079, -0.6157, 0.0802, 0.1673, -1.1413, 0.3360, 2.0812, -0.4530,\n",
+ " 1.9244, -0.1222, 0.8667, -0.4083, -1.5679, 1.3790, -0.7961],\n",
+ " [-1.4472, 0.9576, 0.8538, 0.3888, -2.0439, -0.0737, -1.1856, -0.7158,\n",
+ " 2.7370, 2.5444, -0.7023, 1.4209, 0.0901, 0.8644, 0.7805],\n",
+ " [-0.0899, 0.7839, -0.3986, -0.0622, 0.5134, 0.8943, -0.7492, -1.7582,\n",
+ " 0.7121, 1.2855, -0.4851, 0.0736, -0.2256, 0.4992, -0.7854],\n",
+ " [ 2.5287, 0.5654, 0.7835, -0.6008, -1.9796, 1.6482, -1.0928, -0.0595,\n",
+ " -0.0564, -0.0729, -0.2910, -0.9512, -1.3852, -0.2895, -0.8505],\n",
+ " [-1.4534, -1.3051, 0.1817, 0.5341, 1.1192, -0.0837, -0.5133, -1.2826,\n",
+ " 1.5556, -1.0946, -1.0002, 0.3343, 0.1432, 1.8591, 0.0199],\n",
+ " [ 1.2811, -1.2248, 0.5370, 2.1831, 1.1538, 0.9405, 1.3957, 0.2959,\n",
+ " 1.1088, 1.6196, -0.1388, -0.0548, -0.2318, 0.1342, -0.1255],\n",
+ " [ 1.2155, -0.1889, -0.3502, -0.3293, 0.3099, 2.5889, 1.1579, 0.2868,\n",
+ " -0.1206, 0.3678, 0.2955, 0.7390, -1.1026, -0.5398, 0.0348],\n",
+ " [-0.9189, 0.8586, 1.3321, -0.7081, 0.0048, 0.7643, -0.5209, 0.2017,\n",
+ " 0.8094, -1.4677, 1.0291, 0.1726, 0.3298, 1.6261, -0.1650],\n",
+ " [ 0.7050, -0.3248, -0.2147, -0.4022, -0.3151, 0.9486, 1.3281, -0.2464,\n",
+ " -0.2804, 1.7853, -0.8694, -1.3244, 0.4455, 0.6532, 1.1969]],\n",
+ " device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',\n",
+ " requires_grad=True), Parameter containing:\n",
+ "tensor([[ 0.7268, -0.5613, 0.3554, -0.7328, 1.7000, -0.1115, -0.4507, 0.0304,\n",
+ " 0.5701, 1.0746],\n",
+ " [-0.7969, 0.7418, 1.5317, -1.0003, -0.2625, 1.7585, -3.1004, -1.2840,\n",
+ " -0.2230, 0.9398],\n",
+ " [-0.2167, -0.4646, -0.1701, -1.1754, 1.2224, -0.1925, -1.0030, -1.3171,\n",
+ " 0.7915, -1.7573],\n",
+ " [ 1.1589, -0.5555, 0.6436, -0.1423, -0.0412, -0.2644, 0.0179, 1.9647,\n",
+ " -1.0822, 0.4064],\n",
+ " [ 0.2140, -0.6242, 1.9719, -1.6243, -0.7416, 1.2082, 0.9993, 2.4522,\n",
+ " 1.2368, -0.9525],\n",
+ " [ 0.9530, 0.5625, -0.6094, -0.3213, 0.6820, -0.6440, 1.5492, -0.4234,\n",
+ " -0.3488, -0.1487]], device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([0., 0., 0., 0., 0., 0.], device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([[-1.2324, -0.2007, -1.9850, -1.1963, 0.4371, -0.8892]],\n",
+ " device='cuda:0', requires_grad=True), Parameter containing:\n",
+ "tensor([0.], device='cuda:0', requires_grad=True)]\n"
+ ]
+ }
+ ],
+ "source": [
+ "def initialize(self):\n",
+ " '''\n",
+ " 初始化网络参数\n",
+ " '''\n",
+ " for m in self.modules():\n",
+ " if isinstance(m, nn.Linear):\n",
+ " torch.nn.init.normal_(m.weight.data, 0.1)\n",
+ " if m.bias is not None:\n",
+ " torch.nn.init.zeros_(m.bias.data)\n",
+ "\n",
+ "\n",
+ "class Net(nn.Module):\n",
+ " '''\n",
+ " 网络结构\n",
+ " '''\n",
+ "\n",
+ " def __init__(self, **kwargs):\n",
+ " super(Net, self).__init__()\n",
+ " self.fc = nn.Sequential(\n",
+ " nn.Linear(5, 15),\n",
+ " nn.Sigmoid(),\n",
+ " nn.Linear(15, 10),\n",
+ " nn.Sigmoid(),\n",
+ " nn.Linear(10, 6),\n",
+ " nn.Sigmoid(),\n",
+ " nn.Linear(6, 1)\n",
+ " )\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = x.view(-1, 5)\n",
+ " x = self.fc(x)\n",
+ " return x.squeeze(-1)\n",
+ "\n",
+ "\n",
+ "# 初始化网络\n",
+ "model = Net()\n",
+ "model = model.cuda()\n",
+ "print(list(model.parameters()))\n",
+ "initialize(model)\n",
+ "print(\"==================================================\")\n",
+ "print(list(model.parameters()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "20fcccfe-3f82-44c3-bba2-9c66fb15cecd",
+ "metadata": {},
+ "source": [
+ "定义损失函数为均方误差损失函数,优化算法为随机梯度下降"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "7b530a85-50b1-4d0f-b47e-14831d0ee85d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 损失函数\n",
+ "criterion = nn.MSELoss()\n",
+ "# 优化器\n",
+ "optimizer = optim.SGD(model.parameters(), lr=lr)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "8694737f-afb5-4e58-ad14-3220971b142f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def train(epoch):\n",
+ " '''\n",
+ " 训练器\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " epoch : int\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None.\n",
+ "\n",
+ " '''\n",
+ " model.train()\n",
+ " train_loss = 0\n",
+ " for info, target in train_loader:\n",
+ " info = info.cuda()\n",
+ " target = target.cuda()\n",
+ " optimizer.zero_grad()\n",
+ " output = model(info)\n",
+ " loss = criterion(output, target)\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " train_loss += loss.item()*info.size(0)\n",
+ " train_loss = train_loss/len(train_loader.dataset)\n",
+ " print(\"epoch:%d, train loss:%.4f\" % (epoch, train_loss))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "bb6be6f8-16e6-4c2b-a56c-232b09faa654",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def validate(epoch):\n",
+ " '''\n",
+ " 验证\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " epoch : int\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None.\n",
+ "\n",
+ " '''\n",
+ " model.eval()\n",
+ " val_loss = 0\n",
+ " with torch.no_grad():\n",
+ " for info, target in test_loader:\n",
+ " info, target = info.cuda(), target.cuda()\n",
+ " output = model(info)\n",
+ " loss = criterion(output, target)\n",
+ " val_loss += loss.item()*info.size(0)\n",
+ " val_loss = val_loss/len(test_loader.dataset)\n",
+ " print(\"epoch:%d, validation loss:%.4f\" %\n",
+ " (epoch, val_loss))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "40532ea8-acfe-469a-aa9f-14d3bec2619d",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch:1, train loss:2.0057\n",
+ "epoch:1, validation loss:0.2197\n",
+ "epoch:2, train loss:0.2096\n",
+ "epoch:2, validation loss:0.2196\n",
+ "epoch:3, train loss:0.2098\n",
+ "epoch:3, validation loss:0.2196\n",
+ "epoch:4, train loss:0.2098\n",
+ "epoch:4, validation loss:0.2196\n",
+ "epoch:5, train loss:0.2100\n",
+ "epoch:5, validation loss:0.2195\n",
+ "epoch:6, train loss:0.2099\n",
+ "epoch:6, validation loss:0.2195\n",
+ "epoch:7, train loss:0.2101\n",
+ "epoch:7, validation loss:0.2195\n",
+ "epoch:8, train loss:0.2099\n",
+ "epoch:8, validation loss:0.2194\n",
+ "epoch:9, train loss:0.2099\n",
+ "epoch:9, validation loss:0.2194\n",
+ "epoch:10, train loss:0.2099\n",
+ "epoch:10, validation loss:0.2194\n",
+ "epoch:11, train loss:0.2093\n",
+ "epoch:11, validation loss:0.2193\n",
+ "epoch:12, train loss:0.2099\n",
+ "epoch:12, validation loss:0.2193\n",
+ "epoch:13, train loss:0.2100\n",
+ "epoch:13, validation loss:0.2193\n",
+ "epoch:14, train loss:0.2097\n",
+ "epoch:14, validation loss:0.2193\n",
+ "epoch:15, train loss:0.2094\n",
+ "epoch:15, validation loss:0.2192\n",
+ "epoch:16, train loss:0.2091\n",
+ "epoch:16, validation loss:0.2192\n",
+ "epoch:17, train loss:0.2098\n",
+ "epoch:17, validation loss:0.2192\n",
+ "epoch:18, train loss:0.2089\n",
+ "epoch:18, validation loss:0.2191\n",
+ "epoch:19, train loss:0.2091\n",
+ "epoch:19, validation loss:0.2191\n",
+ "epoch:20, train loss:0.2096\n",
+ "epoch:20, validation loss:0.2190\n",
+ "epoch:21, train loss:0.2092\n",
+ "epoch:21, validation loss:0.2190\n",
+ "epoch:22, train loss:0.2084\n",
+ "epoch:22, validation loss:0.2190\n",
+ "epoch:23, train loss:0.2093\n",
+ "epoch:23, validation loss:0.2190\n",
+ "epoch:24, train loss:0.2088\n",
+ "epoch:24, validation loss:0.2190\n",
+ "epoch:25, train loss:0.2088\n",
+ "epoch:25, validation loss:0.2189\n",
+ "epoch:26, train loss:0.2090\n",
+ "epoch:26, validation loss:0.2189\n",
+ "epoch:27, train loss:0.2092\n",
+ "epoch:27, validation loss:0.2189\n",
+ "epoch:28, train loss:0.2086\n",
+ "epoch:28, validation loss:0.2188\n",
+ "epoch:29, train loss:0.2087\n",
+ "epoch:29, validation loss:0.2189\n",
+ "epoch:30, train loss:0.2089\n",
+ "epoch:30, validation loss:0.2188\n",
+ "epoch:31, train loss:0.2090\n",
+ "epoch:31, validation loss:0.2187\n",
+ "epoch:32, train loss:0.2087\n",
+ "epoch:32, validation loss:0.2187\n",
+ "epoch:33, train loss:0.2092\n",
+ "epoch:33, validation loss:0.2187\n",
+ "epoch:34, train loss:0.2092\n",
+ "epoch:34, validation loss:0.2187\n",
+ "epoch:35, train loss:0.2091\n",
+ "epoch:35, validation loss:0.2186\n",
+ "epoch:36, train loss:0.2089\n",
+ "epoch:36, validation loss:0.2186\n",
+ "epoch:37, train loss:0.2086\n",
+ "epoch:37, validation loss:0.2186\n",
+ "epoch:38, train loss:0.2090\n",
+ "epoch:38, validation loss:0.2186\n",
+ "epoch:39, train loss:0.2088\n",
+ "epoch:39, validation loss:0.2185\n",
+ "epoch:40, train loss:0.2087\n",
+ "epoch:40, validation loss:0.2185\n",
+ "epoch:41, train loss:0.2087\n",
+ "epoch:41, validation loss:0.2186\n",
+ "epoch:42, train loss:0.2085\n",
+ "epoch:42, validation loss:0.2185\n",
+ "epoch:43, train loss:0.2088\n",
+ "epoch:43, validation loss:0.2184\n",
+ "epoch:44, train loss:0.2089\n",
+ "epoch:44, validation loss:0.2184\n",
+ "epoch:45, train loss:0.2089\n",
+ "epoch:45, validation loss:0.2184\n",
+ "epoch:46, train loss:0.2086\n",
+ "epoch:46, validation loss:0.2184\n",
+ "epoch:47, train loss:0.2089\n",
+ "epoch:47, validation loss:0.2183\n",
+ "epoch:48, train loss:0.2084\n",
+ "epoch:48, validation loss:0.2183\n",
+ "epoch:49, train loss:0.2080\n",
+ "epoch:49, validation loss:0.2183\n",
+ "epoch:50, train loss:0.2088\n",
+ "epoch:50, validation loss:0.2183\n",
+ "epoch:51, train loss:0.2086\n",
+ "epoch:51, validation loss:0.2183\n",
+ "epoch:52, train loss:0.2082\n",
+ "epoch:52, validation loss:0.2182\n",
+ "epoch:53, train loss:0.2080\n",
+ "epoch:53, validation loss:0.2182\n",
+ "epoch:54, train loss:0.2088\n",
+ "epoch:54, validation loss:0.2182\n",
+ "epoch:55, train loss:0.2086\n",
+ "epoch:55, validation loss:0.2181\n",
+ "epoch:56, train loss:0.2082\n",
+ "epoch:56, validation loss:0.2181\n",
+ "epoch:57, train loss:0.2088\n",
+ "epoch:57, validation loss:0.2181\n",
+ "epoch:58, train loss:0.2084\n",
+ "epoch:58, validation loss:0.2181\n",
+ "epoch:59, train loss:0.2085\n",
+ "epoch:59, validation loss:0.2180\n",
+ "epoch:60, train loss:0.2085\n",
+ "epoch:60, validation loss:0.2180\n",
+ "epoch:61, train loss:0.2084\n",
+ "epoch:61, validation loss:0.2180\n",
+ "epoch:62, train loss:0.2082\n",
+ "epoch:62, validation loss:0.2180\n",
+ "epoch:63, train loss:0.2081\n",
+ "epoch:63, validation loss:0.2180\n",
+ "epoch:64, train loss:0.2081\n",
+ "epoch:64, validation loss:0.2179\n",
+ "epoch:65, train loss:0.2078\n",
+ "epoch:65, validation loss:0.2179\n",
+ "epoch:66, train loss:0.2079\n",
+ "epoch:66, validation loss:0.2179\n",
+ "epoch:67, train loss:0.2080\n",
+ "epoch:67, validation loss:0.2179\n",
+ "epoch:68, train loss:0.2080\n",
+ "epoch:68, validation loss:0.2178\n",
+ "epoch:69, train loss:0.2082\n",
+ "epoch:69, validation loss:0.2178\n",
+ "epoch:70, train loss:0.2079\n",
+ "epoch:70, validation loss:0.2178\n",
+ "epoch:71, train loss:0.2083\n",
+ "epoch:71, validation loss:0.2178\n",
+ "epoch:72, train loss:0.2081\n",
+ "epoch:72, validation loss:0.2177\n",
+ "epoch:73, train loss:0.2079\n",
+ "epoch:73, validation loss:0.2177\n",
+ "epoch:74, train loss:0.2076\n",
+ "epoch:74, validation loss:0.2177\n",
+ "epoch:75, train loss:0.2081\n",
+ "epoch:75, validation loss:0.2177\n",
+ "epoch:76, train loss:0.2081\n",
+ "epoch:76, validation loss:0.2176\n",
+ "epoch:77, train loss:0.2080\n",
+ "epoch:77, validation loss:0.2176\n",
+ "epoch:78, train loss:0.2079\n",
+ "epoch:78, validation loss:0.2176\n",
+ "epoch:79, train loss:0.2081\n",
+ "epoch:79, validation loss:0.2176\n",
+ "epoch:80, train loss:0.2078\n",
+ "epoch:80, validation loss:0.2175\n",
+ "epoch:81, train loss:0.2076\n",
+ "epoch:81, validation loss:0.2175\n",
+ "epoch:82, train loss:0.2077\n",
+ "epoch:82, validation loss:0.2175\n",
+ "epoch:83, train loss:0.2077\n",
+ "epoch:83, validation loss:0.2175\n",
+ "epoch:84, train loss:0.2080\n",
+ "epoch:84, validation loss:0.2175\n",
+ "epoch:85, train loss:0.2074\n",
+ "epoch:85, validation loss:0.2174\n",
+ "epoch:86, train loss:0.2077\n",
+ "epoch:86, validation loss:0.2174\n",
+ "epoch:87, train loss:0.2079\n",
+ "epoch:87, validation loss:0.2174\n",
+ "epoch:88, train loss:0.2079\n",
+ "epoch:88, validation loss:0.2173\n",
+ "epoch:89, train loss:0.2075\n",
+ "epoch:89, validation loss:0.2173\n",
+ "epoch:90, train loss:0.2075\n",
+ "epoch:90, validation loss:0.2173\n",
+ "epoch:91, train loss:0.2070\n",
+ "epoch:91, validation loss:0.2173\n",
+ "epoch:92, train loss:0.2079\n",
+ "epoch:92, validation loss:0.2173\n",
+ "epoch:93, train loss:0.2075\n",
+ "epoch:93, validation loss:0.2172\n",
+ "epoch:94, train loss:0.2078\n",
+ "epoch:94, validation loss:0.2172\n",
+ "epoch:95, train loss:0.2068\n",
+ "epoch:95, validation loss:0.2172\n",
+ "epoch:96, train loss:0.2072\n",
+ "epoch:96, validation loss:0.2172\n",
+ "epoch:97, train loss:0.2076\n",
+ "epoch:97, validation loss:0.2172\n",
+ "epoch:98, train loss:0.2079\n",
+ "epoch:98, validation loss:0.2171\n",
+ "epoch:99, train loss:0.2074\n",
+ "epoch:99, validation loss:0.2171\n",
+ "epoch:100, train loss:0.2074\n",
+ "epoch:100, validation loss:0.2171\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 训练过程\n",
+ "for epoch in range(1, max_epochs+1):\n",
+ " train(epoch)\n",
+ " validate(epoch)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e3f438d3-813f-4db5-b8e9-da71d54730d4",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}