前馈神经网络回归
Signed-off-by: 张卓立 <13190677+zhang-zhuoli@user.noreply.gitee.com>
This commit is contained in:
@@ -0,0 +1,944 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a31d3c40-3593-4ca3-980e-578ee51e171a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 用神经网络进行回归预测"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "738c03b4-3ca0-4a87-9143-53ddea5179be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"from sklearn.preprocessing import StandardScaler\n",
|
||||
"from sklearn.model_selection import train_test_split"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "637063c9-5924-416c-8201-adaae8514ffd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"设置神经网络超参数,批大小为500,学习率为0.01,一共分别进行100次正向传播和反向传播"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b909615c-783d-4bfb-acf6-1bc3bacfbb18",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# torch参数\n",
|
||||
"batch_size = 500\n",
|
||||
"lr = 0.01\n",
|
||||
"max_epochs = 100\n",
|
||||
"num_workers = 0\n",
|
||||
"device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "904cfade-5667-440a-a13c-34fdf57f7b0f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"定义所需数据集类"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f6e65156-8535-49c2-8d9d-ecb8639855c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"variables = ['number_of_reviews', 'price', 'accommodates',\n",
|
||||
" 'host_response_rate', 'host_acceptance_rate', 'review_scores_rating']\n",
|
||||
"# 数据集类\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class USDataset(Dataset):\n",
|
||||
" def __init__(self, df):\n",
|
||||
" '''\n",
|
||||
" 初始化\n",
|
||||
" df: 处理后的数据集\n",
|
||||
" '''\n",
|
||||
" self.df = df\n",
|
||||
" self.info = df[['number_of_reviews', 'price', 'accommodates',\n",
|
||||
" 'host_response_rate', 'host_acceptance_rate']].values\n",
|
||||
" self.target = df['review_scores_rating'].values\n",
|
||||
"\n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" '''\n",
|
||||
" 根据编号返回信息\n",
|
||||
" index: 样本编号\n",
|
||||
" '''\n",
|
||||
" info = self.info[index]\n",
|
||||
" target = self.target[index]\n",
|
||||
" return info, target\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" '''\n",
|
||||
" 返回数据集样本个数\n",
|
||||
" '''\n",
|
||||
" return len(self.df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fc574f2e-6443-48af-9fee-dcf5fb29af58",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"读入数据,由于样本量很大,直接删除有缺失值的样本。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "17f8aedb-abf1-4622-98ec-d525a779274b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>host_response_rate</th>\n",
|
||||
" <th>host_acceptance_rate</th>\n",
|
||||
" <th>accommodates</th>\n",
|
||||
" <th>price</th>\n",
|
||||
" <th>number_of_reviews</th>\n",
|
||||
" <th>review_scores_rating</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.33</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>120.0</td>\n",
|
||||
" <td>90.0</td>\n",
|
||||
" <td>4.50</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.98</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>90.0</td>\n",
|
||||
" <td>351.0</td>\n",
|
||||
" <td>4.58</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.98</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>66.0</td>\n",
|
||||
" <td>67.0</td>\n",
|
||||
" <td>4.52</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.98</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>33.0</td>\n",
|
||||
" <td>297.0</td>\n",
|
||||
" <td>4.70</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>45.0</td>\n",
|
||||
" <td>42.0</td>\n",
|
||||
" <td>4.98</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203252</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.93</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" <td>152.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>4.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203253</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.97</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>45.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>3.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203254</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>0.97</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>40.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203276</th>\n",
|
||||
" <td>0.99</td>\n",
|
||||
" <td>0.99</td>\n",
|
||||
" <td>2.0</td>\n",
|
||||
" <td>43.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203308</th>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>110.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.00</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>134835 rows × 6 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" host_response_rate host_acceptance_rate accommodates price \\\n",
|
||||
"0 1.00 0.33 2.0 120.0 \n",
|
||||
"1 1.00 0.98 2.0 90.0 \n",
|
||||
"2 1.00 0.98 2.0 66.0 \n",
|
||||
"3 1.00 0.98 1.0 33.0 \n",
|
||||
"5 1.00 1.00 2.0 45.0 \n",
|
||||
"... ... ... ... ... \n",
|
||||
"203252 1.00 0.93 4.0 152.0 \n",
|
||||
"203253 1.00 0.97 2.0 45.0 \n",
|
||||
"203254 1.00 0.97 2.0 40.0 \n",
|
||||
"203276 0.99 0.99 2.0 43.0 \n",
|
||||
"203308 1.00 1.00 3.0 110.0 \n",
|
||||
"\n",
|
||||
" number_of_reviews review_scores_rating \n",
|
||||
"0 90.0 4.50 \n",
|
||||
"1 351.0 4.58 \n",
|
||||
"2 67.0 4.52 \n",
|
||||
"3 297.0 4.70 \n",
|
||||
"5 42.0 4.98 \n",
|
||||
"... ... ... \n",
|
||||
"203252 1.0 4.00 \n",
|
||||
"203253 1.0 3.00 \n",
|
||||
"203254 1.0 1.00 \n",
|
||||
"203276 1.0 5.00 \n",
|
||||
"203308 1.0 5.00 \n",
|
||||
"\n",
|
||||
"[134835 rows x 6 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 读入数据\n",
|
||||
"df = pd.read_csv('../data/2022-01(US_25).csv', usecols=variables)\n",
|
||||
"df['price'] = df['price'].replace('\\$', '', regex=True)\n",
|
||||
"df['price'] = df['price'].replace('\\,', '', regex=True).astype(float)\n",
|
||||
"df[['host_response_rate', 'host_acceptance_rate']] = df[['host_response_rate',\n",
|
||||
" 'host_acceptance_rate']].replace('\\%', '', regex=True).astype(float)*0.01\n",
|
||||
"df[['number_of_reviews']] = df[['number_of_reviews']].astype(float)\n",
|
||||
"for col in variables:\n",
|
||||
" df[col] = df[col].astype(np.float32)\n",
|
||||
" df = df[np.isnan(df[col]) != 1]\n",
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "758d8171-9310-4a51-b9a6-b0fa6274ffe1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"将数据标准化"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "2cbe7cab-42cd-4b3b-bf11-080cc0a0f4c4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"df_train:\n",
|
||||
" host_response_rate host_acceptance_rate accommodates price \\\n",
|
||||
"0 0.297689 0.440210 -0.101531 -0.147944 \n",
|
||||
"1 0.297689 0.440210 -0.101531 -0.272867 \n",
|
||||
"2 0.297689 0.393235 -0.807638 -0.258453 \n",
|
||||
"3 0.297689 0.534159 -0.101531 -0.366559 \n",
|
||||
"4 0.297689 0.393235 -0.454584 -0.099897 \n",
|
||||
"... ... ... ... ... \n",
|
||||
"94379 0.297689 -0.499281 -0.101531 0.313307 \n",
|
||||
"94380 0.297689 0.440210 2.016788 0.697684 \n",
|
||||
"94381 0.297689 0.534159 0.604575 -0.222417 \n",
|
||||
"94382 0.297689 0.252311 -0.807638 -0.301695 \n",
|
||||
"94383 0.297689 -0.687179 -0.807638 -0.318512 \n",
|
||||
"\n",
|
||||
" number_of_reviews review_scores_rating \n",
|
||||
"0 -0.629684 3.00 \n",
|
||||
"1 0.060760 4.93 \n",
|
||||
"2 -0.629684 4.00 \n",
|
||||
"3 2.045787 4.91 \n",
|
||||
"4 0.295018 4.83 \n",
|
||||
"... ... ... \n",
|
||||
"94379 -0.407756 5.00 \n",
|
||||
"94380 -0.531050 5.00 \n",
|
||||
"94381 0.036101 4.69 \n",
|
||||
"94382 1.145743 4.85 \n",
|
||||
"94383 0.726545 4.83 \n",
|
||||
"\n",
|
||||
"[94384 rows x 6 columns]\n",
|
||||
"============================================\n",
|
||||
"df_test:\n",
|
||||
" host_response_rate host_acceptance_rate accommodates price \\\n",
|
||||
"0 0.297689 0.581133 -1.160691 -0.392985 \n",
|
||||
"1 0.297689 -0.969026 0.604575 0.714500 \n",
|
||||
"2 0.297689 0.581133 -0.807638 -0.042241 \n",
|
||||
"3 0.095805 -1.626670 0.604575 3.104842 \n",
|
||||
"4 0.297689 -0.452307 -0.807638 -0.289683 \n",
|
||||
"... ... ... ... ... \n",
|
||||
"40446 0.095805 0.534159 -0.101531 -0.244039 \n",
|
||||
"40447 -0.038785 0.534159 1.310681 1.353526 \n",
|
||||
"40448 0.297689 -0.217434 -0.101531 0.197994 \n",
|
||||
"40449 0.297689 0.440210 0.604575 0.358952 \n",
|
||||
"40450 0.297689 0.581133 -0.807638 -0.159956 \n",
|
||||
"\n",
|
||||
" number_of_reviews review_scores_rating \n",
|
||||
"0 0.726545 4.76 \n",
|
||||
"1 -0.580367 4.60 \n",
|
||||
"2 3.352699 4.92 \n",
|
||||
"3 -0.617355 2.50 \n",
|
||||
"4 -0.592696 5.00 \n",
|
||||
"... ... ... \n",
|
||||
"40446 -0.617355 5.00 \n",
|
||||
"40447 -0.629684 5.00 \n",
|
||||
"40448 -0.185828 4.97 \n",
|
||||
"40449 -0.494061 4.67 \n",
|
||||
"40450 0.368994 4.96 \n",
|
||||
"\n",
|
||||
"[40451 rows x 6 columns]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 固定划分测试集和训练集\n",
|
||||
"col_df = df.columns\n",
|
||||
"info = df.iloc[:, :-1].values\n",
|
||||
"target = df.iloc[:, -1].values\n",
|
||||
"# 标准化\n",
|
||||
"stdscaler = StandardScaler()\n",
|
||||
"info_train, info_test, target_train, target_test = train_test_split(\n",
|
||||
" info, target, test_size=0.3, random_state=420)\n",
|
||||
"info_train = stdscaler.fit_transform(info_train)\n",
|
||||
"info_test = stdscaler.transform(info_test)\n",
|
||||
"df_train = np.hstack((info_train, target_train.reshape(-1, 1)))\n",
|
||||
"df_test = np.hstack((info_test, target_test.reshape(-1, 1)))\n",
|
||||
"df_train = pd.DataFrame(df_train, columns=col_df)\n",
|
||||
"df_test = pd.DataFrame(df_test, columns=col_df)\n",
|
||||
"print(\"df_train:\\n\", df_train)\n",
|
||||
"print(\"============================================\")\n",
|
||||
"print(\"df_test:\\n\", df_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1a9661ee-3dc4-4d5c-8a8d-bab3dc299d41",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"设置训练集和验证集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a10021bb-4e08-49b2-8269-114b8ce37be6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 初始化数据集\n",
|
||||
"train_data = USDataset(df_train)\n",
|
||||
"test_data = USDataset(df_test)\n",
|
||||
"train_loader = DataLoader(train_data, batch_size=batch_size,\n",
|
||||
" num_workers=num_workers, shuffle=True, drop_last=True)\n",
|
||||
"test_loader = DataLoader(test_data, batch_size=batch_size,\n",
|
||||
" num_workers=num_workers, shuffle=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "41a15589-9aaf-4541-a603-30a9f433c385",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"设计并初始化网络,其中神经网络是有一个输入层,一个输出层,三个隐藏层的前馈神经网络,激活函数是Sigmoid函数"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "764f28bc-81dd-4924-9039-5eceba54d739",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Parameter containing:\n",
|
||||
"tensor([[ 0.1981, -0.0925, 0.1532, -0.2224, 0.1906],\n",
|
||||
" [-0.4302, -0.2503, -0.1189, 0.4182, -0.3575],\n",
|
||||
" [ 0.1431, 0.3615, 0.2464, -0.0292, -0.3012],\n",
|
||||
" [-0.2803, 0.2226, 0.2143, 0.2458, -0.2719],\n",
|
||||
" [ 0.1669, 0.1216, -0.0148, 0.2795, -0.2338],\n",
|
||||
" [-0.0047, 0.1909, -0.1634, 0.3766, 0.3018],\n",
|
||||
" [ 0.4288, 0.3602, -0.3976, -0.1180, 0.3682],\n",
|
||||
" [-0.1992, 0.1626, 0.3656, -0.4359, 0.2141],\n",
|
||||
" [-0.3708, 0.1608, 0.0614, 0.0473, -0.3628],\n",
|
||||
" [-0.3157, 0.1828, 0.3053, -0.2029, -0.3746],\n",
|
||||
" [ 0.0809, -0.1369, 0.2581, 0.4280, 0.2290],\n",
|
||||
" [-0.3749, -0.1776, 0.1791, -0.3965, -0.1430],\n",
|
||||
" [ 0.3633, 0.1174, -0.4268, 0.1386, 0.3422],\n",
|
||||
" [ 0.3554, -0.3894, -0.4382, 0.2104, 0.4403],\n",
|
||||
" [-0.0714, -0.1196, -0.0719, 0.0933, 0.1312]], device='cuda:0',\n",
|
||||
" requires_grad=True), Parameter containing:\n",
|
||||
"tensor([-0.3285, 0.3194, 0.4210, 0.3921, 0.2482, 0.0288, -0.2423, 0.3984,\n",
|
||||
" -0.4366, -0.0518, 0.2578, -0.2965, 0.3803, -0.1613, -0.3905],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[ 0.2330, 0.2382, 0.1661, -0.0358, -0.2211, -0.2119, 0.2325, 0.2296,\n",
|
||||
" 0.1116, -0.1443, -0.1474, 0.1758, 0.1321, -0.1886, 0.1001],\n",
|
||||
" [ 0.1013, -0.1303, -0.1558, -0.1694, -0.0610, 0.0903, 0.2512, 0.2028,\n",
|
||||
" -0.0437, -0.0142, -0.0613, -0.1092, -0.2330, 0.1156, -0.0872],\n",
|
||||
" [-0.2243, 0.1081, 0.1204, -0.1407, 0.2454, 0.1751, -0.0659, -0.0778,\n",
|
||||
" 0.1699, 0.2420, 0.1921, -0.1972, 0.0297, -0.1185, 0.1182],\n",
|
||||
" [ 0.1126, -0.1653, -0.2057, -0.0083, 0.0714, -0.1995, 0.1577, -0.2197,\n",
|
||||
" -0.0100, -0.1859, -0.2013, -0.2470, -0.1084, -0.0784, -0.0567],\n",
|
||||
" [-0.2133, 0.0452, 0.0957, 0.1400, -0.1520, -0.0957, 0.1586, 0.0450,\n",
|
||||
" -0.1405, -0.1795, -0.1409, 0.2279, -0.1048, -0.0113, 0.1268],\n",
|
||||
" [-0.2536, -0.1572, 0.2436, -0.1054, -0.2366, -0.1888, 0.0218, -0.0935,\n",
|
||||
" -0.1377, 0.0630, 0.0064, -0.0734, 0.0873, 0.1178, -0.1568],\n",
|
||||
" [-0.0283, 0.2286, 0.2539, 0.1228, 0.2385, -0.2216, -0.1502, -0.2049,\n",
|
||||
" 0.0286, 0.1939, 0.0313, 0.1531, 0.2514, -0.2186, 0.0171],\n",
|
||||
" [ 0.1501, 0.0752, -0.1921, -0.1191, -0.0748, 0.2097, 0.2306, -0.2150,\n",
|
||||
" -0.2348, 0.0575, -0.0489, 0.0678, -0.1273, -0.1990, -0.0770],\n",
|
||||
" [-0.0807, -0.0379, -0.1252, -0.1873, -0.0793, 0.0908, 0.1573, 0.1212,\n",
|
||||
" 0.0937, 0.1422, -0.1502, -0.0369, 0.0321, -0.0830, 0.0571],\n",
|
||||
" [-0.1586, -0.0179, 0.2351, 0.1576, 0.0692, -0.0726, -0.0875, -0.0382,\n",
|
||||
" -0.2227, -0.0080, 0.0904, 0.1201, 0.2029, 0.0351, 0.0133]],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([-0.0141, 0.0566, -0.2259, -0.1533, -0.0400, 0.0475, -0.1080, -0.2453,\n",
|
||||
" -0.1625, 0.1345], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[-0.2120, -0.1994, -0.0052, 0.2988, -0.2471, 0.0135, 0.1283, -0.0201,\n",
|
||||
" 0.1182, -0.1972],\n",
|
||||
" [-0.0328, 0.0113, 0.1010, 0.0589, -0.1486, 0.2598, 0.1771, 0.0474,\n",
|
||||
" -0.0413, 0.1537],\n",
|
||||
" [ 0.2083, -0.1183, -0.2833, -0.3092, 0.2081, 0.2566, 0.1134, -0.1159,\n",
|
||||
" -0.2981, -0.2882],\n",
|
||||
" [-0.1033, -0.2929, -0.2451, -0.2850, 0.3157, 0.3092, 0.0539, -0.2594,\n",
|
||||
" 0.1327, -0.0194],\n",
|
||||
" [ 0.1040, -0.3053, -0.1769, -0.2137, -0.0262, 0.2108, 0.0255, -0.1202,\n",
|
||||
" 0.0413, 0.2326],\n",
|
||||
" [-0.0355, -0.2424, 0.1064, 0.0394, -0.1730, 0.2130, 0.1174, -0.1641,\n",
|
||||
" -0.1420, -0.0496]], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([ 0.2195, -0.0579, 0.1234, 0.1060, -0.2140, 0.0288], device='cuda:0',\n",
|
||||
" requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[ 0.2707, 0.1434, 0.0318, -0.0551, -0.1369, -0.1496]],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([-0.0661], device='cuda:0', requires_grad=True)]\n",
|
||||
"==================================================\n",
|
||||
"[Parameter containing:\n",
|
||||
"tensor([[-6.9726e-02, 1.7731e+00, -7.6372e-02, -7.1058e-01, -7.1123e-04],\n",
|
||||
" [-7.1336e-01, -5.4361e-01, -2.6582e-01, 5.6964e-01, -4.9271e-01],\n",
|
||||
" [-1.3366e+00, -6.1715e-01, -8.0915e-01, -1.0871e+00, -1.7130e+00],\n",
|
||||
" [ 1.1439e+00, 1.1702e+00, -5.0160e-02, -6.6409e-01, -9.0478e-02],\n",
|
||||
" [-1.2307e-01, 9.7143e-01, -3.2899e-01, 1.0481e+00, 8.7611e-02],\n",
|
||||
" [-3.6240e-01, 6.6012e-01, -1.7746e+00, 1.7587e-01, 3.9182e-01],\n",
|
||||
" [ 9.5627e-01, -1.4544e+00, 8.6239e-01, 6.4864e-01, 1.2865e+00],\n",
|
||||
" [-9.4586e-01, -2.7911e+00, 1.3601e+00, -8.6802e-01, 3.7288e-01],\n",
|
||||
" [ 6.0642e-01, 8.3700e-01, 4.6751e-01, 9.6069e-01, 1.0403e+00],\n",
|
||||
" [-2.4263e-01, -1.1883e+00, -1.1396e+00, 1.4794e+00, -6.1805e-01],\n",
|
||||
" [-4.5019e-01, 3.1949e-01, 2.2522e-01, -1.4746e+00, -1.5766e+00],\n",
|
||||
" [-1.0685e+00, -8.6054e-01, -4.0435e-01, 1.1669e-01, -5.0217e-01],\n",
|
||||
" [ 3.2439e-01, -6.3227e-01, -9.7261e-01, 4.4522e-01, -1.1293e+00],\n",
|
||||
" [-7.5776e-01, -5.1496e-01, 1.3075e+00, -3.3717e-01, 3.9504e-01],\n",
|
||||
" [ 8.5924e-01, -1.0643e-01, -1.2336e+00, 1.7424e+00, -3.5817e-01]],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[-0.5137, -0.3243, 0.4541, 0.4884, -0.3226, -0.9424, 1.2633, -0.5411,\n",
|
||||
" -0.4047, -0.2105, -1.2065, 0.1194, -0.8706, -0.3806, -0.3470],\n",
|
||||
" [ 1.1079, -0.6157, 0.0802, 0.1673, -1.1413, 0.3360, 2.0812, -0.4530,\n",
|
||||
" 1.9244, -0.1222, 0.8667, -0.4083, -1.5679, 1.3790, -0.7961],\n",
|
||||
" [-1.4472, 0.9576, 0.8538, 0.3888, -2.0439, -0.0737, -1.1856, -0.7158,\n",
|
||||
" 2.7370, 2.5444, -0.7023, 1.4209, 0.0901, 0.8644, 0.7805],\n",
|
||||
" [-0.0899, 0.7839, -0.3986, -0.0622, 0.5134, 0.8943, -0.7492, -1.7582,\n",
|
||||
" 0.7121, 1.2855, -0.4851, 0.0736, -0.2256, 0.4992, -0.7854],\n",
|
||||
" [ 2.5287, 0.5654, 0.7835, -0.6008, -1.9796, 1.6482, -1.0928, -0.0595,\n",
|
||||
" -0.0564, -0.0729, -0.2910, -0.9512, -1.3852, -0.2895, -0.8505],\n",
|
||||
" [-1.4534, -1.3051, 0.1817, 0.5341, 1.1192, -0.0837, -0.5133, -1.2826,\n",
|
||||
" 1.5556, -1.0946, -1.0002, 0.3343, 0.1432, 1.8591, 0.0199],\n",
|
||||
" [ 1.2811, -1.2248, 0.5370, 2.1831, 1.1538, 0.9405, 1.3957, 0.2959,\n",
|
||||
" 1.1088, 1.6196, -0.1388, -0.0548, -0.2318, 0.1342, -0.1255],\n",
|
||||
" [ 1.2155, -0.1889, -0.3502, -0.3293, 0.3099, 2.5889, 1.1579, 0.2868,\n",
|
||||
" -0.1206, 0.3678, 0.2955, 0.7390, -1.1026, -0.5398, 0.0348],\n",
|
||||
" [-0.9189, 0.8586, 1.3321, -0.7081, 0.0048, 0.7643, -0.5209, 0.2017,\n",
|
||||
" 0.8094, -1.4677, 1.0291, 0.1726, 0.3298, 1.6261, -0.1650],\n",
|
||||
" [ 0.7050, -0.3248, -0.2147, -0.4022, -0.3151, 0.9486, 1.3281, -0.2464,\n",
|
||||
" -0.2804, 1.7853, -0.8694, -1.3244, 0.4455, 0.6532, 1.1969]],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',\n",
|
||||
" requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[ 0.7268, -0.5613, 0.3554, -0.7328, 1.7000, -0.1115, -0.4507, 0.0304,\n",
|
||||
" 0.5701, 1.0746],\n",
|
||||
" [-0.7969, 0.7418, 1.5317, -1.0003, -0.2625, 1.7585, -3.1004, -1.2840,\n",
|
||||
" -0.2230, 0.9398],\n",
|
||||
" [-0.2167, -0.4646, -0.1701, -1.1754, 1.2224, -0.1925, -1.0030, -1.3171,\n",
|
||||
" 0.7915, -1.7573],\n",
|
||||
" [ 1.1589, -0.5555, 0.6436, -0.1423, -0.0412, -0.2644, 0.0179, 1.9647,\n",
|
||||
" -1.0822, 0.4064],\n",
|
||||
" [ 0.2140, -0.6242, 1.9719, -1.6243, -0.7416, 1.2082, 0.9993, 2.4522,\n",
|
||||
" 1.2368, -0.9525],\n",
|
||||
" [ 0.9530, 0.5625, -0.6094, -0.3213, 0.6820, -0.6440, 1.5492, -0.4234,\n",
|
||||
" -0.3488, -0.1487]], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([0., 0., 0., 0., 0., 0.], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[-1.2324, -0.2007, -1.9850, -1.1963, 0.4371, -0.8892]],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([0.], device='cuda:0', requires_grad=True)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def initialize(self):\n",
|
||||
" '''\n",
|
||||
" 初始化网络参数\n",
|
||||
" '''\n",
|
||||
" for m in self.modules():\n",
|
||||
" if isinstance(m, nn.Linear):\n",
|
||||
" torch.nn.init.normal_(m.weight.data, 0.1)\n",
|
||||
" if m.bias is not None:\n",
|
||||
" torch.nn.init.zeros_(m.bias.data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Net(nn.Module):\n",
|
||||
" '''\n",
|
||||
" 网络结构\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
" def __init__(self, **kwargs):\n",
|
||||
" super(Net, self).__init__()\n",
|
||||
" self.fc = nn.Sequential(\n",
|
||||
" nn.Linear(5, 15),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" nn.Linear(15, 10),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" nn.Linear(10, 6),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" nn.Linear(6, 1)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = x.view(-1, 5)\n",
|
||||
" x = self.fc(x)\n",
|
||||
" return x.squeeze(-1)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# 初始化网络\n",
|
||||
"model = Net()\n",
|
||||
"model = model.cuda()\n",
|
||||
"print(list(model.parameters()))\n",
|
||||
"initialize(model)\n",
|
||||
"print(\"==================================================\")\n",
|
||||
"print(list(model.parameters()))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "20fcccfe-3f82-44c3-bba2-9c66fb15cecd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"定义损失函数为均方误差损失函数,优化算法为随机梯度下降"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "7b530a85-50b1-4d0f-b47e-14831d0ee85d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 损失函数\n",
|
||||
"criterion = nn.MSELoss()\n",
|
||||
"# 优化器\n",
|
||||
"optimizer = optim.SGD(model.parameters(), lr=lr)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "8694737f-afb5-4e58-ad14-3220971b142f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(epoch):\n",
|
||||
" '''\n",
|
||||
" 训练器\n",
|
||||
"\n",
|
||||
" Parameters\n",
|
||||
" ----------\n",
|
||||
" epoch : int\n",
|
||||
"\n",
|
||||
" Returns\n",
|
||||
" -------\n",
|
||||
" None.\n",
|
||||
"\n",
|
||||
" '''\n",
|
||||
" model.train()\n",
|
||||
" train_loss = 0\n",
|
||||
" for info, target in train_loader:\n",
|
||||
" info = info.cuda()\n",
|
||||
" target = target.cuda()\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" output = model(info)\n",
|
||||
" loss = criterion(output, target)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" train_loss += loss.item()*info.size(0)\n",
|
||||
" train_loss = train_loss/len(train_loader.dataset)\n",
|
||||
" print(\"epoch:%d, train loss:%.4f\" % (epoch, train_loss))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "bb6be6f8-16e6-4c2b-a56c-232b09faa654",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def validate(epoch):\n",
|
||||
" '''\n",
|
||||
" 验证\n",
|
||||
"\n",
|
||||
" Parameters\n",
|
||||
" ----------\n",
|
||||
" epoch : int\n",
|
||||
"\n",
|
||||
" Returns\n",
|
||||
" -------\n",
|
||||
" None.\n",
|
||||
"\n",
|
||||
" '''\n",
|
||||
" model.eval()\n",
|
||||
" val_loss = 0\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for info, target in test_loader:\n",
|
||||
" info, target = info.cuda(), target.cuda()\n",
|
||||
" output = model(info)\n",
|
||||
" loss = criterion(output, target)\n",
|
||||
" val_loss += loss.item()*info.size(0)\n",
|
||||
" val_loss = val_loss/len(test_loader.dataset)\n",
|
||||
" print(\"epoch:%d, validation loss:%.4f\" %\n",
|
||||
" (epoch, val_loss))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "40532ea8-acfe-469a-aa9f-14d3bec2619d",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1, train loss:2.0057\n",
|
||||
"epoch:1, validation loss:0.2197\n",
|
||||
"epoch:2, train loss:0.2096\n",
|
||||
"epoch:2, validation loss:0.2196\n",
|
||||
"epoch:3, train loss:0.2098\n",
|
||||
"epoch:3, validation loss:0.2196\n",
|
||||
"epoch:4, train loss:0.2098\n",
|
||||
"epoch:4, validation loss:0.2196\n",
|
||||
"epoch:5, train loss:0.2100\n",
|
||||
"epoch:5, validation loss:0.2195\n",
|
||||
"epoch:6, train loss:0.2099\n",
|
||||
"epoch:6, validation loss:0.2195\n",
|
||||
"epoch:7, train loss:0.2101\n",
|
||||
"epoch:7, validation loss:0.2195\n",
|
||||
"epoch:8, train loss:0.2099\n",
|
||||
"epoch:8, validation loss:0.2194\n",
|
||||
"epoch:9, train loss:0.2099\n",
|
||||
"epoch:9, validation loss:0.2194\n",
|
||||
"epoch:10, train loss:0.2099\n",
|
||||
"epoch:10, validation loss:0.2194\n",
|
||||
"epoch:11, train loss:0.2093\n",
|
||||
"epoch:11, validation loss:0.2193\n",
|
||||
"epoch:12, train loss:0.2099\n",
|
||||
"epoch:12, validation loss:0.2193\n",
|
||||
"epoch:13, train loss:0.2100\n",
|
||||
"epoch:13, validation loss:0.2193\n",
|
||||
"epoch:14, train loss:0.2097\n",
|
||||
"epoch:14, validation loss:0.2193\n",
|
||||
"epoch:15, train loss:0.2094\n",
|
||||
"epoch:15, validation loss:0.2192\n",
|
||||
"epoch:16, train loss:0.2091\n",
|
||||
"epoch:16, validation loss:0.2192\n",
|
||||
"epoch:17, train loss:0.2098\n",
|
||||
"epoch:17, validation loss:0.2192\n",
|
||||
"epoch:18, train loss:0.2089\n",
|
||||
"epoch:18, validation loss:0.2191\n",
|
||||
"epoch:19, train loss:0.2091\n",
|
||||
"epoch:19, validation loss:0.2191\n",
|
||||
"epoch:20, train loss:0.2096\n",
|
||||
"epoch:20, validation loss:0.2190\n",
|
||||
"epoch:21, train loss:0.2092\n",
|
||||
"epoch:21, validation loss:0.2190\n",
|
||||
"epoch:22, train loss:0.2084\n",
|
||||
"epoch:22, validation loss:0.2190\n",
|
||||
"epoch:23, train loss:0.2093\n",
|
||||
"epoch:23, validation loss:0.2190\n",
|
||||
"epoch:24, train loss:0.2088\n",
|
||||
"epoch:24, validation loss:0.2190\n",
|
||||
"epoch:25, train loss:0.2088\n",
|
||||
"epoch:25, validation loss:0.2189\n",
|
||||
"epoch:26, train loss:0.2090\n",
|
||||
"epoch:26, validation loss:0.2189\n",
|
||||
"epoch:27, train loss:0.2092\n",
|
||||
"epoch:27, validation loss:0.2189\n",
|
||||
"epoch:28, train loss:0.2086\n",
|
||||
"epoch:28, validation loss:0.2188\n",
|
||||
"epoch:29, train loss:0.2087\n",
|
||||
"epoch:29, validation loss:0.2189\n",
|
||||
"epoch:30, train loss:0.2089\n",
|
||||
"epoch:30, validation loss:0.2188\n",
|
||||
"epoch:31, train loss:0.2090\n",
|
||||
"epoch:31, validation loss:0.2187\n",
|
||||
"epoch:32, train loss:0.2087\n",
|
||||
"epoch:32, validation loss:0.2187\n",
|
||||
"epoch:33, train loss:0.2092\n",
|
||||
"epoch:33, validation loss:0.2187\n",
|
||||
"epoch:34, train loss:0.2092\n",
|
||||
"epoch:34, validation loss:0.2187\n",
|
||||
"epoch:35, train loss:0.2091\n",
|
||||
"epoch:35, validation loss:0.2186\n",
|
||||
"epoch:36, train loss:0.2089\n",
|
||||
"epoch:36, validation loss:0.2186\n",
|
||||
"epoch:37, train loss:0.2086\n",
|
||||
"epoch:37, validation loss:0.2186\n",
|
||||
"epoch:38, train loss:0.2090\n",
|
||||
"epoch:38, validation loss:0.2186\n",
|
||||
"epoch:39, train loss:0.2088\n",
|
||||
"epoch:39, validation loss:0.2185\n",
|
||||
"epoch:40, train loss:0.2087\n",
|
||||
"epoch:40, validation loss:0.2185\n",
|
||||
"epoch:41, train loss:0.2087\n",
|
||||
"epoch:41, validation loss:0.2186\n",
|
||||
"epoch:42, train loss:0.2085\n",
|
||||
"epoch:42, validation loss:0.2185\n",
|
||||
"epoch:43, train loss:0.2088\n",
|
||||
"epoch:43, validation loss:0.2184\n",
|
||||
"epoch:44, train loss:0.2089\n",
|
||||
"epoch:44, validation loss:0.2184\n",
|
||||
"epoch:45, train loss:0.2089\n",
|
||||
"epoch:45, validation loss:0.2184\n",
|
||||
"epoch:46, train loss:0.2086\n",
|
||||
"epoch:46, validation loss:0.2184\n",
|
||||
"epoch:47, train loss:0.2089\n",
|
||||
"epoch:47, validation loss:0.2183\n",
|
||||
"epoch:48, train loss:0.2084\n",
|
||||
"epoch:48, validation loss:0.2183\n",
|
||||
"epoch:49, train loss:0.2080\n",
|
||||
"epoch:49, validation loss:0.2183\n",
|
||||
"epoch:50, train loss:0.2088\n",
|
||||
"epoch:50, validation loss:0.2183\n",
|
||||
"epoch:51, train loss:0.2086\n",
|
||||
"epoch:51, validation loss:0.2183\n",
|
||||
"epoch:52, train loss:0.2082\n",
|
||||
"epoch:52, validation loss:0.2182\n",
|
||||
"epoch:53, train loss:0.2080\n",
|
||||
"epoch:53, validation loss:0.2182\n",
|
||||
"epoch:54, train loss:0.2088\n",
|
||||
"epoch:54, validation loss:0.2182\n",
|
||||
"epoch:55, train loss:0.2086\n",
|
||||
"epoch:55, validation loss:0.2181\n",
|
||||
"epoch:56, train loss:0.2082\n",
|
||||
"epoch:56, validation loss:0.2181\n",
|
||||
"epoch:57, train loss:0.2088\n",
|
||||
"epoch:57, validation loss:0.2181\n",
|
||||
"epoch:58, train loss:0.2084\n",
|
||||
"epoch:58, validation loss:0.2181\n",
|
||||
"epoch:59, train loss:0.2085\n",
|
||||
"epoch:59, validation loss:0.2180\n",
|
||||
"epoch:60, train loss:0.2085\n",
|
||||
"epoch:60, validation loss:0.2180\n",
|
||||
"epoch:61, train loss:0.2084\n",
|
||||
"epoch:61, validation loss:0.2180\n",
|
||||
"epoch:62, train loss:0.2082\n",
|
||||
"epoch:62, validation loss:0.2180\n",
|
||||
"epoch:63, train loss:0.2081\n",
|
||||
"epoch:63, validation loss:0.2180\n",
|
||||
"epoch:64, train loss:0.2081\n",
|
||||
"epoch:64, validation loss:0.2179\n",
|
||||
"epoch:65, train loss:0.2078\n",
|
||||
"epoch:65, validation loss:0.2179\n",
|
||||
"epoch:66, train loss:0.2079\n",
|
||||
"epoch:66, validation loss:0.2179\n",
|
||||
"epoch:67, train loss:0.2080\n",
|
||||
"epoch:67, validation loss:0.2179\n",
|
||||
"epoch:68, train loss:0.2080\n",
|
||||
"epoch:68, validation loss:0.2178\n",
|
||||
"epoch:69, train loss:0.2082\n",
|
||||
"epoch:69, validation loss:0.2178\n",
|
||||
"epoch:70, train loss:0.2079\n",
|
||||
"epoch:70, validation loss:0.2178\n",
|
||||
"epoch:71, train loss:0.2083\n",
|
||||
"epoch:71, validation loss:0.2178\n",
|
||||
"epoch:72, train loss:0.2081\n",
|
||||
"epoch:72, validation loss:0.2177\n",
|
||||
"epoch:73, train loss:0.2079\n",
|
||||
"epoch:73, validation loss:0.2177\n",
|
||||
"epoch:74, train loss:0.2076\n",
|
||||
"epoch:74, validation loss:0.2177\n",
|
||||
"epoch:75, train loss:0.2081\n",
|
||||
"epoch:75, validation loss:0.2177\n",
|
||||
"epoch:76, train loss:0.2081\n",
|
||||
"epoch:76, validation loss:0.2176\n",
|
||||
"epoch:77, train loss:0.2080\n",
|
||||
"epoch:77, validation loss:0.2176\n",
|
||||
"epoch:78, train loss:0.2079\n",
|
||||
"epoch:78, validation loss:0.2176\n",
|
||||
"epoch:79, train loss:0.2081\n",
|
||||
"epoch:79, validation loss:0.2176\n",
|
||||
"epoch:80, train loss:0.2078\n",
|
||||
"epoch:80, validation loss:0.2175\n",
|
||||
"epoch:81, train loss:0.2076\n",
|
||||
"epoch:81, validation loss:0.2175\n",
|
||||
"epoch:82, train loss:0.2077\n",
|
||||
"epoch:82, validation loss:0.2175\n",
|
||||
"epoch:83, train loss:0.2077\n",
|
||||
"epoch:83, validation loss:0.2175\n",
|
||||
"epoch:84, train loss:0.2080\n",
|
||||
"epoch:84, validation loss:0.2175\n",
|
||||
"epoch:85, train loss:0.2074\n",
|
||||
"epoch:85, validation loss:0.2174\n",
|
||||
"epoch:86, train loss:0.2077\n",
|
||||
"epoch:86, validation loss:0.2174\n",
|
||||
"epoch:87, train loss:0.2079\n",
|
||||
"epoch:87, validation loss:0.2174\n",
|
||||
"epoch:88, train loss:0.2079\n",
|
||||
"epoch:88, validation loss:0.2173\n",
|
||||
"epoch:89, train loss:0.2075\n",
|
||||
"epoch:89, validation loss:0.2173\n",
|
||||
"epoch:90, train loss:0.2075\n",
|
||||
"epoch:90, validation loss:0.2173\n",
|
||||
"epoch:91, train loss:0.2070\n",
|
||||
"epoch:91, validation loss:0.2173\n",
|
||||
"epoch:92, train loss:0.2079\n",
|
||||
"epoch:92, validation loss:0.2173\n",
|
||||
"epoch:93, train loss:0.2075\n",
|
||||
"epoch:93, validation loss:0.2172\n",
|
||||
"epoch:94, train loss:0.2078\n",
|
||||
"epoch:94, validation loss:0.2172\n",
|
||||
"epoch:95, train loss:0.2068\n",
|
||||
"epoch:95, validation loss:0.2172\n",
|
||||
"epoch:96, train loss:0.2072\n",
|
||||
"epoch:96, validation loss:0.2172\n",
|
||||
"epoch:97, train loss:0.2076\n",
|
||||
"epoch:97, validation loss:0.2172\n",
|
||||
"epoch:98, train loss:0.2079\n",
|
||||
"epoch:98, validation loss:0.2171\n",
|
||||
"epoch:99, train loss:0.2074\n",
|
||||
"epoch:99, validation loss:0.2171\n",
|
||||
"epoch:100, train loss:0.2074\n",
|
||||
"epoch:100, validation loss:0.2171\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 训练过程\n",
|
||||
"for epoch in range(1, max_epochs+1):\n",
|
||||
" train(epoch)\n",
|
||||
" validate(epoch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e3f438d3-813f-4db5-b8e9-da71d54730d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user