diff --git a/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb b/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb
deleted file mode 100644
index f3d96ab..0000000
--- a/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb
+++ /dev/null
@@ -1,863 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "a31d3c40-3593-4ca3-980e-578ee51e171a",
- "metadata": {},
- "source": [
- "# 用神经网络进行回归预测"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "738c03b4-3ca0-4a87-9143-53ddea5179be",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import pandas as pd\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "import torch.optim as optim\n",
- "from torch.utils.data import Dataset, DataLoader\n",
- "from sklearn.preprocessing import StandardScaler\n",
- "from sklearn.model_selection import train_test_split"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "637063c9-5924-416c-8201-adaae8514ffd",
- "metadata": {},
- "source": [
- "设置神经网络超参数,批大小为500,学习率为0.01,一共分别进行100次正向传播和反向传播"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "b909615c-783d-4bfb-acf6-1bc3bacfbb18",
- "metadata": {},
- "outputs": [],
- "source": [
- "# torch参数\n",
- "batch_size = 500\n",
- "lr = 0.01\n",
- "max_epochs = 100\n",
- "num_workers = 0\n",
- "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "904cfade-5667-440a-a13c-34fdf57f7b0f",
- "metadata": {},
- "source": [
- "定义所需数据集类"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "f6e65156-8535-49c2-8d9d-ecb8639855c9",
- "metadata": {},
- "outputs": [],
- "source": [
- "variables = ['number_of_reviews', 'price', 'review_scores_rating']\n",
- "# 数据集类\n",
- "\n",
- "\n",
- "class USDataset(Dataset):\n",
- " def __init__(self, df):\n",
- " '''\n",
- " 初始化\n",
- " df: 处理后的数据集\n",
- " '''\n",
- " self.df = df\n",
- " self.info = df[['number_of_reviews', 'price']].values\n",
- " self.target = df['review_scores_rating'].values\n",
- "\n",
- " def __getitem__(self, index):\n",
- " '''\n",
- " 根据编号返回信息\n",
- " index: 样本编号\n",
- " '''\n",
- " info = self.info[index]\n",
- " target = self.target[index]\n",
- " return info, target\n",
- "\n",
- " def __len__(self):\n",
- " '''\n",
- " 返回数据集样本个数\n",
- " '''\n",
- " return len(self.df)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "fc574f2e-6443-48af-9fee-dcf5fb29af58",
- "metadata": {},
- "source": [
- "读入数据,由于样本量很大,直接删除有缺失值的样本。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "17f8aedb-abf1-4622-98ec-d525a779274b",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " price | \n",
- " number_of_reviews | \n",
- " review_scores_rating | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | 0 | \n",
- " 120.0 | \n",
- " 90.0 | \n",
- " 4.50 | \n",
- "
\n",
- " \n",
- " | 1 | \n",
- " 90.0 | \n",
- " 351.0 | \n",
- " 4.58 | \n",
- "
\n",
- " \n",
- " | 2 | \n",
- " 66.0 | \n",
- " 67.0 | \n",
- " 4.52 | \n",
- "
\n",
- " \n",
- " | 3 | \n",
- " 33.0 | \n",
- " 297.0 | \n",
- " 4.70 | \n",
- "
\n",
- " \n",
- " | 4 | \n",
- " 125.0 | \n",
- " 58.0 | \n",
- " 4.96 | \n",
- "
\n",
- " \n",
- " | ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- "
\n",
- " \n",
- " | 203252 | \n",
- " 152.0 | \n",
- " 1.0 | \n",
- " 4.00 | \n",
- "
\n",
- " \n",
- " | 203253 | \n",
- " 45.0 | \n",
- " 1.0 | \n",
- " 3.00 | \n",
- "
\n",
- " \n",
- " | 203254 | \n",
- " 40.0 | \n",
- " 1.0 | \n",
- " 1.00 | \n",
- "
\n",
- " \n",
- " | 203276 | \n",
- " 43.0 | \n",
- " 1.0 | \n",
- " 5.00 | \n",
- "
\n",
- " \n",
- " | 203308 | \n",
- " 110.0 | \n",
- " 1.0 | \n",
- " 5.00 | \n",
- "
\n",
- " \n",
- "
\n",
- "
162429 rows × 3 columns
\n",
- "
"
- ],
- "text/plain": [
- " price number_of_reviews review_scores_rating\n",
- "0 120.0 90.0 4.50\n",
- "1 90.0 351.0 4.58\n",
- "2 66.0 67.0 4.52\n",
- "3 33.0 297.0 4.70\n",
- "4 125.0 58.0 4.96\n",
- "... ... ... ...\n",
- "203252 152.0 1.0 4.00\n",
- "203253 45.0 1.0 3.00\n",
- "203254 40.0 1.0 1.00\n",
- "203276 43.0 1.0 5.00\n",
- "203308 110.0 1.0 5.00\n",
- "\n",
- "[162429 rows x 3 columns]"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# 读入数据\n",
- "df = pd.read_csv('./data/2022-01(US_25).csv', usecols=variables)\n",
- "df['price'] = df['price'].replace('\\$', '', regex=True)\n",
- "df['price'] = df['price'].replace('\\,', '', regex=True).astype(float)\n",
- "df[['number_of_reviews']] = df[['number_of_reviews']].astype(float)\n",
- "for col in variables:\n",
- " df[col] = df[col].astype(np.float32)\n",
- " df = df[np.isnan(df[col]) != 1]\n",
- "df"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "758d8171-9310-4a51-b9a6-b0fa6274ffe1",
- "metadata": {},
- "source": [
- "将数据标准化"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "2cbe7cab-42cd-4b3b-bf11-080cc0a0f4c4",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "df_train:\n",
- " price number_of_reviews review_scores_rating\n",
- "0 -0.331270 -0.218031 4.71\n",
- "1 -0.030076 0.450276 4.99\n",
- "2 0.425112 -0.603593 5.00\n",
- "3 -0.324476 -0.590741 2.50\n",
- "4 0.017481 -0.577889 4.00\n",
- "... ... ... ...\n",
- "113695 -0.381091 0.128974 4.98\n",
- "113696 -0.220303 -0.282292 4.96\n",
- "113697 -0.165953 -0.603593 3.00\n",
- "113698 -0.152365 -0.539333 4.00\n",
- "113699 0.162417 -0.590741 5.00\n",
- "\n",
- "[113700 rows x 3 columns]\n",
- "============================================\n",
- "df_test:\n",
- " price number_of_reviews review_scores_rating\n",
- "0 -0.249743 0.411720 4.84\n",
- "1 -0.211245 -0.565037 5.00\n",
- "2 -0.376562 -0.603593 5.00\n",
- "3 -0.088956 -0.539333 4.33\n",
- "4 0.517962 0.013306 4.90\n",
- "... ... ... ...\n",
- "48724 0.028804 -0.192327 4.91\n",
- "48725 -0.367503 -0.500777 5.00\n",
- "48726 -0.360710 -0.436517 4.86\n",
- "48727 -0.084426 -0.603593 1.00\n",
- "48728 0.067303 -0.603593 5.00\n",
- "\n",
- "[48729 rows x 3 columns]\n"
- ]
- }
- ],
- "source": [
- "# 固定划分测试集和训练集\n",
- "col_df = df.columns\n",
- "info = df.iloc[:, :-1].values\n",
- "target = df.iloc[:, -1].values\n",
- "# 标准化\n",
- "stdscaler = StandardScaler()\n",
- "info_train, info_test, target_train, target_test = train_test_split(\n",
- " info, target, test_size=0.3, random_state=420)\n",
- "info_train = stdscaler.fit_transform(info_train)\n",
- "info_test = stdscaler.transform(info_test)\n",
- "df_train = np.hstack((info_train, target_train.reshape(-1, 1)))\n",
- "df_test = np.hstack((info_test, target_test.reshape(-1, 1)))\n",
- "df_train = pd.DataFrame(df_train, columns=col_df)\n",
- "df_test = pd.DataFrame(df_test, columns=col_df)\n",
- "print(\"df_train:\\n\", df_train)\n",
- "print(\"============================================\")\n",
- "print(\"df_test:\\n\", df_test)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1a9661ee-3dc4-4d5c-8a8d-bab3dc299d41",
- "metadata": {},
- "source": [
- "设置训练集和验证集"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "a10021bb-4e08-49b2-8269-114b8ce37be6",
- "metadata": {},
- "outputs": [],
- "source": [
- "# 初始化数据集\n",
- "train_data = USDataset(df_train)\n",
- "test_data = USDataset(df_test)\n",
- "train_loader = DataLoader(train_data, batch_size=batch_size,\n",
- " num_workers=num_workers, shuffle=True, drop_last=True)\n",
- "test_loader = DataLoader(test_data, batch_size=batch_size,\n",
- " num_workers=num_workers, shuffle=False)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "41a15589-9aaf-4541-a603-30a9f433c385",
- "metadata": {},
- "source": [
- "设计并初始化网络,其中神经网络是有一个输入层,一个输出层,三个隐藏层的前馈神经网络,激活函数是Sigmoid函数"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "764f28bc-81dd-4924-9039-5eceba54d739",
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[Parameter containing:\n",
- "tensor([[-0.3798, 0.4787],\n",
- " [-0.6687, 0.3603],\n",
- " [-0.4871, 0.1701],\n",
- " [ 0.1127, -0.2635],\n",
- " [ 0.6846, 0.6578],\n",
- " [-0.2820, 0.6887],\n",
- " [ 0.6581, -0.6334],\n",
- " [ 0.1057, -0.2249],\n",
- " [ 0.0650, 0.6174],\n",
- " [ 0.1768, -0.5373],\n",
- " [-0.5809, 0.6241],\n",
- " [ 0.1711, 0.3307],\n",
- " [ 0.0546, 0.6606],\n",
- " [ 0.0528, 0.1988],\n",
- " [ 0.6816, -0.2682]], device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([-0.0835, -0.5775, -0.2021, 0.6027, -0.6389, 0.3593, 0.1442, 0.0746,\n",
- " -0.6356, 0.3029, -0.5204, -0.4724, 0.2451, -0.0103, -0.6133],\n",
- " device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([[-0.2190, 0.1091, 0.1611, 0.0540, 0.2386, -0.0399, 0.0086, -0.2273,\n",
- " 0.1421, 0.0286, 0.0507, -0.1721, 0.1504, -0.1714, 0.0113],\n",
- " [ 0.1312, 0.0581, 0.1805, 0.1385, -0.1394, -0.2539, -0.0047, 0.0753,\n",
- " 0.2134, 0.1119, 0.0534, -0.1073, 0.2128, 0.1909, 0.2403],\n",
- " [ 0.0414, 0.1597, -0.2470, -0.0118, 0.1226, -0.1825, 0.2529, 0.2161,\n",
- " 0.0246, 0.2333, -0.0733, 0.1474, 0.2101, 0.0246, 0.0514],\n",
- " [ 0.1203, 0.0301, 0.1293, -0.0800, -0.1533, -0.1585, -0.0815, -0.1242,\n",
- " 0.1415, -0.1520, 0.1287, -0.0701, -0.2560, -0.0094, -0.2361],\n",
- " [-0.1045, 0.0645, 0.2181, -0.0610, -0.0585, -0.1396, 0.2380, 0.2466,\n",
- " -0.2495, -0.0690, 0.1078, -0.1838, 0.0231, -0.0264, -0.1335],\n",
- " [ 0.0328, 0.2338, 0.0761, 0.2368, -0.0891, -0.1581, 0.1605, -0.0699,\n",
- " 0.1658, -0.0683, -0.2559, -0.0122, 0.1695, 0.2399, -0.1445],\n",
- " [-0.2489, 0.1559, -0.0083, 0.0980, -0.1653, 0.0678, -0.1833, -0.0217,\n",
- " -0.0746, 0.0633, -0.0559, 0.0334, 0.1273, -0.1259, -0.1527],\n",
- " [ 0.2253, -0.0182, 0.1613, -0.1602, -0.2205, -0.2047, 0.2046, -0.0824,\n",
- " 0.1261, -0.1015, 0.2257, -0.1276, -0.0197, -0.0505, -0.2408],\n",
- " [ 0.1883, -0.0590, -0.1909, -0.0710, -0.0304, -0.0085, -0.1848, 0.0480,\n",
- " -0.0447, 0.1001, 0.2330, -0.1368, -0.1478, -0.1676, 0.1942],\n",
- " [ 0.0554, -0.1686, -0.2217, -0.0343, 0.0218, 0.2300, 0.0303, 0.0711,\n",
- " -0.0932, -0.1791, 0.0978, -0.1902, -0.2322, -0.1179, -0.1592]],\n",
- " device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([-0.0498, -0.1153, 0.0559, 0.1391, -0.0231, 0.1752, 0.1824, -0.1185,\n",
- " -0.1845, 0.1776], device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([[-0.1260, -0.2593, -0.2229, 0.0065, 0.0831, -0.2702, -0.0523, 0.2718,\n",
- " 0.2547, 0.0824],\n",
- " [ 0.2193, -0.2204, 0.2573, 0.2804, 0.2895, -0.1308, 0.0558, -0.2521,\n",
- " -0.2073, -0.2508],\n",
- " [ 0.2057, -0.2991, 0.0932, -0.1904, 0.2148, -0.1540, 0.0756, -0.1294,\n",
- " -0.2314, -0.3153],\n",
- " [ 0.0636, -0.1243, 0.0715, 0.0041, -0.0620, -0.0212, 0.0812, 0.1220,\n",
- " -0.1704, 0.1767],\n",
- " [ 0.0371, -0.1627, 0.1664, 0.0266, 0.2154, -0.2726, -0.2169, -0.2568,\n",
- " 0.0378, -0.1306],\n",
- " [-0.0030, -0.1174, -0.0169, -0.2378, 0.2668, -0.1097, -0.0550, -0.3043,\n",
- " 0.1614, 0.2220]], device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([ 0.2263, -0.1245, -0.0600, -0.1303, 0.2717, -0.2886], device='cuda:0',\n",
- " requires_grad=True), Parameter containing:\n",
- "tensor([[-0.0118, -0.0705, -0.3342, -0.0979, -0.0048, -0.1048]],\n",
- " device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([0.3463], device='cuda:0', requires_grad=True)]\n",
- "==================================================\n",
- "[Parameter containing:\n",
- "tensor([[ 0.1109, -0.5283],\n",
- " [-0.3936, -0.6835],\n",
- " [ 1.6326, -1.0302],\n",
- " [ 1.3571, 1.2765],\n",
- " [-2.0980, 1.2285],\n",
- " [-0.2051, 0.3710],\n",
- " [-0.8627, 0.1593],\n",
- " [ 0.2369, 0.0855],\n",
- " [ 1.7011, -1.2683],\n",
- " [ 0.0239, -0.0731],\n",
- " [ 0.3759, 0.6586],\n",
- " [ 0.1748, 1.2761],\n",
- " [ 0.7805, -0.3940],\n",
- " [-1.7492, -0.3534],\n",
- " [-0.1294, 2.1254]], device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
- " device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([[-1.0957, 1.0378, 0.3581, 0.6767, -0.0237, 0.6258, -0.6001, 0.4281,\n",
- " 0.0584, -0.7526, -0.7916, 0.6348, 0.2372, -0.0493, -1.0156],\n",
- " [ 0.5839, 0.1032, 1.3735, -0.1767, -1.5280, 0.9311, 1.2752, 0.0721,\n",
- " -0.0195, -1.0396, -1.2622, -0.2573, 1.1950, 1.2330, -1.1773],\n",
- " [-1.7392, -0.7176, 0.2932, -1.0405, 1.1690, -0.2543, -0.3711, -0.5608,\n",
- " 0.2679, -0.1302, 1.1744, 0.6578, 0.2809, -0.4317, 0.3950],\n",
- " [ 0.4147, -0.5277, -1.3396, 0.6607, 0.3036, 1.3448, -0.6774, 0.2073,\n",
- " 1.4682, 1.1449, 0.9241, 1.1630, 0.8778, -0.8538, 1.1645],\n",
- " [-0.2414, -1.1791, -0.9192, -1.3586, 0.7888, -1.2873, -0.7585, 0.1478,\n",
- " -0.2461, -1.7109, -1.3293, -0.1607, -0.0435, 0.7059, 0.8339],\n",
- " [ 0.5776, -1.4641, -0.8456, 1.4503, -0.0230, 0.9254, -1.1829, -0.7663,\n",
- " -0.5102, -1.7416, -0.2776, 1.8553, 0.0287, 1.3071, -0.0577],\n",
- " [-2.0167, 0.5340, 0.2045, -1.0659, 0.1410, -0.2611, 1.8272, -0.3517,\n",
- " -1.3150, 0.9499, 1.4403, 2.6355, 1.0081, -0.1915, -1.7888],\n",
- " [ 0.5569, -0.7237, -0.2397, 1.7227, 1.0570, 0.1441, -0.0228, 1.0863,\n",
- " 0.5095, -0.1426, -1.6044, -0.4300, 1.5130, 0.7247, -0.1672],\n",
- " [-0.2202, 0.8956, -0.5899, 0.6426, 1.0715, 2.8716, 0.6675, -0.3307,\n",
- " 0.4560, -1.6545, 0.6043, 1.4544, 0.3106, 2.0657, 0.4080],\n",
- " [-1.5387, 0.4301, 1.4037, 1.2928, 0.9105, 0.2546, -1.9276, -0.9019,\n",
- " -0.0115, 0.4342, -0.1543, 1.0857, 0.7773, 0.0530, 1.0929]],\n",
- " device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',\n",
- " requires_grad=True), Parameter containing:\n",
- "tensor([[ 3.2868, 2.1744, -0.4586, 0.3916, 1.0370, 0.2906, -0.6354, -1.7215,\n",
- " -1.3256, 1.4936],\n",
- " [ 0.5906, -1.0098, -1.7117, 0.8507, 0.3034, 1.7005, -2.9092, -0.1906,\n",
- " -1.4068, 2.3292],\n",
- " [ 0.2435, -0.9280, -0.4463, -0.4837, -2.3409, -0.5112, 0.6242, 0.9315,\n",
- " 1.6558, -1.3278],\n",
- " [ 0.4638, 0.5903, -0.7866, 1.3214, 0.6855, 0.7143, 1.2887, -0.2633,\n",
- " -0.0572, 1.0767],\n",
- " [ 1.8814, -1.2857, 0.4642, -0.2101, 2.0882, 1.3320, -0.3117, -0.7563,\n",
- " 0.7099, 0.2467],\n",
- " [ 0.7781, -1.1896, -0.2684, 1.3948, -0.8702, 0.3317, 0.6835, 0.2728,\n",
- " 0.3942, 1.5071]], device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([0., 0., 0., 0., 0., 0.], device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([[ 1.5669, 0.3820, 1.0288, 1.5494, -0.9633, -0.8551]],\n",
- " device='cuda:0', requires_grad=True), Parameter containing:\n",
- "tensor([0.], device='cuda:0', requires_grad=True)]\n"
- ]
- }
- ],
- "source": [
- "def initialize(self):\n",
- " '''\n",
- " 初始化网络参数\n",
- " '''\n",
- " for m in self.modules():\n",
- " if isinstance(m, nn.Linear):\n",
- " torch.nn.init.normal_(m.weight.data, 0.1)\n",
- " if m.bias is not None:\n",
- " torch.nn.init.zeros_(m.bias.data)\n",
- "\n",
- "\n",
- "class Net(nn.Module):\n",
- " '''\n",
- " 网络结构\n",
- " '''\n",
- "\n",
- " def __init__(self, **kwargs):\n",
- " super(Net, self).__init__()\n",
- " self.fc = nn.Sequential(\n",
- " nn.Linear(2, 15),\n",
- " nn.Sigmoid(),\n",
- " nn.Linear(15, 10),\n",
- " nn.Sigmoid(),\n",
- " nn.Linear(10, 6),\n",
- " nn.Sigmoid(),\n",
- " nn.Linear(6, 1)\n",
- " )\n",
- "\n",
- " def forward(self, x):\n",
- " x = x.view(-1, 2)\n",
- " x = self.fc(x)\n",
- " return x.squeeze(-1)\n",
- "\n",
- "\n",
- "# 初始化网络\n",
- "model = Net()\n",
- "model = model.cuda()\n",
- "print(list(model.parameters()))\n",
- "initialize(model)\n",
- "print(\"==================================================\")\n",
- "print(list(model.parameters()))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "20fcccfe-3f82-44c3-bba2-9c66fb15cecd",
- "metadata": {},
- "source": [
- "定义损失函数为均方误差损失函数,优化算法为随机梯度下降"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "7b530a85-50b1-4d0f-b47e-14831d0ee85d",
- "metadata": {},
- "outputs": [],
- "source": [
- "# 损失函数\n",
- "criterion = nn.MSELoss()\n",
- "# 优化器\n",
- "optimizer = optim.SGD(model.parameters(), lr=lr)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "8694737f-afb5-4e58-ad14-3220971b142f",
- "metadata": {},
- "outputs": [],
- "source": [
- "def train(epoch):\n",
- " '''\n",
- " 训练器\n",
- "\n",
- " Parameters\n",
- " ----------\n",
- " epoch : int\n",
- "\n",
- " Returns\n",
- " -------\n",
- " None.\n",
- "\n",
- " '''\n",
- " model.train()\n",
- " train_loss = 0\n",
- " for info, target in train_loader:\n",
- " info = info.cuda()\n",
- " target = target.cuda()\n",
- " optimizer.zero_grad()\n",
- " output = model(info)\n",
- " loss = criterion(output, target)\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- " train_loss += loss.item()*info.size(0)\n",
- " train_loss = train_loss/len(train_loader.dataset)\n",
- " print(\"epoch:%d, train loss:%.4f\" % (epoch, train_loss))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "id": "bb6be6f8-16e6-4c2b-a56c-232b09faa654",
- "metadata": {},
- "outputs": [],
- "source": [
- "def validate(epoch):\n",
- " '''\n",
- " 验证\n",
- "\n",
- " Parameters\n",
- " ----------\n",
- " epoch : int\n",
- "\n",
- " Returns\n",
- " -------\n",
- " None.\n",
- "\n",
- " '''\n",
- " model.eval()\n",
- " val_loss = 0\n",
- " with torch.no_grad():\n",
- " for info, target in test_loader:\n",
- " info, target = info.cuda(), target.cuda()\n",
- " output = model(info)\n",
- " loss = criterion(output, target)\n",
- " val_loss += loss.item()*info.size(0)\n",
- " val_loss = val_loss/len(test_loader.dataset)\n",
- " print(\"epoch:%d, validation loss:%.4f\" %\n",
- " (epoch, val_loss))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "id": "40532ea8-acfe-469a-aa9f-14d3bec2619d",
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "epoch:1, train loss:0.5139\n",
- "epoch:1, validation loss:0.3730\n",
- "epoch:2, train loss:0.3496\n",
- "epoch:2, validation loss:0.3698\n",
- "epoch:3, train loss:0.3471\n",
- "epoch:3, validation loss:0.3677\n",
- "epoch:4, train loss:0.3450\n",
- "epoch:4, validation loss:0.3662\n",
- "epoch:5, train loss:0.3439\n",
- "epoch:5, validation loss:0.3652\n",
- "epoch:6, train loss:0.3429\n",
- "epoch:6, validation loss:0.3644\n",
- "epoch:7, train loss:0.3425\n",
- "epoch:7, validation loss:0.3635\n",
- "epoch:8, train loss:0.3418\n",
- "epoch:8, validation loss:0.3631\n",
- "epoch:9, train loss:0.3417\n",
- "epoch:9, validation loss:0.3625\n",
- "epoch:10, train loss:0.3411\n",
- "epoch:10, validation loss:0.3621\n",
- "epoch:11, train loss:0.3405\n",
- "epoch:11, validation loss:0.3620\n",
- "epoch:12, train loss:0.3406\n",
- "epoch:12, validation loss:0.3617\n",
- "epoch:13, train loss:0.3393\n",
- "epoch:13, validation loss:0.3615\n",
- "epoch:14, train loss:0.3401\n",
- "epoch:14, validation loss:0.3616\n",
- "epoch:15, train loss:0.3396\n",
- "epoch:15, validation loss:0.3610\n",
- "epoch:16, train loss:0.3400\n",
- "epoch:16, validation loss:0.3608\n",
- "epoch:17, train loss:0.3390\n",
- "epoch:17, validation loss:0.3607\n",
- "epoch:18, train loss:0.3392\n",
- "epoch:18, validation loss:0.3608\n",
- "epoch:19, train loss:0.3396\n",
- "epoch:19, validation loss:0.3606\n",
- "epoch:20, train loss:0.3396\n",
- "epoch:20, validation loss:0.3604\n",
- "epoch:21, train loss:0.3396\n",
- "epoch:21, validation loss:0.3604\n",
- "epoch:22, train loss:0.3389\n",
- "epoch:22, validation loss:0.3603\n",
- "epoch:23, train loss:0.3389\n",
- "epoch:23, validation loss:0.3601\n",
- "epoch:24, train loss:0.3387\n",
- "epoch:24, validation loss:0.3602\n",
- "epoch:25, train loss:0.3390\n",
- "epoch:25, validation loss:0.3600\n",
- "epoch:26, train loss:0.3393\n",
- "epoch:26, validation loss:0.3600\n",
- "epoch:27, train loss:0.3385\n",
- "epoch:27, validation loss:0.3599\n",
- "epoch:28, train loss:0.3388\n",
- "epoch:28, validation loss:0.3598\n",
- "epoch:29, train loss:0.3388\n",
- "epoch:29, validation loss:0.3598\n",
- "epoch:30, train loss:0.3388\n",
- "epoch:30, validation loss:0.3597\n",
- "epoch:31, train loss:0.3387\n",
- "epoch:31, validation loss:0.3596\n",
- "epoch:32, train loss:0.3390\n",
- "epoch:32, validation loss:0.3596\n",
- "epoch:33, train loss:0.3384\n",
- "epoch:33, validation loss:0.3596\n",
- "epoch:34, train loss:0.3389\n",
- "epoch:34, validation loss:0.3595\n",
- "epoch:35, train loss:0.3385\n",
- "epoch:35, validation loss:0.3594\n",
- "epoch:36, train loss:0.3381\n",
- "epoch:36, validation loss:0.3594\n",
- "epoch:37, train loss:0.3380\n",
- "epoch:37, validation loss:0.3593\n",
- "epoch:38, train loss:0.3387\n",
- "epoch:38, validation loss:0.3593\n",
- "epoch:39, train loss:0.3388\n",
- "epoch:39, validation loss:0.3592\n",
- "epoch:40, train loss:0.3385\n",
- "epoch:40, validation loss:0.3592\n",
- "epoch:41, train loss:0.3386\n",
- "epoch:41, validation loss:0.3593\n",
- "epoch:42, train loss:0.3384\n",
- "epoch:42, validation loss:0.3595\n",
- "epoch:43, train loss:0.3373\n",
- "epoch:43, validation loss:0.3590\n",
- "epoch:44, train loss:0.3376\n",
- "epoch:44, validation loss:0.3591\n",
- "epoch:45, train loss:0.3385\n",
- "epoch:45, validation loss:0.3590\n",
- "epoch:46, train loss:0.3381\n",
- "epoch:46, validation loss:0.3589\n",
- "epoch:47, train loss:0.3382\n",
- "epoch:47, validation loss:0.3591\n",
- "epoch:48, train loss:0.3381\n",
- "epoch:48, validation loss:0.3589\n",
- "epoch:49, train loss:0.3381\n",
- "epoch:49, validation loss:0.3588\n",
- "epoch:50, train loss:0.3379\n",
- "epoch:50, validation loss:0.3587\n",
- "epoch:51, train loss:0.3383\n",
- "epoch:51, validation loss:0.3587\n",
- "epoch:52, train loss:0.3379\n",
- "epoch:52, validation loss:0.3587\n",
- "epoch:53, train loss:0.3378\n",
- "epoch:53, validation loss:0.3587\n",
- "epoch:54, train loss:0.3379\n",
- "epoch:54, validation loss:0.3585\n",
- "epoch:55, train loss:0.3380\n",
- "epoch:55, validation loss:0.3586\n",
- "epoch:56, train loss:0.3378\n",
- "epoch:56, validation loss:0.3585\n",
- "epoch:57, train loss:0.3378\n",
- "epoch:57, validation loss:0.3586\n",
- "epoch:58, train loss:0.3379\n",
- "epoch:58, validation loss:0.3587\n",
- "epoch:59, train loss:0.3374\n",
- "epoch:59, validation loss:0.3583\n",
- "epoch:60, train loss:0.3380\n",
- "epoch:60, validation loss:0.3583\n",
- "epoch:61, train loss:0.3375\n",
- "epoch:61, validation loss:0.3587\n",
- "epoch:62, train loss:0.3375\n",
- "epoch:62, validation loss:0.3583\n",
- "epoch:63, train loss:0.3377\n",
- "epoch:63, validation loss:0.3582\n",
- "epoch:64, train loss:0.3373\n",
- "epoch:64, validation loss:0.3584\n",
- "epoch:65, train loss:0.3374\n",
- "epoch:65, validation loss:0.3581\n",
- "epoch:66, train loss:0.3373\n",
- "epoch:66, validation loss:0.3581\n",
- "epoch:67, train loss:0.3378\n",
- "epoch:67, validation loss:0.3581\n",
- "epoch:68, train loss:0.3370\n",
- "epoch:68, validation loss:0.3581\n",
- "epoch:69, train loss:0.3370\n",
- "epoch:69, validation loss:0.3582\n",
- "epoch:70, train loss:0.3374\n",
- "epoch:70, validation loss:0.3580\n",
- "epoch:71, train loss:0.3374\n",
- "epoch:71, validation loss:0.3580\n",
- "epoch:72, train loss:0.3375\n",
- "epoch:72, validation loss:0.3579\n",
- "epoch:73, train loss:0.3373\n",
- "epoch:73, validation loss:0.3580\n",
- "epoch:74, train loss:0.3371\n",
- "epoch:74, validation loss:0.3583\n",
- "epoch:75, train loss:0.3376\n",
- "epoch:75, validation loss:0.3578\n",
- "epoch:76, train loss:0.3369\n",
- "epoch:76, validation loss:0.3577\n",
- "epoch:77, train loss:0.3371\n",
- "epoch:77, validation loss:0.3577\n",
- "epoch:78, train loss:0.3373\n",
- "epoch:78, validation loss:0.3577\n",
- "epoch:79, train loss:0.3373\n",
- "epoch:79, validation loss:0.3578\n",
- "epoch:80, train loss:0.3374\n",
- "epoch:80, validation loss:0.3577\n",
- "epoch:81, train loss:0.3368\n",
- "epoch:81, validation loss:0.3576\n",
- "epoch:82, train loss:0.3365\n",
- "epoch:82, validation loss:0.3577\n",
- "epoch:83, train loss:0.3371\n",
- "epoch:83, validation loss:0.3576\n",
- "epoch:84, train loss:0.3371\n",
- "epoch:84, validation loss:0.3575\n",
- "epoch:85, train loss:0.3371\n",
- "epoch:85, validation loss:0.3574\n",
- "epoch:86, train loss:0.3370\n",
- "epoch:86, validation loss:0.3574\n",
- "epoch:87, train loss:0.3371\n",
- "epoch:87, validation loss:0.3574\n",
- "epoch:88, train loss:0.3369\n",
- "epoch:88, validation loss:0.3574\n",
- "epoch:89, train loss:0.3364\n",
- "epoch:89, validation loss:0.3573\n",
- "epoch:90, train loss:0.3367\n",
- "epoch:90, validation loss:0.3573\n",
- "epoch:91, train loss:0.3367\n",
- "epoch:91, validation loss:0.3575\n",
- "epoch:92, train loss:0.3367\n",
- "epoch:92, validation loss:0.3573\n",
- "epoch:93, train loss:0.3372\n",
- "epoch:93, validation loss:0.3572\n",
- "epoch:94, train loss:0.3368\n",
- "epoch:94, validation loss:0.3572\n",
- "epoch:95, train loss:0.3369\n",
- "epoch:95, validation loss:0.3571\n",
- "epoch:96, train loss:0.3366\n",
- "epoch:96, validation loss:0.3571\n",
- "epoch:97, train loss:0.3368\n",
- "epoch:97, validation loss:0.3571\n",
- "epoch:98, train loss:0.3367\n",
- "epoch:98, validation loss:0.3571\n",
- "epoch:99, train loss:0.3367\n",
- "epoch:99, validation loss:0.3573\n",
- "epoch:100, train loss:0.3365\n",
- "epoch:100, validation loss:0.3571\n"
- ]
- }
- ],
- "source": [
- "# 训练过程\n",
- "for epoch in range(1, max_epochs+1):\n",
- " train(epoch)\n",
- " validate(epoch)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e3f438d3-813f-4db5-b8e9-da71d54730d4",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}