前馈神经网络回归

Signed-off-by: 张卓立 <13190677+zhang-zhuoli@user.noreply.gitee.com>
This commit is contained in:
张卓立
2023-07-15 12:05:42 +00:00
committed by Gitee
parent a395c9c6e0
commit e289ff3c0c
@@ -0,0 +1,944 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a31d3c40-3593-4ca3-980e-578ee51e171a",
"metadata": {},
"source": [
"# 用神经网络进行回归预测"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "738c03b4-3ca0-4a87-9143-53ddea5179be",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "markdown",
"id": "637063c9-5924-416c-8201-adaae8514ffd",
"metadata": {},
"source": [
"设置神经网络超参数,批大小为500,学习率为0.01,一共分别进行100次正向传播和反向传播"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b909615c-783d-4bfb-acf6-1bc3bacfbb18",
"metadata": {},
"outputs": [],
"source": [
"# torch参数\n",
"batch_size = 500\n",
"lr = 0.01\n",
"max_epochs = 100\n",
"num_workers = 0\n",
"device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"id": "904cfade-5667-440a-a13c-34fdf57f7b0f",
"metadata": {},
"source": [
"定义所需数据集类"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f6e65156-8535-49c2-8d9d-ecb8639855c9",
"metadata": {},
"outputs": [],
"source": [
"variables = ['number_of_reviews', 'price', 'accommodates',\n",
" 'host_response_rate', 'host_acceptance_rate', 'review_scores_rating']\n",
"# 数据集类\n",
"\n",
"\n",
"class USDataset(Dataset):\n",
" def __init__(self, df):\n",
" '''\n",
" 初始化\n",
" df: 处理后的数据集\n",
" '''\n",
" self.df = df\n",
" self.info = df[['number_of_reviews', 'price', 'accommodates',\n",
" 'host_response_rate', 'host_acceptance_rate']].values\n",
" self.target = df['review_scores_rating'].values\n",
"\n",
" def __getitem__(self, index):\n",
" '''\n",
" 根据编号返回信息\n",
" index: 样本编号\n",
" '''\n",
" info = self.info[index]\n",
" target = self.target[index]\n",
" return info, target\n",
"\n",
" def __len__(self):\n",
" '''\n",
" 返回数据集样本个数\n",
" '''\n",
" return len(self.df)"
]
},
{
"cell_type": "markdown",
"id": "fc574f2e-6443-48af-9fee-dcf5fb29af58",
"metadata": {},
"source": [
"读入数据,由于样本量很大,直接删除有缺失值的样本。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "17f8aedb-abf1-4622-98ec-d525a779274b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>host_response_rate</th>\n",
" <th>host_acceptance_rate</th>\n",
" <th>accommodates</th>\n",
" <th>price</th>\n",
" <th>number_of_reviews</th>\n",
" <th>review_scores_rating</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.00</td>\n",
" <td>0.33</td>\n",
" <td>2.0</td>\n",
" <td>120.0</td>\n",
" <td>90.0</td>\n",
" <td>4.50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.00</td>\n",
" <td>0.98</td>\n",
" <td>2.0</td>\n",
" <td>90.0</td>\n",
" <td>351.0</td>\n",
" <td>4.58</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.00</td>\n",
" <td>0.98</td>\n",
" <td>2.0</td>\n",
" <td>66.0</td>\n",
" <td>67.0</td>\n",
" <td>4.52</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.00</td>\n",
" <td>0.98</td>\n",
" <td>1.0</td>\n",
" <td>33.0</td>\n",
" <td>297.0</td>\n",
" <td>4.70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1.00</td>\n",
" <td>1.00</td>\n",
" <td>2.0</td>\n",
" <td>45.0</td>\n",
" <td>42.0</td>\n",
" <td>4.98</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203252</th>\n",
" <td>1.00</td>\n",
" <td>0.93</td>\n",
" <td>4.0</td>\n",
" <td>152.0</td>\n",
" <td>1.0</td>\n",
" <td>4.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203253</th>\n",
" <td>1.00</td>\n",
" <td>0.97</td>\n",
" <td>2.0</td>\n",
" <td>45.0</td>\n",
" <td>1.0</td>\n",
" <td>3.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203254</th>\n",
" <td>1.00</td>\n",
" <td>0.97</td>\n",
" <td>2.0</td>\n",
" <td>40.0</td>\n",
" <td>1.0</td>\n",
" <td>1.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203276</th>\n",
" <td>0.99</td>\n",
" <td>0.99</td>\n",
" <td>2.0</td>\n",
" <td>43.0</td>\n",
" <td>1.0</td>\n",
" <td>5.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>203308</th>\n",
" <td>1.00</td>\n",
" <td>1.00</td>\n",
" <td>3.0</td>\n",
" <td>110.0</td>\n",
" <td>1.0</td>\n",
" <td>5.00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>134835 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" host_response_rate host_acceptance_rate accommodates price \\\n",
"0 1.00 0.33 2.0 120.0 \n",
"1 1.00 0.98 2.0 90.0 \n",
"2 1.00 0.98 2.0 66.0 \n",
"3 1.00 0.98 1.0 33.0 \n",
"5 1.00 1.00 2.0 45.0 \n",
"... ... ... ... ... \n",
"203252 1.00 0.93 4.0 152.0 \n",
"203253 1.00 0.97 2.0 45.0 \n",
"203254 1.00 0.97 2.0 40.0 \n",
"203276 0.99 0.99 2.0 43.0 \n",
"203308 1.00 1.00 3.0 110.0 \n",
"\n",
" number_of_reviews review_scores_rating \n",
"0 90.0 4.50 \n",
"1 351.0 4.58 \n",
"2 67.0 4.52 \n",
"3 297.0 4.70 \n",
"5 42.0 4.98 \n",
"... ... ... \n",
"203252 1.0 4.00 \n",
"203253 1.0 3.00 \n",
"203254 1.0 1.00 \n",
"203276 1.0 5.00 \n",
"203308 1.0 5.00 \n",
"\n",
"[134835 rows x 6 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 读入数据\n",
"df = pd.read_csv('../data/2022-01(US_25).csv', usecols=variables)\n",
"df['price'] = df['price'].replace('\\$', '', regex=True)\n",
"df['price'] = df['price'].replace('\\,', '', regex=True).astype(float)\n",
"df[['host_response_rate', 'host_acceptance_rate']] = df[['host_response_rate',\n",
" 'host_acceptance_rate']].replace('\\%', '', regex=True).astype(float)*0.01\n",
"df[['number_of_reviews']] = df[['number_of_reviews']].astype(float)\n",
"for col in variables:\n",
" df[col] = df[col].astype(np.float32)\n",
" df = df[np.isnan(df[col]) != 1]\n",
"df"
]
},
{
"cell_type": "markdown",
"id": "758d8171-9310-4a51-b9a6-b0fa6274ffe1",
"metadata": {},
"source": [
"将数据标准化"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2cbe7cab-42cd-4b3b-bf11-080cc0a0f4c4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"df_train:\n",
" host_response_rate host_acceptance_rate accommodates price \\\n",
"0 0.297689 0.440210 -0.101531 -0.147944 \n",
"1 0.297689 0.440210 -0.101531 -0.272867 \n",
"2 0.297689 0.393235 -0.807638 -0.258453 \n",
"3 0.297689 0.534159 -0.101531 -0.366559 \n",
"4 0.297689 0.393235 -0.454584 -0.099897 \n",
"... ... ... ... ... \n",
"94379 0.297689 -0.499281 -0.101531 0.313307 \n",
"94380 0.297689 0.440210 2.016788 0.697684 \n",
"94381 0.297689 0.534159 0.604575 -0.222417 \n",
"94382 0.297689 0.252311 -0.807638 -0.301695 \n",
"94383 0.297689 -0.687179 -0.807638 -0.318512 \n",
"\n",
" number_of_reviews review_scores_rating \n",
"0 -0.629684 3.00 \n",
"1 0.060760 4.93 \n",
"2 -0.629684 4.00 \n",
"3 2.045787 4.91 \n",
"4 0.295018 4.83 \n",
"... ... ... \n",
"94379 -0.407756 5.00 \n",
"94380 -0.531050 5.00 \n",
"94381 0.036101 4.69 \n",
"94382 1.145743 4.85 \n",
"94383 0.726545 4.83 \n",
"\n",
"[94384 rows x 6 columns]\n",
"============================================\n",
"df_test:\n",
" host_response_rate host_acceptance_rate accommodates price \\\n",
"0 0.297689 0.581133 -1.160691 -0.392985 \n",
"1 0.297689 -0.969026 0.604575 0.714500 \n",
"2 0.297689 0.581133 -0.807638 -0.042241 \n",
"3 0.095805 -1.626670 0.604575 3.104842 \n",
"4 0.297689 -0.452307 -0.807638 -0.289683 \n",
"... ... ... ... ... \n",
"40446 0.095805 0.534159 -0.101531 -0.244039 \n",
"40447 -0.038785 0.534159 1.310681 1.353526 \n",
"40448 0.297689 -0.217434 -0.101531 0.197994 \n",
"40449 0.297689 0.440210 0.604575 0.358952 \n",
"40450 0.297689 0.581133 -0.807638 -0.159956 \n",
"\n",
" number_of_reviews review_scores_rating \n",
"0 0.726545 4.76 \n",
"1 -0.580367 4.60 \n",
"2 3.352699 4.92 \n",
"3 -0.617355 2.50 \n",
"4 -0.592696 5.00 \n",
"... ... ... \n",
"40446 -0.617355 5.00 \n",
"40447 -0.629684 5.00 \n",
"40448 -0.185828 4.97 \n",
"40449 -0.494061 4.67 \n",
"40450 0.368994 4.96 \n",
"\n",
"[40451 rows x 6 columns]\n"
]
}
],
"source": [
"# 固定划分测试集和训练集\n",
"col_df = df.columns\n",
"info = df.iloc[:, :-1].values\n",
"target = df.iloc[:, -1].values\n",
"# 标准化\n",
"stdscaler = StandardScaler()\n",
"info_train, info_test, target_train, target_test = train_test_split(\n",
" info, target, test_size=0.3, random_state=420)\n",
"info_train = stdscaler.fit_transform(info_train)\n",
"info_test = stdscaler.transform(info_test)\n",
"df_train = np.hstack((info_train, target_train.reshape(-1, 1)))\n",
"df_test = np.hstack((info_test, target_test.reshape(-1, 1)))\n",
"df_train = pd.DataFrame(df_train, columns=col_df)\n",
"df_test = pd.DataFrame(df_test, columns=col_df)\n",
"print(\"df_train:\\n\", df_train)\n",
"print(\"============================================\")\n",
"print(\"df_test:\\n\", df_test)"
]
},
{
"cell_type": "markdown",
"id": "1a9661ee-3dc4-4d5c-8a8d-bab3dc299d41",
"metadata": {},
"source": [
"设置训练集和验证集"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a10021bb-4e08-49b2-8269-114b8ce37be6",
"metadata": {},
"outputs": [],
"source": [
"# 初始化数据集\n",
"train_data = USDataset(df_train)\n",
"test_data = USDataset(df_test)\n",
"train_loader = DataLoader(train_data, batch_size=batch_size,\n",
" num_workers=num_workers, shuffle=True, drop_last=True)\n",
"test_loader = DataLoader(test_data, batch_size=batch_size,\n",
" num_workers=num_workers, shuffle=False)"
]
},
{
"cell_type": "markdown",
"id": "41a15589-9aaf-4541-a603-30a9f433c385",
"metadata": {},
"source": [
"设计并初始化网络,其中神经网络是有一个输入层,一个输出层,三个隐藏层的前馈神经网络,激活函数是Sigmoid函数"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "764f28bc-81dd-4924-9039-5eceba54d739",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Parameter containing:\n",
"tensor([[ 0.1981, -0.0925, 0.1532, -0.2224, 0.1906],\n",
" [-0.4302, -0.2503, -0.1189, 0.4182, -0.3575],\n",
" [ 0.1431, 0.3615, 0.2464, -0.0292, -0.3012],\n",
" [-0.2803, 0.2226, 0.2143, 0.2458, -0.2719],\n",
" [ 0.1669, 0.1216, -0.0148, 0.2795, -0.2338],\n",
" [-0.0047, 0.1909, -0.1634, 0.3766, 0.3018],\n",
" [ 0.4288, 0.3602, -0.3976, -0.1180, 0.3682],\n",
" [-0.1992, 0.1626, 0.3656, -0.4359, 0.2141],\n",
" [-0.3708, 0.1608, 0.0614, 0.0473, -0.3628],\n",
" [-0.3157, 0.1828, 0.3053, -0.2029, -0.3746],\n",
" [ 0.0809, -0.1369, 0.2581, 0.4280, 0.2290],\n",
" [-0.3749, -0.1776, 0.1791, -0.3965, -0.1430],\n",
" [ 0.3633, 0.1174, -0.4268, 0.1386, 0.3422],\n",
" [ 0.3554, -0.3894, -0.4382, 0.2104, 0.4403],\n",
" [-0.0714, -0.1196, -0.0719, 0.0933, 0.1312]], device='cuda:0',\n",
" requires_grad=True), Parameter containing:\n",
"tensor([-0.3285, 0.3194, 0.4210, 0.3921, 0.2482, 0.0288, -0.2423, 0.3984,\n",
" -0.4366, -0.0518, 0.2578, -0.2965, 0.3803, -0.1613, -0.3905],\n",
" device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([[ 0.2330, 0.2382, 0.1661, -0.0358, -0.2211, -0.2119, 0.2325, 0.2296,\n",
" 0.1116, -0.1443, -0.1474, 0.1758, 0.1321, -0.1886, 0.1001],\n",
" [ 0.1013, -0.1303, -0.1558, -0.1694, -0.0610, 0.0903, 0.2512, 0.2028,\n",
" -0.0437, -0.0142, -0.0613, -0.1092, -0.2330, 0.1156, -0.0872],\n",
" [-0.2243, 0.1081, 0.1204, -0.1407, 0.2454, 0.1751, -0.0659, -0.0778,\n",
" 0.1699, 0.2420, 0.1921, -0.1972, 0.0297, -0.1185, 0.1182],\n",
" [ 0.1126, -0.1653, -0.2057, -0.0083, 0.0714, -0.1995, 0.1577, -0.2197,\n",
" -0.0100, -0.1859, -0.2013, -0.2470, -0.1084, -0.0784, -0.0567],\n",
" [-0.2133, 0.0452, 0.0957, 0.1400, -0.1520, -0.0957, 0.1586, 0.0450,\n",
" -0.1405, -0.1795, -0.1409, 0.2279, -0.1048, -0.0113, 0.1268],\n",
" [-0.2536, -0.1572, 0.2436, -0.1054, -0.2366, -0.1888, 0.0218, -0.0935,\n",
" -0.1377, 0.0630, 0.0064, -0.0734, 0.0873, 0.1178, -0.1568],\n",
" [-0.0283, 0.2286, 0.2539, 0.1228, 0.2385, -0.2216, -0.1502, -0.2049,\n",
" 0.0286, 0.1939, 0.0313, 0.1531, 0.2514, -0.2186, 0.0171],\n",
" [ 0.1501, 0.0752, -0.1921, -0.1191, -0.0748, 0.2097, 0.2306, -0.2150,\n",
" -0.2348, 0.0575, -0.0489, 0.0678, -0.1273, -0.1990, -0.0770],\n",
" [-0.0807, -0.0379, -0.1252, -0.1873, -0.0793, 0.0908, 0.1573, 0.1212,\n",
" 0.0937, 0.1422, -0.1502, -0.0369, 0.0321, -0.0830, 0.0571],\n",
" [-0.1586, -0.0179, 0.2351, 0.1576, 0.0692, -0.0726, -0.0875, -0.0382,\n",
" -0.2227, -0.0080, 0.0904, 0.1201, 0.2029, 0.0351, 0.0133]],\n",
" device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([-0.0141, 0.0566, -0.2259, -0.1533, -0.0400, 0.0475, -0.1080, -0.2453,\n",
" -0.1625, 0.1345], device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([[-0.2120, -0.1994, -0.0052, 0.2988, -0.2471, 0.0135, 0.1283, -0.0201,\n",
" 0.1182, -0.1972],\n",
" [-0.0328, 0.0113, 0.1010, 0.0589, -0.1486, 0.2598, 0.1771, 0.0474,\n",
" -0.0413, 0.1537],\n",
" [ 0.2083, -0.1183, -0.2833, -0.3092, 0.2081, 0.2566, 0.1134, -0.1159,\n",
" -0.2981, -0.2882],\n",
" [-0.1033, -0.2929, -0.2451, -0.2850, 0.3157, 0.3092, 0.0539, -0.2594,\n",
" 0.1327, -0.0194],\n",
" [ 0.1040, -0.3053, -0.1769, -0.2137, -0.0262, 0.2108, 0.0255, -0.1202,\n",
" 0.0413, 0.2326],\n",
" [-0.0355, -0.2424, 0.1064, 0.0394, -0.1730, 0.2130, 0.1174, -0.1641,\n",
" -0.1420, -0.0496]], device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([ 0.2195, -0.0579, 0.1234, 0.1060, -0.2140, 0.0288], device='cuda:0',\n",
" requires_grad=True), Parameter containing:\n",
"tensor([[ 0.2707, 0.1434, 0.0318, -0.0551, -0.1369, -0.1496]],\n",
" device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([-0.0661], device='cuda:0', requires_grad=True)]\n",
"==================================================\n",
"[Parameter containing:\n",
"tensor([[-6.9726e-02, 1.7731e+00, -7.6372e-02, -7.1058e-01, -7.1123e-04],\n",
" [-7.1336e-01, -5.4361e-01, -2.6582e-01, 5.6964e-01, -4.9271e-01],\n",
" [-1.3366e+00, -6.1715e-01, -8.0915e-01, -1.0871e+00, -1.7130e+00],\n",
" [ 1.1439e+00, 1.1702e+00, -5.0160e-02, -6.6409e-01, -9.0478e-02],\n",
" [-1.2307e-01, 9.7143e-01, -3.2899e-01, 1.0481e+00, 8.7611e-02],\n",
" [-3.6240e-01, 6.6012e-01, -1.7746e+00, 1.7587e-01, 3.9182e-01],\n",
" [ 9.5627e-01, -1.4544e+00, 8.6239e-01, 6.4864e-01, 1.2865e+00],\n",
" [-9.4586e-01, -2.7911e+00, 1.3601e+00, -8.6802e-01, 3.7288e-01],\n",
" [ 6.0642e-01, 8.3700e-01, 4.6751e-01, 9.6069e-01, 1.0403e+00],\n",
" [-2.4263e-01, -1.1883e+00, -1.1396e+00, 1.4794e+00, -6.1805e-01],\n",
" [-4.5019e-01, 3.1949e-01, 2.2522e-01, -1.4746e+00, -1.5766e+00],\n",
" [-1.0685e+00, -8.6054e-01, -4.0435e-01, 1.1669e-01, -5.0217e-01],\n",
" [ 3.2439e-01, -6.3227e-01, -9.7261e-01, 4.4522e-01, -1.1293e+00],\n",
" [-7.5776e-01, -5.1496e-01, 1.3075e+00, -3.3717e-01, 3.9504e-01],\n",
" [ 8.5924e-01, -1.0643e-01, -1.2336e+00, 1.7424e+00, -3.5817e-01]],\n",
" device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([[-0.5137, -0.3243, 0.4541, 0.4884, -0.3226, -0.9424, 1.2633, -0.5411,\n",
" -0.4047, -0.2105, -1.2065, 0.1194, -0.8706, -0.3806, -0.3470],\n",
" [ 1.1079, -0.6157, 0.0802, 0.1673, -1.1413, 0.3360, 2.0812, -0.4530,\n",
" 1.9244, -0.1222, 0.8667, -0.4083, -1.5679, 1.3790, -0.7961],\n",
" [-1.4472, 0.9576, 0.8538, 0.3888, -2.0439, -0.0737, -1.1856, -0.7158,\n",
" 2.7370, 2.5444, -0.7023, 1.4209, 0.0901, 0.8644, 0.7805],\n",
" [-0.0899, 0.7839, -0.3986, -0.0622, 0.5134, 0.8943, -0.7492, -1.7582,\n",
" 0.7121, 1.2855, -0.4851, 0.0736, -0.2256, 0.4992, -0.7854],\n",
" [ 2.5287, 0.5654, 0.7835, -0.6008, -1.9796, 1.6482, -1.0928, -0.0595,\n",
" -0.0564, -0.0729, -0.2910, -0.9512, -1.3852, -0.2895, -0.8505],\n",
" [-1.4534, -1.3051, 0.1817, 0.5341, 1.1192, -0.0837, -0.5133, -1.2826,\n",
" 1.5556, -1.0946, -1.0002, 0.3343, 0.1432, 1.8591, 0.0199],\n",
" [ 1.2811, -1.2248, 0.5370, 2.1831, 1.1538, 0.9405, 1.3957, 0.2959,\n",
" 1.1088, 1.6196, -0.1388, -0.0548, -0.2318, 0.1342, -0.1255],\n",
" [ 1.2155, -0.1889, -0.3502, -0.3293, 0.3099, 2.5889, 1.1579, 0.2868,\n",
" -0.1206, 0.3678, 0.2955, 0.7390, -1.1026, -0.5398, 0.0348],\n",
" [-0.9189, 0.8586, 1.3321, -0.7081, 0.0048, 0.7643, -0.5209, 0.2017,\n",
" 0.8094, -1.4677, 1.0291, 0.1726, 0.3298, 1.6261, -0.1650],\n",
" [ 0.7050, -0.3248, -0.2147, -0.4022, -0.3151, 0.9486, 1.3281, -0.2464,\n",
" -0.2804, 1.7853, -0.8694, -1.3244, 0.4455, 0.6532, 1.1969]],\n",
" device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',\n",
" requires_grad=True), Parameter containing:\n",
"tensor([[ 0.7268, -0.5613, 0.3554, -0.7328, 1.7000, -0.1115, -0.4507, 0.0304,\n",
" 0.5701, 1.0746],\n",
" [-0.7969, 0.7418, 1.5317, -1.0003, -0.2625, 1.7585, -3.1004, -1.2840,\n",
" -0.2230, 0.9398],\n",
" [-0.2167, -0.4646, -0.1701, -1.1754, 1.2224, -0.1925, -1.0030, -1.3171,\n",
" 0.7915, -1.7573],\n",
" [ 1.1589, -0.5555, 0.6436, -0.1423, -0.0412, -0.2644, 0.0179, 1.9647,\n",
" -1.0822, 0.4064],\n",
" [ 0.2140, -0.6242, 1.9719, -1.6243, -0.7416, 1.2082, 0.9993, 2.4522,\n",
" 1.2368, -0.9525],\n",
" [ 0.9530, 0.5625, -0.6094, -0.3213, 0.6820, -0.6440, 1.5492, -0.4234,\n",
" -0.3488, -0.1487]], device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([0., 0., 0., 0., 0., 0.], device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([[-1.2324, -0.2007, -1.9850, -1.1963, 0.4371, -0.8892]],\n",
" device='cuda:0', requires_grad=True), Parameter containing:\n",
"tensor([0.], device='cuda:0', requires_grad=True)]\n"
]
}
],
"source": [
"def initialize(self):\n",
" '''\n",
" 初始化网络参数\n",
" '''\n",
" for m in self.modules():\n",
" if isinstance(m, nn.Linear):\n",
" torch.nn.init.normal_(m.weight.data, 0.1)\n",
" if m.bias is not None:\n",
" torch.nn.init.zeros_(m.bias.data)\n",
"\n",
"\n",
"class Net(nn.Module):\n",
" '''\n",
" 网络结构\n",
" '''\n",
"\n",
" def __init__(self, **kwargs):\n",
" super(Net, self).__init__()\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(5, 15),\n",
" nn.Sigmoid(),\n",
" nn.Linear(15, 10),\n",
" nn.Sigmoid(),\n",
" nn.Linear(10, 6),\n",
" nn.Sigmoid(),\n",
" nn.Linear(6, 1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = x.view(-1, 5)\n",
" x = self.fc(x)\n",
" return x.squeeze(-1)\n",
"\n",
"\n",
"# 初始化网络\n",
"model = Net()\n",
"model = model.cuda()\n",
"print(list(model.parameters()))\n",
"initialize(model)\n",
"print(\"==================================================\")\n",
"print(list(model.parameters()))"
]
},
{
"cell_type": "markdown",
"id": "20fcccfe-3f82-44c3-bba2-9c66fb15cecd",
"metadata": {},
"source": [
"定义损失函数为均方误差损失函数,优化算法为随机梯度下降"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7b530a85-50b1-4d0f-b47e-14831d0ee85d",
"metadata": {},
"outputs": [],
"source": [
"# 损失函数\n",
"criterion = nn.MSELoss()\n",
"# 优化器\n",
"optimizer = optim.SGD(model.parameters(), lr=lr)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8694737f-afb5-4e58-ad14-3220971b142f",
"metadata": {},
"outputs": [],
"source": [
"def train(epoch):\n",
" '''\n",
" 训练器\n",
"\n",
" Parameters\n",
" ----------\n",
" epoch : int\n",
"\n",
" Returns\n",
" -------\n",
" None.\n",
"\n",
" '''\n",
" model.train()\n",
" train_loss = 0\n",
" for info, target in train_loader:\n",
" info = info.cuda()\n",
" target = target.cuda()\n",
" optimizer.zero_grad()\n",
" output = model(info)\n",
" loss = criterion(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss += loss.item()*info.size(0)\n",
" train_loss = train_loss/len(train_loader.dataset)\n",
" print(\"epoch:%d, train loss:%.4f\" % (epoch, train_loss))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "bb6be6f8-16e6-4c2b-a56c-232b09faa654",
"metadata": {},
"outputs": [],
"source": [
"def validate(epoch):\n",
" '''\n",
" 验证\n",
"\n",
" Parameters\n",
" ----------\n",
" epoch : int\n",
"\n",
" Returns\n",
" -------\n",
" None.\n",
"\n",
" '''\n",
" model.eval()\n",
" val_loss = 0\n",
" with torch.no_grad():\n",
" for info, target in test_loader:\n",
" info, target = info.cuda(), target.cuda()\n",
" output = model(info)\n",
" loss = criterion(output, target)\n",
" val_loss += loss.item()*info.size(0)\n",
" val_loss = val_loss/len(test_loader.dataset)\n",
" print(\"epoch:%d, validation loss:%.4f\" %\n",
" (epoch, val_loss))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "40532ea8-acfe-469a-aa9f-14d3bec2619d",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:1, train loss:2.0057\n",
"epoch:1, validation loss:0.2197\n",
"epoch:2, train loss:0.2096\n",
"epoch:2, validation loss:0.2196\n",
"epoch:3, train loss:0.2098\n",
"epoch:3, validation loss:0.2196\n",
"epoch:4, train loss:0.2098\n",
"epoch:4, validation loss:0.2196\n",
"epoch:5, train loss:0.2100\n",
"epoch:5, validation loss:0.2195\n",
"epoch:6, train loss:0.2099\n",
"epoch:6, validation loss:0.2195\n",
"epoch:7, train loss:0.2101\n",
"epoch:7, validation loss:0.2195\n",
"epoch:8, train loss:0.2099\n",
"epoch:8, validation loss:0.2194\n",
"epoch:9, train loss:0.2099\n",
"epoch:9, validation loss:0.2194\n",
"epoch:10, train loss:0.2099\n",
"epoch:10, validation loss:0.2194\n",
"epoch:11, train loss:0.2093\n",
"epoch:11, validation loss:0.2193\n",
"epoch:12, train loss:0.2099\n",
"epoch:12, validation loss:0.2193\n",
"epoch:13, train loss:0.2100\n",
"epoch:13, validation loss:0.2193\n",
"epoch:14, train loss:0.2097\n",
"epoch:14, validation loss:0.2193\n",
"epoch:15, train loss:0.2094\n",
"epoch:15, validation loss:0.2192\n",
"epoch:16, train loss:0.2091\n",
"epoch:16, validation loss:0.2192\n",
"epoch:17, train loss:0.2098\n",
"epoch:17, validation loss:0.2192\n",
"epoch:18, train loss:0.2089\n",
"epoch:18, validation loss:0.2191\n",
"epoch:19, train loss:0.2091\n",
"epoch:19, validation loss:0.2191\n",
"epoch:20, train loss:0.2096\n",
"epoch:20, validation loss:0.2190\n",
"epoch:21, train loss:0.2092\n",
"epoch:21, validation loss:0.2190\n",
"epoch:22, train loss:0.2084\n",
"epoch:22, validation loss:0.2190\n",
"epoch:23, train loss:0.2093\n",
"epoch:23, validation loss:0.2190\n",
"epoch:24, train loss:0.2088\n",
"epoch:24, validation loss:0.2190\n",
"epoch:25, train loss:0.2088\n",
"epoch:25, validation loss:0.2189\n",
"epoch:26, train loss:0.2090\n",
"epoch:26, validation loss:0.2189\n",
"epoch:27, train loss:0.2092\n",
"epoch:27, validation loss:0.2189\n",
"epoch:28, train loss:0.2086\n",
"epoch:28, validation loss:0.2188\n",
"epoch:29, train loss:0.2087\n",
"epoch:29, validation loss:0.2189\n",
"epoch:30, train loss:0.2089\n",
"epoch:30, validation loss:0.2188\n",
"epoch:31, train loss:0.2090\n",
"epoch:31, validation loss:0.2187\n",
"epoch:32, train loss:0.2087\n",
"epoch:32, validation loss:0.2187\n",
"epoch:33, train loss:0.2092\n",
"epoch:33, validation loss:0.2187\n",
"epoch:34, train loss:0.2092\n",
"epoch:34, validation loss:0.2187\n",
"epoch:35, train loss:0.2091\n",
"epoch:35, validation loss:0.2186\n",
"epoch:36, train loss:0.2089\n",
"epoch:36, validation loss:0.2186\n",
"epoch:37, train loss:0.2086\n",
"epoch:37, validation loss:0.2186\n",
"epoch:38, train loss:0.2090\n",
"epoch:38, validation loss:0.2186\n",
"epoch:39, train loss:0.2088\n",
"epoch:39, validation loss:0.2185\n",
"epoch:40, train loss:0.2087\n",
"epoch:40, validation loss:0.2185\n",
"epoch:41, train loss:0.2087\n",
"epoch:41, validation loss:0.2186\n",
"epoch:42, train loss:0.2085\n",
"epoch:42, validation loss:0.2185\n",
"epoch:43, train loss:0.2088\n",
"epoch:43, validation loss:0.2184\n",
"epoch:44, train loss:0.2089\n",
"epoch:44, validation loss:0.2184\n",
"epoch:45, train loss:0.2089\n",
"epoch:45, validation loss:0.2184\n",
"epoch:46, train loss:0.2086\n",
"epoch:46, validation loss:0.2184\n",
"epoch:47, train loss:0.2089\n",
"epoch:47, validation loss:0.2183\n",
"epoch:48, train loss:0.2084\n",
"epoch:48, validation loss:0.2183\n",
"epoch:49, train loss:0.2080\n",
"epoch:49, validation loss:0.2183\n",
"epoch:50, train loss:0.2088\n",
"epoch:50, validation loss:0.2183\n",
"epoch:51, train loss:0.2086\n",
"epoch:51, validation loss:0.2183\n",
"epoch:52, train loss:0.2082\n",
"epoch:52, validation loss:0.2182\n",
"epoch:53, train loss:0.2080\n",
"epoch:53, validation loss:0.2182\n",
"epoch:54, train loss:0.2088\n",
"epoch:54, validation loss:0.2182\n",
"epoch:55, train loss:0.2086\n",
"epoch:55, validation loss:0.2181\n",
"epoch:56, train loss:0.2082\n",
"epoch:56, validation loss:0.2181\n",
"epoch:57, train loss:0.2088\n",
"epoch:57, validation loss:0.2181\n",
"epoch:58, train loss:0.2084\n",
"epoch:58, validation loss:0.2181\n",
"epoch:59, train loss:0.2085\n",
"epoch:59, validation loss:0.2180\n",
"epoch:60, train loss:0.2085\n",
"epoch:60, validation loss:0.2180\n",
"epoch:61, train loss:0.2084\n",
"epoch:61, validation loss:0.2180\n",
"epoch:62, train loss:0.2082\n",
"epoch:62, validation loss:0.2180\n",
"epoch:63, train loss:0.2081\n",
"epoch:63, validation loss:0.2180\n",
"epoch:64, train loss:0.2081\n",
"epoch:64, validation loss:0.2179\n",
"epoch:65, train loss:0.2078\n",
"epoch:65, validation loss:0.2179\n",
"epoch:66, train loss:0.2079\n",
"epoch:66, validation loss:0.2179\n",
"epoch:67, train loss:0.2080\n",
"epoch:67, validation loss:0.2179\n",
"epoch:68, train loss:0.2080\n",
"epoch:68, validation loss:0.2178\n",
"epoch:69, train loss:0.2082\n",
"epoch:69, validation loss:0.2178\n",
"epoch:70, train loss:0.2079\n",
"epoch:70, validation loss:0.2178\n",
"epoch:71, train loss:0.2083\n",
"epoch:71, validation loss:0.2178\n",
"epoch:72, train loss:0.2081\n",
"epoch:72, validation loss:0.2177\n",
"epoch:73, train loss:0.2079\n",
"epoch:73, validation loss:0.2177\n",
"epoch:74, train loss:0.2076\n",
"epoch:74, validation loss:0.2177\n",
"epoch:75, train loss:0.2081\n",
"epoch:75, validation loss:0.2177\n",
"epoch:76, train loss:0.2081\n",
"epoch:76, validation loss:0.2176\n",
"epoch:77, train loss:0.2080\n",
"epoch:77, validation loss:0.2176\n",
"epoch:78, train loss:0.2079\n",
"epoch:78, validation loss:0.2176\n",
"epoch:79, train loss:0.2081\n",
"epoch:79, validation loss:0.2176\n",
"epoch:80, train loss:0.2078\n",
"epoch:80, validation loss:0.2175\n",
"epoch:81, train loss:0.2076\n",
"epoch:81, validation loss:0.2175\n",
"epoch:82, train loss:0.2077\n",
"epoch:82, validation loss:0.2175\n",
"epoch:83, train loss:0.2077\n",
"epoch:83, validation loss:0.2175\n",
"epoch:84, train loss:0.2080\n",
"epoch:84, validation loss:0.2175\n",
"epoch:85, train loss:0.2074\n",
"epoch:85, validation loss:0.2174\n",
"epoch:86, train loss:0.2077\n",
"epoch:86, validation loss:0.2174\n",
"epoch:87, train loss:0.2079\n",
"epoch:87, validation loss:0.2174\n",
"epoch:88, train loss:0.2079\n",
"epoch:88, validation loss:0.2173\n",
"epoch:89, train loss:0.2075\n",
"epoch:89, validation loss:0.2173\n",
"epoch:90, train loss:0.2075\n",
"epoch:90, validation loss:0.2173\n",
"epoch:91, train loss:0.2070\n",
"epoch:91, validation loss:0.2173\n",
"epoch:92, train loss:0.2079\n",
"epoch:92, validation loss:0.2173\n",
"epoch:93, train loss:0.2075\n",
"epoch:93, validation loss:0.2172\n",
"epoch:94, train loss:0.2078\n",
"epoch:94, validation loss:0.2172\n",
"epoch:95, train loss:0.2068\n",
"epoch:95, validation loss:0.2172\n",
"epoch:96, train loss:0.2072\n",
"epoch:96, validation loss:0.2172\n",
"epoch:97, train loss:0.2076\n",
"epoch:97, validation loss:0.2172\n",
"epoch:98, train loss:0.2079\n",
"epoch:98, validation loss:0.2171\n",
"epoch:99, train loss:0.2074\n",
"epoch:99, validation loss:0.2171\n",
"epoch:100, train loss:0.2074\n",
"epoch:100, validation loss:0.2171\n"
]
}
],
"source": [
"# 训练过程\n",
"for epoch in range(1, max_epochs+1):\n",
" train(epoch)\n",
" validate(epoch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3f438d3-813f-4db5-b8e9-da71d54730d4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}