删除文件 共享民宿平台担保交易房子评分的影响研究/torch_reg.ipynb
This commit is contained in:
@@ -1,863 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a31d3c40-3593-4ca3-980e-578ee51e171a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 用神经网络进行回归预测"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "738c03b4-3ca0-4a87-9143-53ddea5179be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"from sklearn.preprocessing import StandardScaler\n",
|
||||
"from sklearn.model_selection import train_test_split"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "637063c9-5924-416c-8201-adaae8514ffd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"设置神经网络超参数,批大小为500,学习率为0.01,一共分别进行100次正向传播和反向传播"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b909615c-783d-4bfb-acf6-1bc3bacfbb18",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# torch参数\n",
|
||||
"batch_size = 500\n",
|
||||
"lr = 0.01\n",
|
||||
"max_epochs = 100\n",
|
||||
"num_workers = 0\n",
|
||||
"device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "904cfade-5667-440a-a13c-34fdf57f7b0f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"定义所需数据集类"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f6e65156-8535-49c2-8d9d-ecb8639855c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"variables = ['number_of_reviews', 'price', 'review_scores_rating']\n",
|
||||
"# 数据集类\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class USDataset(Dataset):\n",
|
||||
" def __init__(self, df):\n",
|
||||
" '''\n",
|
||||
" 初始化\n",
|
||||
" df: 处理后的数据集\n",
|
||||
" '''\n",
|
||||
" self.df = df\n",
|
||||
" self.info = df[['number_of_reviews', 'price']].values\n",
|
||||
" self.target = df['review_scores_rating'].values\n",
|
||||
"\n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" '''\n",
|
||||
" 根据编号返回信息\n",
|
||||
" index: 样本编号\n",
|
||||
" '''\n",
|
||||
" info = self.info[index]\n",
|
||||
" target = self.target[index]\n",
|
||||
" return info, target\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" '''\n",
|
||||
" 返回数据集样本个数\n",
|
||||
" '''\n",
|
||||
" return len(self.df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fc574f2e-6443-48af-9fee-dcf5fb29af58",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"读入数据,由于样本量很大,直接删除有缺失值的样本。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "17f8aedb-abf1-4622-98ec-d525a779274b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>price</th>\n",
|
||||
" <th>number_of_reviews</th>\n",
|
||||
" <th>review_scores_rating</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>120.0</td>\n",
|
||||
" <td>90.0</td>\n",
|
||||
" <td>4.50</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>90.0</td>\n",
|
||||
" <td>351.0</td>\n",
|
||||
" <td>4.58</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>66.0</td>\n",
|
||||
" <td>67.0</td>\n",
|
||||
" <td>4.52</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>33.0</td>\n",
|
||||
" <td>297.0</td>\n",
|
||||
" <td>4.70</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>125.0</td>\n",
|
||||
" <td>58.0</td>\n",
|
||||
" <td>4.96</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203252</th>\n",
|
||||
" <td>152.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>4.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203253</th>\n",
|
||||
" <td>45.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>3.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203254</th>\n",
|
||||
" <td>40.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203276</th>\n",
|
||||
" <td>43.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.00</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>203308</th>\n",
|
||||
" <td>110.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>5.00</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>162429 rows × 3 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" price number_of_reviews review_scores_rating\n",
|
||||
"0 120.0 90.0 4.50\n",
|
||||
"1 90.0 351.0 4.58\n",
|
||||
"2 66.0 67.0 4.52\n",
|
||||
"3 33.0 297.0 4.70\n",
|
||||
"4 125.0 58.0 4.96\n",
|
||||
"... ... ... ...\n",
|
||||
"203252 152.0 1.0 4.00\n",
|
||||
"203253 45.0 1.0 3.00\n",
|
||||
"203254 40.0 1.0 1.00\n",
|
||||
"203276 43.0 1.0 5.00\n",
|
||||
"203308 110.0 1.0 5.00\n",
|
||||
"\n",
|
||||
"[162429 rows x 3 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 读入数据\n",
|
||||
"df = pd.read_csv('./data/2022-01(US_25).csv', usecols=variables)\n",
|
||||
"df['price'] = df['price'].replace('\\$', '', regex=True)\n",
|
||||
"df['price'] = df['price'].replace('\\,', '', regex=True).astype(float)\n",
|
||||
"df[['number_of_reviews']] = df[['number_of_reviews']].astype(float)\n",
|
||||
"for col in variables:\n",
|
||||
" df[col] = df[col].astype(np.float32)\n",
|
||||
" df = df[np.isnan(df[col]) != 1]\n",
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "758d8171-9310-4a51-b9a6-b0fa6274ffe1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"将数据标准化"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "2cbe7cab-42cd-4b3b-bf11-080cc0a0f4c4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"df_train:\n",
|
||||
" price number_of_reviews review_scores_rating\n",
|
||||
"0 -0.331270 -0.218031 4.71\n",
|
||||
"1 -0.030076 0.450276 4.99\n",
|
||||
"2 0.425112 -0.603593 5.00\n",
|
||||
"3 -0.324476 -0.590741 2.50\n",
|
||||
"4 0.017481 -0.577889 4.00\n",
|
||||
"... ... ... ...\n",
|
||||
"113695 -0.381091 0.128974 4.98\n",
|
||||
"113696 -0.220303 -0.282292 4.96\n",
|
||||
"113697 -0.165953 -0.603593 3.00\n",
|
||||
"113698 -0.152365 -0.539333 4.00\n",
|
||||
"113699 0.162417 -0.590741 5.00\n",
|
||||
"\n",
|
||||
"[113700 rows x 3 columns]\n",
|
||||
"============================================\n",
|
||||
"df_test:\n",
|
||||
" price number_of_reviews review_scores_rating\n",
|
||||
"0 -0.249743 0.411720 4.84\n",
|
||||
"1 -0.211245 -0.565037 5.00\n",
|
||||
"2 -0.376562 -0.603593 5.00\n",
|
||||
"3 -0.088956 -0.539333 4.33\n",
|
||||
"4 0.517962 0.013306 4.90\n",
|
||||
"... ... ... ...\n",
|
||||
"48724 0.028804 -0.192327 4.91\n",
|
||||
"48725 -0.367503 -0.500777 5.00\n",
|
||||
"48726 -0.360710 -0.436517 4.86\n",
|
||||
"48727 -0.084426 -0.603593 1.00\n",
|
||||
"48728 0.067303 -0.603593 5.00\n",
|
||||
"\n",
|
||||
"[48729 rows x 3 columns]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 固定划分测试集和训练集\n",
|
||||
"col_df = df.columns\n",
|
||||
"info = df.iloc[:, :-1].values\n",
|
||||
"target = df.iloc[:, -1].values\n",
|
||||
"# 标准化\n",
|
||||
"stdscaler = StandardScaler()\n",
|
||||
"info_train, info_test, target_train, target_test = train_test_split(\n",
|
||||
" info, target, test_size=0.3, random_state=420)\n",
|
||||
"info_train = stdscaler.fit_transform(info_train)\n",
|
||||
"info_test = stdscaler.transform(info_test)\n",
|
||||
"df_train = np.hstack((info_train, target_train.reshape(-1, 1)))\n",
|
||||
"df_test = np.hstack((info_test, target_test.reshape(-1, 1)))\n",
|
||||
"df_train = pd.DataFrame(df_train, columns=col_df)\n",
|
||||
"df_test = pd.DataFrame(df_test, columns=col_df)\n",
|
||||
"print(\"df_train:\\n\", df_train)\n",
|
||||
"print(\"============================================\")\n",
|
||||
"print(\"df_test:\\n\", df_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1a9661ee-3dc4-4d5c-8a8d-bab3dc299d41",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"设置训练集和验证集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a10021bb-4e08-49b2-8269-114b8ce37be6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 初始化数据集\n",
|
||||
"train_data = USDataset(df_train)\n",
|
||||
"test_data = USDataset(df_test)\n",
|
||||
"train_loader = DataLoader(train_data, batch_size=batch_size,\n",
|
||||
" num_workers=num_workers, shuffle=True, drop_last=True)\n",
|
||||
"test_loader = DataLoader(test_data, batch_size=batch_size,\n",
|
||||
" num_workers=num_workers, shuffle=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "41a15589-9aaf-4541-a603-30a9f433c385",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"设计并初始化网络,其中神经网络是有一个输入层,一个输出层,三个隐藏层的前馈神经网络,激活函数是Sigmoid函数"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "764f28bc-81dd-4924-9039-5eceba54d739",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Parameter containing:\n",
|
||||
"tensor([[-0.3798, 0.4787],\n",
|
||||
" [-0.6687, 0.3603],\n",
|
||||
" [-0.4871, 0.1701],\n",
|
||||
" [ 0.1127, -0.2635],\n",
|
||||
" [ 0.6846, 0.6578],\n",
|
||||
" [-0.2820, 0.6887],\n",
|
||||
" [ 0.6581, -0.6334],\n",
|
||||
" [ 0.1057, -0.2249],\n",
|
||||
" [ 0.0650, 0.6174],\n",
|
||||
" [ 0.1768, -0.5373],\n",
|
||||
" [-0.5809, 0.6241],\n",
|
||||
" [ 0.1711, 0.3307],\n",
|
||||
" [ 0.0546, 0.6606],\n",
|
||||
" [ 0.0528, 0.1988],\n",
|
||||
" [ 0.6816, -0.2682]], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([-0.0835, -0.5775, -0.2021, 0.6027, -0.6389, 0.3593, 0.1442, 0.0746,\n",
|
||||
" -0.6356, 0.3029, -0.5204, -0.4724, 0.2451, -0.0103, -0.6133],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[-0.2190, 0.1091, 0.1611, 0.0540, 0.2386, -0.0399, 0.0086, -0.2273,\n",
|
||||
" 0.1421, 0.0286, 0.0507, -0.1721, 0.1504, -0.1714, 0.0113],\n",
|
||||
" [ 0.1312, 0.0581, 0.1805, 0.1385, -0.1394, -0.2539, -0.0047, 0.0753,\n",
|
||||
" 0.2134, 0.1119, 0.0534, -0.1073, 0.2128, 0.1909, 0.2403],\n",
|
||||
" [ 0.0414, 0.1597, -0.2470, -0.0118, 0.1226, -0.1825, 0.2529, 0.2161,\n",
|
||||
" 0.0246, 0.2333, -0.0733, 0.1474, 0.2101, 0.0246, 0.0514],\n",
|
||||
" [ 0.1203, 0.0301, 0.1293, -0.0800, -0.1533, -0.1585, -0.0815, -0.1242,\n",
|
||||
" 0.1415, -0.1520, 0.1287, -0.0701, -0.2560, -0.0094, -0.2361],\n",
|
||||
" [-0.1045, 0.0645, 0.2181, -0.0610, -0.0585, -0.1396, 0.2380, 0.2466,\n",
|
||||
" -0.2495, -0.0690, 0.1078, -0.1838, 0.0231, -0.0264, -0.1335],\n",
|
||||
" [ 0.0328, 0.2338, 0.0761, 0.2368, -0.0891, -0.1581, 0.1605, -0.0699,\n",
|
||||
" 0.1658, -0.0683, -0.2559, -0.0122, 0.1695, 0.2399, -0.1445],\n",
|
||||
" [-0.2489, 0.1559, -0.0083, 0.0980, -0.1653, 0.0678, -0.1833, -0.0217,\n",
|
||||
" -0.0746, 0.0633, -0.0559, 0.0334, 0.1273, -0.1259, -0.1527],\n",
|
||||
" [ 0.2253, -0.0182, 0.1613, -0.1602, -0.2205, -0.2047, 0.2046, -0.0824,\n",
|
||||
" 0.1261, -0.1015, 0.2257, -0.1276, -0.0197, -0.0505, -0.2408],\n",
|
||||
" [ 0.1883, -0.0590, -0.1909, -0.0710, -0.0304, -0.0085, -0.1848, 0.0480,\n",
|
||||
" -0.0447, 0.1001, 0.2330, -0.1368, -0.1478, -0.1676, 0.1942],\n",
|
||||
" [ 0.0554, -0.1686, -0.2217, -0.0343, 0.0218, 0.2300, 0.0303, 0.0711,\n",
|
||||
" -0.0932, -0.1791, 0.0978, -0.1902, -0.2322, -0.1179, -0.1592]],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([-0.0498, -0.1153, 0.0559, 0.1391, -0.0231, 0.1752, 0.1824, -0.1185,\n",
|
||||
" -0.1845, 0.1776], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[-0.1260, -0.2593, -0.2229, 0.0065, 0.0831, -0.2702, -0.0523, 0.2718,\n",
|
||||
" 0.2547, 0.0824],\n",
|
||||
" [ 0.2193, -0.2204, 0.2573, 0.2804, 0.2895, -0.1308, 0.0558, -0.2521,\n",
|
||||
" -0.2073, -0.2508],\n",
|
||||
" [ 0.2057, -0.2991, 0.0932, -0.1904, 0.2148, -0.1540, 0.0756, -0.1294,\n",
|
||||
" -0.2314, -0.3153],\n",
|
||||
" [ 0.0636, -0.1243, 0.0715, 0.0041, -0.0620, -0.0212, 0.0812, 0.1220,\n",
|
||||
" -0.1704, 0.1767],\n",
|
||||
" [ 0.0371, -0.1627, 0.1664, 0.0266, 0.2154, -0.2726, -0.2169, -0.2568,\n",
|
||||
" 0.0378, -0.1306],\n",
|
||||
" [-0.0030, -0.1174, -0.0169, -0.2378, 0.2668, -0.1097, -0.0550, -0.3043,\n",
|
||||
" 0.1614, 0.2220]], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([ 0.2263, -0.1245, -0.0600, -0.1303, 0.2717, -0.2886], device='cuda:0',\n",
|
||||
" requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[-0.0118, -0.0705, -0.3342, -0.0979, -0.0048, -0.1048]],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([0.3463], device='cuda:0', requires_grad=True)]\n",
|
||||
"==================================================\n",
|
||||
"[Parameter containing:\n",
|
||||
"tensor([[ 0.1109, -0.5283],\n",
|
||||
" [-0.3936, -0.6835],\n",
|
||||
" [ 1.6326, -1.0302],\n",
|
||||
" [ 1.3571, 1.2765],\n",
|
||||
" [-2.0980, 1.2285],\n",
|
||||
" [-0.2051, 0.3710],\n",
|
||||
" [-0.8627, 0.1593],\n",
|
||||
" [ 0.2369, 0.0855],\n",
|
||||
" [ 1.7011, -1.2683],\n",
|
||||
" [ 0.0239, -0.0731],\n",
|
||||
" [ 0.3759, 0.6586],\n",
|
||||
" [ 0.1748, 1.2761],\n",
|
||||
" [ 0.7805, -0.3940],\n",
|
||||
" [-1.7492, -0.3534],\n",
|
||||
" [-0.1294, 2.1254]], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[-1.0957, 1.0378, 0.3581, 0.6767, -0.0237, 0.6258, -0.6001, 0.4281,\n",
|
||||
" 0.0584, -0.7526, -0.7916, 0.6348, 0.2372, -0.0493, -1.0156],\n",
|
||||
" [ 0.5839, 0.1032, 1.3735, -0.1767, -1.5280, 0.9311, 1.2752, 0.0721,\n",
|
||||
" -0.0195, -1.0396, -1.2622, -0.2573, 1.1950, 1.2330, -1.1773],\n",
|
||||
" [-1.7392, -0.7176, 0.2932, -1.0405, 1.1690, -0.2543, -0.3711, -0.5608,\n",
|
||||
" 0.2679, -0.1302, 1.1744, 0.6578, 0.2809, -0.4317, 0.3950],\n",
|
||||
" [ 0.4147, -0.5277, -1.3396, 0.6607, 0.3036, 1.3448, -0.6774, 0.2073,\n",
|
||||
" 1.4682, 1.1449, 0.9241, 1.1630, 0.8778, -0.8538, 1.1645],\n",
|
||||
" [-0.2414, -1.1791, -0.9192, -1.3586, 0.7888, -1.2873, -0.7585, 0.1478,\n",
|
||||
" -0.2461, -1.7109, -1.3293, -0.1607, -0.0435, 0.7059, 0.8339],\n",
|
||||
" [ 0.5776, -1.4641, -0.8456, 1.4503, -0.0230, 0.9254, -1.1829, -0.7663,\n",
|
||||
" -0.5102, -1.7416, -0.2776, 1.8553, 0.0287, 1.3071, -0.0577],\n",
|
||||
" [-2.0167, 0.5340, 0.2045, -1.0659, 0.1410, -0.2611, 1.8272, -0.3517,\n",
|
||||
" -1.3150, 0.9499, 1.4403, 2.6355, 1.0081, -0.1915, -1.7888],\n",
|
||||
" [ 0.5569, -0.7237, -0.2397, 1.7227, 1.0570, 0.1441, -0.0228, 1.0863,\n",
|
||||
" 0.5095, -0.1426, -1.6044, -0.4300, 1.5130, 0.7247, -0.1672],\n",
|
||||
" [-0.2202, 0.8956, -0.5899, 0.6426, 1.0715, 2.8716, 0.6675, -0.3307,\n",
|
||||
" 0.4560, -1.6545, 0.6043, 1.4544, 0.3106, 2.0657, 0.4080],\n",
|
||||
" [-1.5387, 0.4301, 1.4037, 1.2928, 0.9105, 0.2546, -1.9276, -0.9019,\n",
|
||||
" -0.0115, 0.4342, -0.1543, 1.0857, 0.7773, 0.0530, 1.0929]],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',\n",
|
||||
" requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[ 3.2868, 2.1744, -0.4586, 0.3916, 1.0370, 0.2906, -0.6354, -1.7215,\n",
|
||||
" -1.3256, 1.4936],\n",
|
||||
" [ 0.5906, -1.0098, -1.7117, 0.8507, 0.3034, 1.7005, -2.9092, -0.1906,\n",
|
||||
" -1.4068, 2.3292],\n",
|
||||
" [ 0.2435, -0.9280, -0.4463, -0.4837, -2.3409, -0.5112, 0.6242, 0.9315,\n",
|
||||
" 1.6558, -1.3278],\n",
|
||||
" [ 0.4638, 0.5903, -0.7866, 1.3214, 0.6855, 0.7143, 1.2887, -0.2633,\n",
|
||||
" -0.0572, 1.0767],\n",
|
||||
" [ 1.8814, -1.2857, 0.4642, -0.2101, 2.0882, 1.3320, -0.3117, -0.7563,\n",
|
||||
" 0.7099, 0.2467],\n",
|
||||
" [ 0.7781, -1.1896, -0.2684, 1.3948, -0.8702, 0.3317, 0.6835, 0.2728,\n",
|
||||
" 0.3942, 1.5071]], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([0., 0., 0., 0., 0., 0.], device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([[ 1.5669, 0.3820, 1.0288, 1.5494, -0.9633, -0.8551]],\n",
|
||||
" device='cuda:0', requires_grad=True), Parameter containing:\n",
|
||||
"tensor([0.], device='cuda:0', requires_grad=True)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def initialize(self):\n",
|
||||
" '''\n",
|
||||
" 初始化网络参数\n",
|
||||
" '''\n",
|
||||
" for m in self.modules():\n",
|
||||
" if isinstance(m, nn.Linear):\n",
|
||||
" torch.nn.init.normal_(m.weight.data, 0.1)\n",
|
||||
" if m.bias is not None:\n",
|
||||
" torch.nn.init.zeros_(m.bias.data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Net(nn.Module):\n",
|
||||
" '''\n",
|
||||
" 网络结构\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
" def __init__(self, **kwargs):\n",
|
||||
" super(Net, self).__init__()\n",
|
||||
" self.fc = nn.Sequential(\n",
|
||||
" nn.Linear(2, 15),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" nn.Linear(15, 10),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" nn.Linear(10, 6),\n",
|
||||
" nn.Sigmoid(),\n",
|
||||
" nn.Linear(6, 1)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = x.view(-1, 2)\n",
|
||||
" x = self.fc(x)\n",
|
||||
" return x.squeeze(-1)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# 初始化网络\n",
|
||||
"model = Net()\n",
|
||||
"model = model.cuda()\n",
|
||||
"print(list(model.parameters()))\n",
|
||||
"initialize(model)\n",
|
||||
"print(\"==================================================\")\n",
|
||||
"print(list(model.parameters()))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "20fcccfe-3f82-44c3-bba2-9c66fb15cecd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"定义损失函数为均方误差损失函数,优化算法为随机梯度下降"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "7b530a85-50b1-4d0f-b47e-14831d0ee85d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 损失函数\n",
|
||||
"criterion = nn.MSELoss()\n",
|
||||
"# 优化器\n",
|
||||
"optimizer = optim.SGD(model.parameters(), lr=lr)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "8694737f-afb5-4e58-ad14-3220971b142f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(epoch):\n",
|
||||
" '''\n",
|
||||
" 训练器\n",
|
||||
"\n",
|
||||
" Parameters\n",
|
||||
" ----------\n",
|
||||
" epoch : int\n",
|
||||
"\n",
|
||||
" Returns\n",
|
||||
" -------\n",
|
||||
" None.\n",
|
||||
"\n",
|
||||
" '''\n",
|
||||
" model.train()\n",
|
||||
" train_loss = 0\n",
|
||||
" for info, target in train_loader:\n",
|
||||
" info = info.cuda()\n",
|
||||
" target = target.cuda()\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" output = model(info)\n",
|
||||
" loss = criterion(output, target)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" train_loss += loss.item()*info.size(0)\n",
|
||||
" train_loss = train_loss/len(train_loader.dataset)\n",
|
||||
" print(\"epoch:%d, train loss:%.4f\" % (epoch, train_loss))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "bb6be6f8-16e6-4c2b-a56c-232b09faa654",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def validate(epoch):\n",
|
||||
" '''\n",
|
||||
" 验证\n",
|
||||
"\n",
|
||||
" Parameters\n",
|
||||
" ----------\n",
|
||||
" epoch : int\n",
|
||||
"\n",
|
||||
" Returns\n",
|
||||
" -------\n",
|
||||
" None.\n",
|
||||
"\n",
|
||||
" '''\n",
|
||||
" model.eval()\n",
|
||||
" val_loss = 0\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for info, target in test_loader:\n",
|
||||
" info, target = info.cuda(), target.cuda()\n",
|
||||
" output = model(info)\n",
|
||||
" loss = criterion(output, target)\n",
|
||||
" val_loss += loss.item()*info.size(0)\n",
|
||||
" val_loss = val_loss/len(test_loader.dataset)\n",
|
||||
" print(\"epoch:%d, validation loss:%.4f\" %\n",
|
||||
" (epoch, val_loss))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "40532ea8-acfe-469a-aa9f-14d3bec2619d",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1, train loss:0.5139\n",
|
||||
"epoch:1, validation loss:0.3730\n",
|
||||
"epoch:2, train loss:0.3496\n",
|
||||
"epoch:2, validation loss:0.3698\n",
|
||||
"epoch:3, train loss:0.3471\n",
|
||||
"epoch:3, validation loss:0.3677\n",
|
||||
"epoch:4, train loss:0.3450\n",
|
||||
"epoch:4, validation loss:0.3662\n",
|
||||
"epoch:5, train loss:0.3439\n",
|
||||
"epoch:5, validation loss:0.3652\n",
|
||||
"epoch:6, train loss:0.3429\n",
|
||||
"epoch:6, validation loss:0.3644\n",
|
||||
"epoch:7, train loss:0.3425\n",
|
||||
"epoch:7, validation loss:0.3635\n",
|
||||
"epoch:8, train loss:0.3418\n",
|
||||
"epoch:8, validation loss:0.3631\n",
|
||||
"epoch:9, train loss:0.3417\n",
|
||||
"epoch:9, validation loss:0.3625\n",
|
||||
"epoch:10, train loss:0.3411\n",
|
||||
"epoch:10, validation loss:0.3621\n",
|
||||
"epoch:11, train loss:0.3405\n",
|
||||
"epoch:11, validation loss:0.3620\n",
|
||||
"epoch:12, train loss:0.3406\n",
|
||||
"epoch:12, validation loss:0.3617\n",
|
||||
"epoch:13, train loss:0.3393\n",
|
||||
"epoch:13, validation loss:0.3615\n",
|
||||
"epoch:14, train loss:0.3401\n",
|
||||
"epoch:14, validation loss:0.3616\n",
|
||||
"epoch:15, train loss:0.3396\n",
|
||||
"epoch:15, validation loss:0.3610\n",
|
||||
"epoch:16, train loss:0.3400\n",
|
||||
"epoch:16, validation loss:0.3608\n",
|
||||
"epoch:17, train loss:0.3390\n",
|
||||
"epoch:17, validation loss:0.3607\n",
|
||||
"epoch:18, train loss:0.3392\n",
|
||||
"epoch:18, validation loss:0.3608\n",
|
||||
"epoch:19, train loss:0.3396\n",
|
||||
"epoch:19, validation loss:0.3606\n",
|
||||
"epoch:20, train loss:0.3396\n",
|
||||
"epoch:20, validation loss:0.3604\n",
|
||||
"epoch:21, train loss:0.3396\n",
|
||||
"epoch:21, validation loss:0.3604\n",
|
||||
"epoch:22, train loss:0.3389\n",
|
||||
"epoch:22, validation loss:0.3603\n",
|
||||
"epoch:23, train loss:0.3389\n",
|
||||
"epoch:23, validation loss:0.3601\n",
|
||||
"epoch:24, train loss:0.3387\n",
|
||||
"epoch:24, validation loss:0.3602\n",
|
||||
"epoch:25, train loss:0.3390\n",
|
||||
"epoch:25, validation loss:0.3600\n",
|
||||
"epoch:26, train loss:0.3393\n",
|
||||
"epoch:26, validation loss:0.3600\n",
|
||||
"epoch:27, train loss:0.3385\n",
|
||||
"epoch:27, validation loss:0.3599\n",
|
||||
"epoch:28, train loss:0.3388\n",
|
||||
"epoch:28, validation loss:0.3598\n",
|
||||
"epoch:29, train loss:0.3388\n",
|
||||
"epoch:29, validation loss:0.3598\n",
|
||||
"epoch:30, train loss:0.3388\n",
|
||||
"epoch:30, validation loss:0.3597\n",
|
||||
"epoch:31, train loss:0.3387\n",
|
||||
"epoch:31, validation loss:0.3596\n",
|
||||
"epoch:32, train loss:0.3390\n",
|
||||
"epoch:32, validation loss:0.3596\n",
|
||||
"epoch:33, train loss:0.3384\n",
|
||||
"epoch:33, validation loss:0.3596\n",
|
||||
"epoch:34, train loss:0.3389\n",
|
||||
"epoch:34, validation loss:0.3595\n",
|
||||
"epoch:35, train loss:0.3385\n",
|
||||
"epoch:35, validation loss:0.3594\n",
|
||||
"epoch:36, train loss:0.3381\n",
|
||||
"epoch:36, validation loss:0.3594\n",
|
||||
"epoch:37, train loss:0.3380\n",
|
||||
"epoch:37, validation loss:0.3593\n",
|
||||
"epoch:38, train loss:0.3387\n",
|
||||
"epoch:38, validation loss:0.3593\n",
|
||||
"epoch:39, train loss:0.3388\n",
|
||||
"epoch:39, validation loss:0.3592\n",
|
||||
"epoch:40, train loss:0.3385\n",
|
||||
"epoch:40, validation loss:0.3592\n",
|
||||
"epoch:41, train loss:0.3386\n",
|
||||
"epoch:41, validation loss:0.3593\n",
|
||||
"epoch:42, train loss:0.3384\n",
|
||||
"epoch:42, validation loss:0.3595\n",
|
||||
"epoch:43, train loss:0.3373\n",
|
||||
"epoch:43, validation loss:0.3590\n",
|
||||
"epoch:44, train loss:0.3376\n",
|
||||
"epoch:44, validation loss:0.3591\n",
|
||||
"epoch:45, train loss:0.3385\n",
|
||||
"epoch:45, validation loss:0.3590\n",
|
||||
"epoch:46, train loss:0.3381\n",
|
||||
"epoch:46, validation loss:0.3589\n",
|
||||
"epoch:47, train loss:0.3382\n",
|
||||
"epoch:47, validation loss:0.3591\n",
|
||||
"epoch:48, train loss:0.3381\n",
|
||||
"epoch:48, validation loss:0.3589\n",
|
||||
"epoch:49, train loss:0.3381\n",
|
||||
"epoch:49, validation loss:0.3588\n",
|
||||
"epoch:50, train loss:0.3379\n",
|
||||
"epoch:50, validation loss:0.3587\n",
|
||||
"epoch:51, train loss:0.3383\n",
|
||||
"epoch:51, validation loss:0.3587\n",
|
||||
"epoch:52, train loss:0.3379\n",
|
||||
"epoch:52, validation loss:0.3587\n",
|
||||
"epoch:53, train loss:0.3378\n",
|
||||
"epoch:53, validation loss:0.3587\n",
|
||||
"epoch:54, train loss:0.3379\n",
|
||||
"epoch:54, validation loss:0.3585\n",
|
||||
"epoch:55, train loss:0.3380\n",
|
||||
"epoch:55, validation loss:0.3586\n",
|
||||
"epoch:56, train loss:0.3378\n",
|
||||
"epoch:56, validation loss:0.3585\n",
|
||||
"epoch:57, train loss:0.3378\n",
|
||||
"epoch:57, validation loss:0.3586\n",
|
||||
"epoch:58, train loss:0.3379\n",
|
||||
"epoch:58, validation loss:0.3587\n",
|
||||
"epoch:59, train loss:0.3374\n",
|
||||
"epoch:59, validation loss:0.3583\n",
|
||||
"epoch:60, train loss:0.3380\n",
|
||||
"epoch:60, validation loss:0.3583\n",
|
||||
"epoch:61, train loss:0.3375\n",
|
||||
"epoch:61, validation loss:0.3587\n",
|
||||
"epoch:62, train loss:0.3375\n",
|
||||
"epoch:62, validation loss:0.3583\n",
|
||||
"epoch:63, train loss:0.3377\n",
|
||||
"epoch:63, validation loss:0.3582\n",
|
||||
"epoch:64, train loss:0.3373\n",
|
||||
"epoch:64, validation loss:0.3584\n",
|
||||
"epoch:65, train loss:0.3374\n",
|
||||
"epoch:65, validation loss:0.3581\n",
|
||||
"epoch:66, train loss:0.3373\n",
|
||||
"epoch:66, validation loss:0.3581\n",
|
||||
"epoch:67, train loss:0.3378\n",
|
||||
"epoch:67, validation loss:0.3581\n",
|
||||
"epoch:68, train loss:0.3370\n",
|
||||
"epoch:68, validation loss:0.3581\n",
|
||||
"epoch:69, train loss:0.3370\n",
|
||||
"epoch:69, validation loss:0.3582\n",
|
||||
"epoch:70, train loss:0.3374\n",
|
||||
"epoch:70, validation loss:0.3580\n",
|
||||
"epoch:71, train loss:0.3374\n",
|
||||
"epoch:71, validation loss:0.3580\n",
|
||||
"epoch:72, train loss:0.3375\n",
|
||||
"epoch:72, validation loss:0.3579\n",
|
||||
"epoch:73, train loss:0.3373\n",
|
||||
"epoch:73, validation loss:0.3580\n",
|
||||
"epoch:74, train loss:0.3371\n",
|
||||
"epoch:74, validation loss:0.3583\n",
|
||||
"epoch:75, train loss:0.3376\n",
|
||||
"epoch:75, validation loss:0.3578\n",
|
||||
"epoch:76, train loss:0.3369\n",
|
||||
"epoch:76, validation loss:0.3577\n",
|
||||
"epoch:77, train loss:0.3371\n",
|
||||
"epoch:77, validation loss:0.3577\n",
|
||||
"epoch:78, train loss:0.3373\n",
|
||||
"epoch:78, validation loss:0.3577\n",
|
||||
"epoch:79, train loss:0.3373\n",
|
||||
"epoch:79, validation loss:0.3578\n",
|
||||
"epoch:80, train loss:0.3374\n",
|
||||
"epoch:80, validation loss:0.3577\n",
|
||||
"epoch:81, train loss:0.3368\n",
|
||||
"epoch:81, validation loss:0.3576\n",
|
||||
"epoch:82, train loss:0.3365\n",
|
||||
"epoch:82, validation loss:0.3577\n",
|
||||
"epoch:83, train loss:0.3371\n",
|
||||
"epoch:83, validation loss:0.3576\n",
|
||||
"epoch:84, train loss:0.3371\n",
|
||||
"epoch:84, validation loss:0.3575\n",
|
||||
"epoch:85, train loss:0.3371\n",
|
||||
"epoch:85, validation loss:0.3574\n",
|
||||
"epoch:86, train loss:0.3370\n",
|
||||
"epoch:86, validation loss:0.3574\n",
|
||||
"epoch:87, train loss:0.3371\n",
|
||||
"epoch:87, validation loss:0.3574\n",
|
||||
"epoch:88, train loss:0.3369\n",
|
||||
"epoch:88, validation loss:0.3574\n",
|
||||
"epoch:89, train loss:0.3364\n",
|
||||
"epoch:89, validation loss:0.3573\n",
|
||||
"epoch:90, train loss:0.3367\n",
|
||||
"epoch:90, validation loss:0.3573\n",
|
||||
"epoch:91, train loss:0.3367\n",
|
||||
"epoch:91, validation loss:0.3575\n",
|
||||
"epoch:92, train loss:0.3367\n",
|
||||
"epoch:92, validation loss:0.3573\n",
|
||||
"epoch:93, train loss:0.3372\n",
|
||||
"epoch:93, validation loss:0.3572\n",
|
||||
"epoch:94, train loss:0.3368\n",
|
||||
"epoch:94, validation loss:0.3572\n",
|
||||
"epoch:95, train loss:0.3369\n",
|
||||
"epoch:95, validation loss:0.3571\n",
|
||||
"epoch:96, train loss:0.3366\n",
|
||||
"epoch:96, validation loss:0.3571\n",
|
||||
"epoch:97, train loss:0.3368\n",
|
||||
"epoch:97, validation loss:0.3571\n",
|
||||
"epoch:98, train loss:0.3367\n",
|
||||
"epoch:98, validation loss:0.3571\n",
|
||||
"epoch:99, train loss:0.3367\n",
|
||||
"epoch:99, validation loss:0.3573\n",
|
||||
"epoch:100, train loss:0.3365\n",
|
||||
"epoch:100, validation loss:0.3571\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 训练过程\n",
|
||||
"for epoch in range(1, max_epochs+1):\n",
|
||||
" train(epoch)\n",
|
||||
" validate(epoch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e3f438d3-813f-4db5-b8e9-da71d54730d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user