第二组
This commit is contained in:
@@ -0,0 +1,316 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[4., 1., 1., ..., 0., 0., 0.],\n",
|
||||
" [4., 1., 1., ..., 0., 0., 0.],\n",
|
||||
" [4., 0., 3., ..., 0., 0., 0.],\n",
|
||||
" ...,\n",
|
||||
" [4., 1., 1., ..., 0., 0., 0.],\n",
|
||||
" [3., 1., 1., ..., 0., 0., 0.],\n",
|
||||
" [4., 1., 1., ..., 0., 0., 0.]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torch.utils.data import Dataset, DataLoader, random_split\n",
|
||||
"from torchvision import transforms, datasets, models\n",
|
||||
"import pandas as pd\n",
|
||||
"import sklearn\n",
|
||||
"df=pd.read_csv(\"Analysis.csv\").drop(\"Unnamed: 0\",axis=1)\n",
|
||||
"df=df.values\n",
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Running on the CPU\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(0) # 设置随机种子, 用于复现\n",
|
||||
"torch.cuda.is_available()\n",
|
||||
"# 超参数\n",
|
||||
"EPOCH = 100 # 前向后向传播迭代次数\n",
|
||||
"LR = 0.001 # 学习率 learning rate\n",
|
||||
"BATCH_SIZE = 50 # 批量训练时候一次送入数据的size\n",
|
||||
"DOWNLOAD_MNIST = True\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" device = torch.device(\"cuda:0\") # you can continue going on here, like cuda:1 cuda:2....etc.\n",
|
||||
" print(\"Running on the GPU\")\n",
|
||||
"else:\n",
|
||||
" device = torch.device(\"cpu\")\n",
|
||||
" print(\"Running on the CPU\")\n",
|
||||
"train_size = int(len(df) * 0.8) # 这里按照8:2进行训练和测试\n",
|
||||
"test_size = len(df) - train_size\n",
|
||||
"train_dataset, test_dataset = random_split(df, [train_size, test_size])\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
|
||||
"test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.nn import functional as F\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class focal_loss_multi(nn.Module):\n",
|
||||
" def __init__(self, alpha=[0.3, 0.3, 0.15,0.05,0.15], gamma=2, num_classes=5, size_average=True):\n",
|
||||
" super(focal_loss_multi, self).__init__()\n",
|
||||
" self.size_average = size_average\n",
|
||||
" if isinstance(alpha, (float, int)): # 只设置第一类别的权重\n",
|
||||
" self.alpha = torch.zeros(num_classes)\n",
|
||||
" self.alpha[0] += alpha\n",
|
||||
" self.alpha[1:] += (1 - alpha) # self.alpha = [0.25,0.75,0.75,0.75,0.75]\n",
|
||||
" if isinstance(alpha, list): # 全部权重自己设置\n",
|
||||
" self.alpha = torch.Tensor(alpha)\n",
|
||||
" self.gamma = gamma\n",
|
||||
"\n",
|
||||
" def forward(self, inputs, targets):\n",
|
||||
" alpha = torch.tensor(self.alpha).cuda()\n",
|
||||
" N = inputs.size(0)\n",
|
||||
" C = inputs.size(1)\n",
|
||||
" # 下面这些只是为了获取四个样本的概率probs\n",
|
||||
" # 如模型中有softmax,则不需要下一行代码\n",
|
||||
" P = F.softmax(inputs, dim=1)\n",
|
||||
"\n",
|
||||
" class_mask = inputs.data.new(N, C).fill_(0) # 生成和input一样shape的tensor\n",
|
||||
" class_mask = class_mask.requires_grad_() # 加入梯度计算\n",
|
||||
" ids = targets.view(-1, 1) # 获取目标的索引\n",
|
||||
" alpha = alpha.gather(0, ids.view(-1))\n",
|
||||
" # one hot\n",
|
||||
" class_mask.data.scatter_(1, ids.data, 1.) # 利用scatter将索引丢给mask\n",
|
||||
" probs = (P * class_mask).sum(1).view(-1, 1)\n",
|
||||
" # focal loss公式\n",
|
||||
" log_p = probs.log()\n",
|
||||
" loss = torch.pow((1 - probs), self.gamma) * log_p\n",
|
||||
" batch_loss = (-alpha * loss).t()\n",
|
||||
"\n",
|
||||
" # batch loss求平均\n",
|
||||
" if self.size_average:\n",
|
||||
" loss = batch_loss.mean()\n",
|
||||
" else:\n",
|
||||
" loss = batch_loss.sum()\n",
|
||||
" return loss"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "AssertionError",
|
||||
"evalue": "Torch not compiled with CUDA enabled",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mAssertionError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[1;32mIn[18], line 21\u001b[0m\n\u001b[0;32m 19\u001b[0m device \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mdevice(\u001b[39m\"\u001b[39m\u001b[39mcuda:0\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 20\u001b[0m fnn \u001b[39m=\u001b[39m FNN()\n\u001b[1;32m---> 21\u001b[0m fnn \u001b[39m=\u001b[39m fnn\u001b[39m.\u001b[39;49mto(device)\n\u001b[0;32m 22\u001b[0m optimizer \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39moptim\u001b[39m.\u001b[39mAdam(fnn\u001b[39m.\u001b[39mparameters(), lr\u001b[39m=\u001b[39mLR) \u001b[39m# 定义优化器\u001b[39;00m\n\u001b[0;32m 23\u001b[0m loss_func \u001b[39m=\u001b[39m focal_loss_multi() \u001b[39m# 定义损失函数\u001b[39;00m\n",
|
||||
"File \u001b[1;32mc:\\Users\\王瑞恒\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:927\u001b[0m, in \u001b[0;36mModule.to\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 923\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39mto(device, dtype \u001b[39mif\u001b[39;00m t\u001b[39m.\u001b[39mis_floating_point() \u001b[39mor\u001b[39;00m t\u001b[39m.\u001b[39mis_complex() \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[0;32m 924\u001b[0m non_blocking, memory_format\u001b[39m=\u001b[39mconvert_to_format)\n\u001b[0;32m 925\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39mto(device, dtype \u001b[39mif\u001b[39;00m t\u001b[39m.\u001b[39mis_floating_point() \u001b[39mor\u001b[39;00m t\u001b[39m.\u001b[39mis_complex() \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m, non_blocking)\n\u001b[1;32m--> 927\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_apply(convert)\n",
|
||||
"File \u001b[1;32mc:\\Users\\王瑞恒\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:579\u001b[0m, in \u001b[0;36mModule._apply\u001b[1;34m(self, fn)\u001b[0m\n\u001b[0;32m 577\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_apply\u001b[39m(\u001b[39mself\u001b[39m, fn):\n\u001b[0;32m 578\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mchildren():\n\u001b[1;32m--> 579\u001b[0m module\u001b[39m.\u001b[39;49m_apply(fn)\n\u001b[0;32m 581\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[0;32m 582\u001b[0m \u001b[39mif\u001b[39;00m torch\u001b[39m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[0;32m 583\u001b[0m \u001b[39m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[0;32m 584\u001b[0m \u001b[39m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 589\u001b[0m \u001b[39m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[0;32m 590\u001b[0m \u001b[39m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
|
||||
"File \u001b[1;32mc:\\Users\\王瑞恒\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:602\u001b[0m, in \u001b[0;36mModule._apply\u001b[1;34m(self, fn)\u001b[0m\n\u001b[0;32m 598\u001b[0m \u001b[39m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[0;32m 599\u001b[0m \u001b[39m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[0;32m 600\u001b[0m \u001b[39m# `with torch.no_grad():`\u001b[39;00m\n\u001b[0;32m 601\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[1;32m--> 602\u001b[0m param_applied \u001b[39m=\u001b[39m fn(param)\n\u001b[0;32m 603\u001b[0m should_use_set_data \u001b[39m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[0;32m 604\u001b[0m \u001b[39mif\u001b[39;00m should_use_set_data:\n",
|
||||
"File \u001b[1;32mc:\\Users\\王瑞恒\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:925\u001b[0m, in \u001b[0;36mModule.to.<locals>.convert\u001b[1;34m(t)\u001b[0m\n\u001b[0;32m 922\u001b[0m \u001b[39mif\u001b[39;00m convert_to_format \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m t\u001b[39m.\u001b[39mdim() \u001b[39min\u001b[39;00m (\u001b[39m4\u001b[39m, \u001b[39m5\u001b[39m):\n\u001b[0;32m 923\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39mto(device, dtype \u001b[39mif\u001b[39;00m t\u001b[39m.\u001b[39mis_floating_point() \u001b[39mor\u001b[39;00m t\u001b[39m.\u001b[39mis_complex() \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[0;32m 924\u001b[0m non_blocking, memory_format\u001b[39m=\u001b[39mconvert_to_format)\n\u001b[1;32m--> 925\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39;49mto(device, dtype \u001b[39mif\u001b[39;49;00m t\u001b[39m.\u001b[39;49mis_floating_point() \u001b[39mor\u001b[39;49;00m t\u001b[39m.\u001b[39;49mis_complex() \u001b[39melse\u001b[39;49;00m \u001b[39mNone\u001b[39;49;00m, non_blocking)\n",
|
||||
"File \u001b[1;32mc:\\Users\\王瑞恒\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\cuda\\__init__.py:211\u001b[0m, in \u001b[0;36m_lazy_init\u001b[1;34m()\u001b[0m\n\u001b[0;32m 207\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[0;32m 208\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mCannot re-initialize CUDA in forked subprocess. To use CUDA with \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 209\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mmultiprocessing, you must use the \u001b[39m\u001b[39m'\u001b[39m\u001b[39mspawn\u001b[39m\u001b[39m'\u001b[39m\u001b[39m start method\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 210\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mhasattr\u001b[39m(torch\u001b[39m.\u001b[39m_C, \u001b[39m'\u001b[39m\u001b[39m_cuda_getDeviceCount\u001b[39m\u001b[39m'\u001b[39m):\n\u001b[1;32m--> 211\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAssertionError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mTorch not compiled with CUDA enabled\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 212\u001b[0m \u001b[39mif\u001b[39;00m _cudart \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 213\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAssertionError\u001b[39;00m(\n\u001b[0;32m 214\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mlibcudart functions unavailable. It looks like you have a broken build?\u001b[39m\u001b[39m\"\u001b[39m)\n",
|
||||
"\u001b[1;31mAssertionError\u001b[0m: Torch not compiled with CUDA enabled"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"class FNN(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(FNN, self).__init__()\n",
|
||||
" self.len = 256\n",
|
||||
" channel_list = [35, 20, 10, 5]\n",
|
||||
" self.fc1 = nn.Linear(35, 20)\n",
|
||||
" self.fc2 = nn.Linear(20, 10)\n",
|
||||
" self.fc3 = nn.Linear(10, 5)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = x.to(torch.float32)\n",
|
||||
" x = F.relu(self.fc1(x))\n",
|
||||
" x = F.relu(self.fc2(x))\n",
|
||||
" x = self.fc3(x)\n",
|
||||
"\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda:0\")\n",
|
||||
"fnn = FNN()\n",
|
||||
"fnn = fnn.to(device)\n",
|
||||
"optimizer = torch.optim.Adam(fnn.parameters(), lr=LR) # 定义优化器\n",
|
||||
"loss_func = focal_loss_multi() # 定义损失函数"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "mat1 and mat2 shapes cannot be multiplied (128x319 and 35x20)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[1;32mIn[14], line 19\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[39m# 清空梯度缓存\u001b[39;00m\n\u001b[0;32m 17\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m---> 19\u001b[0m outputs \u001b[39m=\u001b[39m fnn(inputs)\n\u001b[0;32m 21\u001b[0m loss \u001b[39m=\u001b[39m loss_func(outputs, labels\u001b[39m.\u001b[39mlong())\n\u001b[0;32m 22\u001b[0m loss\u001b[39m.\u001b[39mbackward()\n",
|
||||
"File \u001b[1;32mc:\\Users\\王瑞恒\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39m\u001b[39minput\u001b[39m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
||||
"Cell \u001b[1;32mIn[12], line 12\u001b[0m, in \u001b[0;36mFNN.forward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x):\n\u001b[0;32m 11\u001b[0m x \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39mto(torch\u001b[39m.\u001b[39mfloat32)\n\u001b[1;32m---> 12\u001b[0m x \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39mrelu(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfc1(x))\n\u001b[0;32m 13\u001b[0m x \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39mrelu(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfc2(x))\n\u001b[0;32m 14\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfc3(x)\n",
|
||||
"File \u001b[1;32mc:\\Users\\王瑞恒\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39m\u001b[39minput\u001b[39m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
||||
"File \u001b[1;32mc:\\Users\\王瑞恒\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[1;32m--> 114\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mlinear(\u001b[39minput\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias)\n",
|
||||
"\u001b[1;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (128x319 and 35x20)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"start = time.time()\n",
|
||||
"running_loss_list = []\n",
|
||||
"test_loss_list = []\n",
|
||||
"\n",
|
||||
"for epoch in range(50):\n",
|
||||
"\n",
|
||||
" running_loss = 0.0\n",
|
||||
" test_loss = 0.0\n",
|
||||
" for data in train_loader:\n",
|
||||
" # 获取输入数据\n",
|
||||
" inputs = data[:, 1:].to(device)\n",
|
||||
" labels = (data[:, 0]-1).to(device)\n",
|
||||
" # 清空梯度缓存\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
"\n",
|
||||
" outputs = fnn(inputs)\n",
|
||||
"\n",
|
||||
" loss = loss_func(outputs, labels.long())\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" # 打印统计信息\n",
|
||||
" running_loss += loss.item()\n",
|
||||
"\n",
|
||||
" running_loss_list.append(running_loss)\n",
|
||||
" test_loss_list.append(test_loss)\n",
|
||||
" for i, data in enumerate(test_loader, 0):\n",
|
||||
" # 获取输入数据\n",
|
||||
" inputs = data[:, 1:].to(device)\n",
|
||||
" labels = (data[:, 0]-1).to(device)\n",
|
||||
"\n",
|
||||
" # 清空梯度缓存\n",
|
||||
" outputs = fnn(inputs)\n",
|
||||
" loss = loss_func(outputs, labels.long())\n",
|
||||
" # 打印统计信息\n",
|
||||
" test_loss += loss.item()\n",
|
||||
" running_loss_list.append(running_loss)\n",
|
||||
" test_loss_list.append(test_loss)\n",
|
||||
" print(f\"epoch:{epoch},loss: {running_loss}\")\n",
|
||||
"plt.style.use(\"ggplot\") # matplotlib的美化样式\n",
|
||||
"plt.figure()\n",
|
||||
"plt.title(\"loss and accuracy\")\n",
|
||||
"plt.xlabel(\"epoch\")\n",
|
||||
"plt.ylabel(\"loss\")\n",
|
||||
"plt.plot(test_loss_list,label='train')\n",
|
||||
"plt.plot(running_loss_list,label='train')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy of the network on the test images: 59 %\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"correct = 0\n",
|
||||
"total = 0\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for data in test_loader:\n",
|
||||
" inputs = data[:, 1:].to(device)\n",
|
||||
" labels = (data[:, 0]-1).to(device)\n",
|
||||
" outputs = fnn(inputs)\n",
|
||||
" _, predicted = torch.max(outputs.data, 1)\n",
|
||||
" total += labels.size(0)\n",
|
||||
" correct += (predicted == labels).sum().item()\n",
|
||||
"\n",
|
||||
"print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Reference in New Issue
Block a user