diff --git a/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb b/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb deleted file mode 100644 index f3d96ab..0000000 --- a/共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb +++ /dev/null @@ -1,863 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "a31d3c40-3593-4ca3-980e-578ee51e171a", - "metadata": {}, - "source": [ - "# 用神经网络进行回归预测" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "738c03b4-3ca0-4a87-9143-53ddea5179be", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "from torch.utils.data import Dataset, DataLoader\n", - "from sklearn.preprocessing import StandardScaler\n", - "from sklearn.model_selection import train_test_split" - ] - }, - { - "cell_type": "markdown", - "id": "637063c9-5924-416c-8201-adaae8514ffd", - "metadata": {}, - "source": [ - "设置神经网络超参数,批大小为500,学习率为0.01,一共分别进行100次正向传播和反向传播" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "b909615c-783d-4bfb-acf6-1bc3bacfbb18", - "metadata": {}, - "outputs": [], - "source": [ - "# torch参数\n", - "batch_size = 500\n", - "lr = 0.01\n", - "max_epochs = 100\n", - "num_workers = 0\n", - "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")" - ] - }, - { - "cell_type": "markdown", - "id": "904cfade-5667-440a-a13c-34fdf57f7b0f", - "metadata": {}, - "source": [ - "定义所需数据集类" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "f6e65156-8535-49c2-8d9d-ecb8639855c9", - "metadata": {}, - "outputs": [], - "source": [ - "variables = ['number_of_reviews', 'price', 'review_scores_rating']\n", - "# 数据集类\n", - "\n", - "\n", - "class USDataset(Dataset):\n", - " def __init__(self, df):\n", - " '''\n", - " 初始化\n", - " df: 处理后的数据集\n", - " '''\n", - " self.df = df\n", - " self.info = df[['number_of_reviews', 'price']].values\n", - " self.target = df['review_scores_rating'].values\n", - "\n", - " def __getitem__(self, index):\n", - " '''\n", - " 根据编号返回信息\n", - " index: 样本编号\n", - " '''\n", - " info = self.info[index]\n", - " target = self.target[index]\n", - " return info, target\n", - "\n", - " def __len__(self):\n", - " '''\n", - " 返回数据集样本个数\n", - " '''\n", - " return len(self.df)" - ] - }, - { - "cell_type": "markdown", - "id": "fc574f2e-6443-48af-9fee-dcf5fb29af58", - "metadata": {}, - "source": [ - "读入数据,由于样本量很大,直接删除有缺失值的样本。" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "17f8aedb-abf1-4622-98ec-d525a779274b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
pricenumber_of_reviewsreview_scores_rating
0120.090.04.50
190.0351.04.58
266.067.04.52
333.0297.04.70
4125.058.04.96
............
203252152.01.04.00
20325345.01.03.00
20325440.01.01.00
20327643.01.05.00
203308110.01.05.00
\n", - "

162429 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " price number_of_reviews review_scores_rating\n", - "0 120.0 90.0 4.50\n", - "1 90.0 351.0 4.58\n", - "2 66.0 67.0 4.52\n", - "3 33.0 297.0 4.70\n", - "4 125.0 58.0 4.96\n", - "... ... ... ...\n", - "203252 152.0 1.0 4.00\n", - "203253 45.0 1.0 3.00\n", - "203254 40.0 1.0 1.00\n", - "203276 43.0 1.0 5.00\n", - "203308 110.0 1.0 5.00\n", - "\n", - "[162429 rows x 3 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 读入数据\n", - "df = pd.read_csv('./data/2022-01(US_25).csv', usecols=variables)\n", - "df['price'] = df['price'].replace('\\$', '', regex=True)\n", - "df['price'] = df['price'].replace('\\,', '', regex=True).astype(float)\n", - "df[['number_of_reviews']] = df[['number_of_reviews']].astype(float)\n", - "for col in variables:\n", - " df[col] = df[col].astype(np.float32)\n", - " df = df[np.isnan(df[col]) != 1]\n", - "df" - ] - }, - { - "cell_type": "markdown", - "id": "758d8171-9310-4a51-b9a6-b0fa6274ffe1", - "metadata": {}, - "source": [ - "将数据标准化" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "2cbe7cab-42cd-4b3b-bf11-080cc0a0f4c4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "df_train:\n", - " price number_of_reviews review_scores_rating\n", - "0 -0.331270 -0.218031 4.71\n", - "1 -0.030076 0.450276 4.99\n", - "2 0.425112 -0.603593 5.00\n", - "3 -0.324476 -0.590741 2.50\n", - "4 0.017481 -0.577889 4.00\n", - "... ... ... ...\n", - "113695 -0.381091 0.128974 4.98\n", - "113696 -0.220303 -0.282292 4.96\n", - "113697 -0.165953 -0.603593 3.00\n", - "113698 -0.152365 -0.539333 4.00\n", - "113699 0.162417 -0.590741 5.00\n", - "\n", - "[113700 rows x 3 columns]\n", - "============================================\n", - "df_test:\n", - " price number_of_reviews review_scores_rating\n", - "0 -0.249743 0.411720 4.84\n", - "1 -0.211245 -0.565037 5.00\n", - "2 -0.376562 -0.603593 5.00\n", - "3 -0.088956 -0.539333 4.33\n", - "4 0.517962 0.013306 4.90\n", - "... ... ... ...\n", - "48724 0.028804 -0.192327 4.91\n", - "48725 -0.367503 -0.500777 5.00\n", - "48726 -0.360710 -0.436517 4.86\n", - "48727 -0.084426 -0.603593 1.00\n", - "48728 0.067303 -0.603593 5.00\n", - "\n", - "[48729 rows x 3 columns]\n" - ] - } - ], - "source": [ - "# 固定划分测试集和训练集\n", - "col_df = df.columns\n", - "info = df.iloc[:, :-1].values\n", - "target = df.iloc[:, -1].values\n", - "# 标准化\n", - "stdscaler = StandardScaler()\n", - "info_train, info_test, target_train, target_test = train_test_split(\n", - " info, target, test_size=0.3, random_state=420)\n", - "info_train = stdscaler.fit_transform(info_train)\n", - "info_test = stdscaler.transform(info_test)\n", - "df_train = np.hstack((info_train, target_train.reshape(-1, 1)))\n", - "df_test = np.hstack((info_test, target_test.reshape(-1, 1)))\n", - "df_train = pd.DataFrame(df_train, columns=col_df)\n", - "df_test = pd.DataFrame(df_test, columns=col_df)\n", - "print(\"df_train:\\n\", df_train)\n", - "print(\"============================================\")\n", - "print(\"df_test:\\n\", df_test)" - ] - }, - { - "cell_type": "markdown", - "id": "1a9661ee-3dc4-4d5c-8a8d-bab3dc299d41", - "metadata": {}, - "source": [ - "设置训练集和验证集" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "a10021bb-4e08-49b2-8269-114b8ce37be6", - "metadata": {}, - "outputs": [], - "source": [ - "# 初始化数据集\n", - "train_data = USDataset(df_train)\n", - "test_data = USDataset(df_test)\n", - "train_loader = DataLoader(train_data, batch_size=batch_size,\n", - " num_workers=num_workers, shuffle=True, drop_last=True)\n", - "test_loader = DataLoader(test_data, batch_size=batch_size,\n", - " num_workers=num_workers, shuffle=False)" - ] - }, - { - "cell_type": "markdown", - "id": "41a15589-9aaf-4541-a603-30a9f433c385", - "metadata": {}, - "source": [ - "设计并初始化网络,其中神经网络是有一个输入层,一个输出层,三个隐藏层的前馈神经网络,激活函数是Sigmoid函数" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "764f28bc-81dd-4924-9039-5eceba54d739", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Parameter containing:\n", - "tensor([[-0.3798, 0.4787],\n", - " [-0.6687, 0.3603],\n", - " [-0.4871, 0.1701],\n", - " [ 0.1127, -0.2635],\n", - " [ 0.6846, 0.6578],\n", - " [-0.2820, 0.6887],\n", - " [ 0.6581, -0.6334],\n", - " [ 0.1057, -0.2249],\n", - " [ 0.0650, 0.6174],\n", - " [ 0.1768, -0.5373],\n", - " [-0.5809, 0.6241],\n", - " [ 0.1711, 0.3307],\n", - " [ 0.0546, 0.6606],\n", - " [ 0.0528, 0.1988],\n", - " [ 0.6816, -0.2682]], device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([-0.0835, -0.5775, -0.2021, 0.6027, -0.6389, 0.3593, 0.1442, 0.0746,\n", - " -0.6356, 0.3029, -0.5204, -0.4724, 0.2451, -0.0103, -0.6133],\n", - " device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([[-0.2190, 0.1091, 0.1611, 0.0540, 0.2386, -0.0399, 0.0086, -0.2273,\n", - " 0.1421, 0.0286, 0.0507, -0.1721, 0.1504, -0.1714, 0.0113],\n", - " [ 0.1312, 0.0581, 0.1805, 0.1385, -0.1394, -0.2539, -0.0047, 0.0753,\n", - " 0.2134, 0.1119, 0.0534, -0.1073, 0.2128, 0.1909, 0.2403],\n", - " [ 0.0414, 0.1597, -0.2470, -0.0118, 0.1226, -0.1825, 0.2529, 0.2161,\n", - " 0.0246, 0.2333, -0.0733, 0.1474, 0.2101, 0.0246, 0.0514],\n", - " [ 0.1203, 0.0301, 0.1293, -0.0800, -0.1533, -0.1585, -0.0815, -0.1242,\n", - " 0.1415, -0.1520, 0.1287, -0.0701, -0.2560, -0.0094, -0.2361],\n", - " [-0.1045, 0.0645, 0.2181, -0.0610, -0.0585, -0.1396, 0.2380, 0.2466,\n", - " -0.2495, -0.0690, 0.1078, -0.1838, 0.0231, -0.0264, -0.1335],\n", - " [ 0.0328, 0.2338, 0.0761, 0.2368, -0.0891, -0.1581, 0.1605, -0.0699,\n", - " 0.1658, -0.0683, -0.2559, -0.0122, 0.1695, 0.2399, -0.1445],\n", - " [-0.2489, 0.1559, -0.0083, 0.0980, -0.1653, 0.0678, -0.1833, -0.0217,\n", - " -0.0746, 0.0633, -0.0559, 0.0334, 0.1273, -0.1259, -0.1527],\n", - " [ 0.2253, -0.0182, 0.1613, -0.1602, -0.2205, -0.2047, 0.2046, -0.0824,\n", - " 0.1261, -0.1015, 0.2257, -0.1276, -0.0197, -0.0505, -0.2408],\n", - " [ 0.1883, -0.0590, -0.1909, -0.0710, -0.0304, -0.0085, -0.1848, 0.0480,\n", - " -0.0447, 0.1001, 0.2330, -0.1368, -0.1478, -0.1676, 0.1942],\n", - " [ 0.0554, -0.1686, -0.2217, -0.0343, 0.0218, 0.2300, 0.0303, 0.0711,\n", - " -0.0932, -0.1791, 0.0978, -0.1902, -0.2322, -0.1179, -0.1592]],\n", - " device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([-0.0498, -0.1153, 0.0559, 0.1391, -0.0231, 0.1752, 0.1824, -0.1185,\n", - " -0.1845, 0.1776], device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([[-0.1260, -0.2593, -0.2229, 0.0065, 0.0831, -0.2702, -0.0523, 0.2718,\n", - " 0.2547, 0.0824],\n", - " [ 0.2193, -0.2204, 0.2573, 0.2804, 0.2895, -0.1308, 0.0558, -0.2521,\n", - " -0.2073, -0.2508],\n", - " [ 0.2057, -0.2991, 0.0932, -0.1904, 0.2148, -0.1540, 0.0756, -0.1294,\n", - " -0.2314, -0.3153],\n", - " [ 0.0636, -0.1243, 0.0715, 0.0041, -0.0620, -0.0212, 0.0812, 0.1220,\n", - " -0.1704, 0.1767],\n", - " [ 0.0371, -0.1627, 0.1664, 0.0266, 0.2154, -0.2726, -0.2169, -0.2568,\n", - " 0.0378, -0.1306],\n", - " [-0.0030, -0.1174, -0.0169, -0.2378, 0.2668, -0.1097, -0.0550, -0.3043,\n", - " 0.1614, 0.2220]], device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([ 0.2263, -0.1245, -0.0600, -0.1303, 0.2717, -0.2886], device='cuda:0',\n", - " requires_grad=True), Parameter containing:\n", - "tensor([[-0.0118, -0.0705, -0.3342, -0.0979, -0.0048, -0.1048]],\n", - " device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([0.3463], device='cuda:0', requires_grad=True)]\n", - "==================================================\n", - "[Parameter containing:\n", - "tensor([[ 0.1109, -0.5283],\n", - " [-0.3936, -0.6835],\n", - " [ 1.6326, -1.0302],\n", - " [ 1.3571, 1.2765],\n", - " [-2.0980, 1.2285],\n", - " [-0.2051, 0.3710],\n", - " [-0.8627, 0.1593],\n", - " [ 0.2369, 0.0855],\n", - " [ 1.7011, -1.2683],\n", - " [ 0.0239, -0.0731],\n", - " [ 0.3759, 0.6586],\n", - " [ 0.1748, 1.2761],\n", - " [ 0.7805, -0.3940],\n", - " [-1.7492, -0.3534],\n", - " [-0.1294, 2.1254]], device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([[-1.0957, 1.0378, 0.3581, 0.6767, -0.0237, 0.6258, -0.6001, 0.4281,\n", - " 0.0584, -0.7526, -0.7916, 0.6348, 0.2372, -0.0493, -1.0156],\n", - " [ 0.5839, 0.1032, 1.3735, -0.1767, -1.5280, 0.9311, 1.2752, 0.0721,\n", - " -0.0195, -1.0396, -1.2622, -0.2573, 1.1950, 1.2330, -1.1773],\n", - " [-1.7392, -0.7176, 0.2932, -1.0405, 1.1690, -0.2543, -0.3711, -0.5608,\n", - " 0.2679, -0.1302, 1.1744, 0.6578, 0.2809, -0.4317, 0.3950],\n", - " [ 0.4147, -0.5277, -1.3396, 0.6607, 0.3036, 1.3448, -0.6774, 0.2073,\n", - " 1.4682, 1.1449, 0.9241, 1.1630, 0.8778, -0.8538, 1.1645],\n", - " [-0.2414, -1.1791, -0.9192, -1.3586, 0.7888, -1.2873, -0.7585, 0.1478,\n", - " -0.2461, -1.7109, -1.3293, -0.1607, -0.0435, 0.7059, 0.8339],\n", - " [ 0.5776, -1.4641, -0.8456, 1.4503, -0.0230, 0.9254, -1.1829, -0.7663,\n", - " -0.5102, -1.7416, -0.2776, 1.8553, 0.0287, 1.3071, -0.0577],\n", - " [-2.0167, 0.5340, 0.2045, -1.0659, 0.1410, -0.2611, 1.8272, -0.3517,\n", - " -1.3150, 0.9499, 1.4403, 2.6355, 1.0081, -0.1915, -1.7888],\n", - " [ 0.5569, -0.7237, -0.2397, 1.7227, 1.0570, 0.1441, -0.0228, 1.0863,\n", - " 0.5095, -0.1426, -1.6044, -0.4300, 1.5130, 0.7247, -0.1672],\n", - " [-0.2202, 0.8956, -0.5899, 0.6426, 1.0715, 2.8716, 0.6675, -0.3307,\n", - " 0.4560, -1.6545, 0.6043, 1.4544, 0.3106, 2.0657, 0.4080],\n", - " [-1.5387, 0.4301, 1.4037, 1.2928, 0.9105, 0.2546, -1.9276, -0.9019,\n", - " -0.0115, 0.4342, -0.1543, 1.0857, 0.7773, 0.0530, 1.0929]],\n", - " device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',\n", - " requires_grad=True), Parameter containing:\n", - "tensor([[ 3.2868, 2.1744, -0.4586, 0.3916, 1.0370, 0.2906, -0.6354, -1.7215,\n", - " -1.3256, 1.4936],\n", - " [ 0.5906, -1.0098, -1.7117, 0.8507, 0.3034, 1.7005, -2.9092, -0.1906,\n", - " -1.4068, 2.3292],\n", - " [ 0.2435, -0.9280, -0.4463, -0.4837, -2.3409, -0.5112, 0.6242, 0.9315,\n", - " 1.6558, -1.3278],\n", - " [ 0.4638, 0.5903, -0.7866, 1.3214, 0.6855, 0.7143, 1.2887, -0.2633,\n", - " -0.0572, 1.0767],\n", - " [ 1.8814, -1.2857, 0.4642, -0.2101, 2.0882, 1.3320, -0.3117, -0.7563,\n", - " 0.7099, 0.2467],\n", - " [ 0.7781, -1.1896, -0.2684, 1.3948, -0.8702, 0.3317, 0.6835, 0.2728,\n", - " 0.3942, 1.5071]], device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([0., 0., 0., 0., 0., 0.], device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([[ 1.5669, 0.3820, 1.0288, 1.5494, -0.9633, -0.8551]],\n", - " device='cuda:0', requires_grad=True), Parameter containing:\n", - "tensor([0.], device='cuda:0', requires_grad=True)]\n" - ] - } - ], - "source": [ - "def initialize(self):\n", - " '''\n", - " 初始化网络参数\n", - " '''\n", - " for m in self.modules():\n", - " if isinstance(m, nn.Linear):\n", - " torch.nn.init.normal_(m.weight.data, 0.1)\n", - " if m.bias is not None:\n", - " torch.nn.init.zeros_(m.bias.data)\n", - "\n", - "\n", - "class Net(nn.Module):\n", - " '''\n", - " 网络结构\n", - " '''\n", - "\n", - " def __init__(self, **kwargs):\n", - " super(Net, self).__init__()\n", - " self.fc = nn.Sequential(\n", - " nn.Linear(2, 15),\n", - " nn.Sigmoid(),\n", - " nn.Linear(15, 10),\n", - " nn.Sigmoid(),\n", - " nn.Linear(10, 6),\n", - " nn.Sigmoid(),\n", - " nn.Linear(6, 1)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " x = x.view(-1, 2)\n", - " x = self.fc(x)\n", - " return x.squeeze(-1)\n", - "\n", - "\n", - "# 初始化网络\n", - "model = Net()\n", - "model = model.cuda()\n", - "print(list(model.parameters()))\n", - "initialize(model)\n", - "print(\"==================================================\")\n", - "print(list(model.parameters()))" - ] - }, - { - "cell_type": "markdown", - "id": "20fcccfe-3f82-44c3-bba2-9c66fb15cecd", - "metadata": {}, - "source": [ - "定义损失函数为均方误差损失函数,优化算法为随机梯度下降" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "7b530a85-50b1-4d0f-b47e-14831d0ee85d", - "metadata": {}, - "outputs": [], - "source": [ - "# 损失函数\n", - "criterion = nn.MSELoss()\n", - "# 优化器\n", - "optimizer = optim.SGD(model.parameters(), lr=lr)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "8694737f-afb5-4e58-ad14-3220971b142f", - "metadata": {}, - "outputs": [], - "source": [ - "def train(epoch):\n", - " '''\n", - " 训练器\n", - "\n", - " Parameters\n", - " ----------\n", - " epoch : int\n", - "\n", - " Returns\n", - " -------\n", - " None.\n", - "\n", - " '''\n", - " model.train()\n", - " train_loss = 0\n", - " for info, target in train_loader:\n", - " info = info.cuda()\n", - " target = target.cuda()\n", - " optimizer.zero_grad()\n", - " output = model(info)\n", - " loss = criterion(output, target)\n", - " loss.backward()\n", - " optimizer.step()\n", - " train_loss += loss.item()*info.size(0)\n", - " train_loss = train_loss/len(train_loader.dataset)\n", - " print(\"epoch:%d, train loss:%.4f\" % (epoch, train_loss))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "bb6be6f8-16e6-4c2b-a56c-232b09faa654", - "metadata": {}, - "outputs": [], - "source": [ - "def validate(epoch):\n", - " '''\n", - " 验证\n", - "\n", - " Parameters\n", - " ----------\n", - " epoch : int\n", - "\n", - " Returns\n", - " -------\n", - " None.\n", - "\n", - " '''\n", - " model.eval()\n", - " val_loss = 0\n", - " with torch.no_grad():\n", - " for info, target in test_loader:\n", - " info, target = info.cuda(), target.cuda()\n", - " output = model(info)\n", - " loss = criterion(output, target)\n", - " val_loss += loss.item()*info.size(0)\n", - " val_loss = val_loss/len(test_loader.dataset)\n", - " print(\"epoch:%d, validation loss:%.4f\" %\n", - " (epoch, val_loss))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "40532ea8-acfe-469a-aa9f-14d3bec2619d", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch:1, train loss:0.5139\n", - "epoch:1, validation loss:0.3730\n", - "epoch:2, train loss:0.3496\n", - "epoch:2, validation loss:0.3698\n", - "epoch:3, train loss:0.3471\n", - "epoch:3, validation loss:0.3677\n", - "epoch:4, train loss:0.3450\n", - "epoch:4, validation loss:0.3662\n", - "epoch:5, train loss:0.3439\n", - "epoch:5, validation loss:0.3652\n", - "epoch:6, train loss:0.3429\n", - "epoch:6, validation loss:0.3644\n", - "epoch:7, train loss:0.3425\n", - "epoch:7, validation loss:0.3635\n", - "epoch:8, train loss:0.3418\n", - "epoch:8, validation loss:0.3631\n", - "epoch:9, train loss:0.3417\n", - "epoch:9, validation loss:0.3625\n", - "epoch:10, train loss:0.3411\n", - "epoch:10, validation loss:0.3621\n", - "epoch:11, train loss:0.3405\n", - "epoch:11, validation loss:0.3620\n", - "epoch:12, train loss:0.3406\n", - "epoch:12, validation loss:0.3617\n", - "epoch:13, train loss:0.3393\n", - "epoch:13, validation loss:0.3615\n", - "epoch:14, train loss:0.3401\n", - "epoch:14, validation loss:0.3616\n", - "epoch:15, train loss:0.3396\n", - "epoch:15, validation loss:0.3610\n", - "epoch:16, train loss:0.3400\n", - "epoch:16, validation loss:0.3608\n", - "epoch:17, train loss:0.3390\n", - "epoch:17, validation loss:0.3607\n", - "epoch:18, train loss:0.3392\n", - "epoch:18, validation loss:0.3608\n", - "epoch:19, train loss:0.3396\n", - "epoch:19, validation loss:0.3606\n", - "epoch:20, train loss:0.3396\n", - "epoch:20, validation loss:0.3604\n", - "epoch:21, train loss:0.3396\n", - "epoch:21, validation loss:0.3604\n", - "epoch:22, train loss:0.3389\n", - "epoch:22, validation loss:0.3603\n", - "epoch:23, train loss:0.3389\n", - "epoch:23, validation loss:0.3601\n", - "epoch:24, train loss:0.3387\n", - "epoch:24, validation loss:0.3602\n", - "epoch:25, train loss:0.3390\n", - "epoch:25, validation loss:0.3600\n", - "epoch:26, train loss:0.3393\n", - "epoch:26, validation loss:0.3600\n", - "epoch:27, train loss:0.3385\n", - "epoch:27, validation loss:0.3599\n", - "epoch:28, train loss:0.3388\n", - "epoch:28, validation loss:0.3598\n", - "epoch:29, train loss:0.3388\n", - "epoch:29, validation loss:0.3598\n", - "epoch:30, train loss:0.3388\n", - "epoch:30, validation loss:0.3597\n", - "epoch:31, train loss:0.3387\n", - "epoch:31, validation loss:0.3596\n", - "epoch:32, train loss:0.3390\n", - "epoch:32, validation loss:0.3596\n", - "epoch:33, train loss:0.3384\n", - "epoch:33, validation loss:0.3596\n", - "epoch:34, train loss:0.3389\n", - "epoch:34, validation loss:0.3595\n", - "epoch:35, train loss:0.3385\n", - "epoch:35, validation loss:0.3594\n", - "epoch:36, train loss:0.3381\n", - "epoch:36, validation loss:0.3594\n", - "epoch:37, train loss:0.3380\n", - "epoch:37, validation loss:0.3593\n", - "epoch:38, train loss:0.3387\n", - "epoch:38, validation loss:0.3593\n", - "epoch:39, train loss:0.3388\n", - "epoch:39, validation loss:0.3592\n", - "epoch:40, train loss:0.3385\n", - "epoch:40, validation loss:0.3592\n", - "epoch:41, train loss:0.3386\n", - "epoch:41, validation loss:0.3593\n", - "epoch:42, train loss:0.3384\n", - "epoch:42, validation loss:0.3595\n", - "epoch:43, train loss:0.3373\n", - "epoch:43, validation loss:0.3590\n", - "epoch:44, train loss:0.3376\n", - "epoch:44, validation loss:0.3591\n", - "epoch:45, train loss:0.3385\n", - "epoch:45, validation loss:0.3590\n", - "epoch:46, train loss:0.3381\n", - "epoch:46, validation loss:0.3589\n", - "epoch:47, train loss:0.3382\n", - "epoch:47, validation loss:0.3591\n", - "epoch:48, train loss:0.3381\n", - "epoch:48, validation loss:0.3589\n", - "epoch:49, train loss:0.3381\n", - "epoch:49, validation loss:0.3588\n", - "epoch:50, train loss:0.3379\n", - "epoch:50, validation loss:0.3587\n", - "epoch:51, train loss:0.3383\n", - "epoch:51, validation loss:0.3587\n", - "epoch:52, train loss:0.3379\n", - "epoch:52, validation loss:0.3587\n", - "epoch:53, train loss:0.3378\n", - "epoch:53, validation loss:0.3587\n", - "epoch:54, train loss:0.3379\n", - "epoch:54, validation loss:0.3585\n", - "epoch:55, train loss:0.3380\n", - "epoch:55, validation loss:0.3586\n", - "epoch:56, train loss:0.3378\n", - "epoch:56, validation loss:0.3585\n", - "epoch:57, train loss:0.3378\n", - "epoch:57, validation loss:0.3586\n", - "epoch:58, train loss:0.3379\n", - "epoch:58, validation loss:0.3587\n", - "epoch:59, train loss:0.3374\n", - "epoch:59, validation loss:0.3583\n", - "epoch:60, train loss:0.3380\n", - "epoch:60, validation loss:0.3583\n", - "epoch:61, train loss:0.3375\n", - "epoch:61, validation loss:0.3587\n", - "epoch:62, train loss:0.3375\n", - "epoch:62, validation loss:0.3583\n", - "epoch:63, train loss:0.3377\n", - "epoch:63, validation loss:0.3582\n", - "epoch:64, train loss:0.3373\n", - "epoch:64, validation loss:0.3584\n", - "epoch:65, train loss:0.3374\n", - "epoch:65, validation loss:0.3581\n", - "epoch:66, train loss:0.3373\n", - "epoch:66, validation loss:0.3581\n", - "epoch:67, train loss:0.3378\n", - "epoch:67, validation loss:0.3581\n", - "epoch:68, train loss:0.3370\n", - "epoch:68, validation loss:0.3581\n", - "epoch:69, train loss:0.3370\n", - "epoch:69, validation loss:0.3582\n", - "epoch:70, train loss:0.3374\n", - "epoch:70, validation loss:0.3580\n", - "epoch:71, train loss:0.3374\n", - "epoch:71, validation loss:0.3580\n", - "epoch:72, train loss:0.3375\n", - "epoch:72, validation loss:0.3579\n", - "epoch:73, train loss:0.3373\n", - "epoch:73, validation loss:0.3580\n", - "epoch:74, train loss:0.3371\n", - "epoch:74, validation loss:0.3583\n", - "epoch:75, train loss:0.3376\n", - "epoch:75, validation loss:0.3578\n", - "epoch:76, train loss:0.3369\n", - "epoch:76, validation loss:0.3577\n", - "epoch:77, train loss:0.3371\n", - "epoch:77, validation loss:0.3577\n", - "epoch:78, train loss:0.3373\n", - "epoch:78, validation loss:0.3577\n", - "epoch:79, train loss:0.3373\n", - "epoch:79, validation loss:0.3578\n", - "epoch:80, train loss:0.3374\n", - "epoch:80, validation loss:0.3577\n", - "epoch:81, train loss:0.3368\n", - "epoch:81, validation loss:0.3576\n", - "epoch:82, train loss:0.3365\n", - "epoch:82, validation loss:0.3577\n", - "epoch:83, train loss:0.3371\n", - "epoch:83, validation loss:0.3576\n", - "epoch:84, train loss:0.3371\n", - "epoch:84, validation loss:0.3575\n", - "epoch:85, train loss:0.3371\n", - "epoch:85, validation loss:0.3574\n", - "epoch:86, train loss:0.3370\n", - "epoch:86, validation loss:0.3574\n", - "epoch:87, train loss:0.3371\n", - "epoch:87, validation loss:0.3574\n", - "epoch:88, train loss:0.3369\n", - "epoch:88, validation loss:0.3574\n", - "epoch:89, train loss:0.3364\n", - "epoch:89, validation loss:0.3573\n", - "epoch:90, train loss:0.3367\n", - "epoch:90, validation loss:0.3573\n", - "epoch:91, train loss:0.3367\n", - "epoch:91, validation loss:0.3575\n", - "epoch:92, train loss:0.3367\n", - "epoch:92, validation loss:0.3573\n", - "epoch:93, train loss:0.3372\n", - "epoch:93, validation loss:0.3572\n", - "epoch:94, train loss:0.3368\n", - "epoch:94, validation loss:0.3572\n", - "epoch:95, train loss:0.3369\n", - "epoch:95, validation loss:0.3571\n", - "epoch:96, train loss:0.3366\n", - "epoch:96, validation loss:0.3571\n", - "epoch:97, train loss:0.3368\n", - "epoch:97, validation loss:0.3571\n", - "epoch:98, train loss:0.3367\n", - "epoch:98, validation loss:0.3571\n", - "epoch:99, train loss:0.3367\n", - "epoch:99, validation loss:0.3573\n", - "epoch:100, train loss:0.3365\n", - "epoch:100, validation loss:0.3571\n" - ] - } - ], - "source": [ - "# 训练过程\n", - "for epoch in range(1, max_epochs+1):\n", - " train(epoch)\n", - " validate(epoch)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e3f438d3-813f-4db5-b8e9-da71d54730d4", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}