@@ -0,0 +1,600 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "1996849f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||||
"RangeIndex: 161 entries, 0 to 160\n",
|
||||
"Data columns (total 9 columns):\n",
|
||||
" # Column Non-Null Count Dtype \n",
|
||||
"--- ------ -------------- ----- \n",
|
||||
" 0 0 161 non-null float64\n",
|
||||
" 1 1 161 non-null float64\n",
|
||||
" 2 2 161 non-null float64\n",
|
||||
" 3 3 161 non-null float64\n",
|
||||
" 4 4 161 non-null float64\n",
|
||||
" 5 5 161 non-null float64\n",
|
||||
" 6 6 161 non-null float64\n",
|
||||
" 7 7 161 non-null float64\n",
|
||||
" 8 8 161 non-null float64\n",
|
||||
"dtypes: float64(9)\n",
|
||||
"memory usage: 11.4 KB\n",
|
||||
"Try(\n",
|
||||
" (unil1): Sequential(\n",
|
||||
" (0): Conv1d(1, 16, kernel_size=(3,), stride=(1,), padding=(1,))\n",
|
||||
" (1): ReLU()\n",
|
||||
" (2): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n",
|
||||
" (3): ReLU()\n",
|
||||
" (4): Conv1d(32, 16, kernel_size=(3,), stride=(1,), padding=(1,))\n",
|
||||
" (5): ReLU()\n",
|
||||
" (6): Flatten(start_dim=1, end_dim=-1)\n",
|
||||
" (7): Linear(in_features=144, out_features=100, bias=True)\n",
|
||||
" (8): ReLU()\n",
|
||||
" (9): Dropout(p=0.5, inplace=False)\n",
|
||||
" (10): Linear(in_features=100, out_features=100, bias=True)\n",
|
||||
" (11): ReLU()\n",
|
||||
" (12): Dropout(p=0.5, inplace=False)\n",
|
||||
" (13): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (14): Linear(in_features=100, out_features=5, bias=True)\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"train,loss 0 -0.19760972261428833\n",
|
||||
"train,acc 0 0.175\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 0 0.1059190031152648\n",
|
||||
"train,loss 1 -0.15465295314788818\n",
|
||||
"train,acc 1 0.2\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"C:\\Users\\16560\\AppData\\Local\\Temp\\ipykernel_5848\\317761143.py:95: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
||||
" x_train=torch.tensor(x_train,dtype=torch.float)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 1 0.11838006230529595\n",
|
||||
"train,loss 2 -0.26933109760284424\n",
|
||||
"train,acc 2 0.15625\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 2 0.09657320872274143\n",
|
||||
"train,loss 3 -0.21968130767345428\n",
|
||||
"train,acc 3 0.2\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 3 0.11838006230529595\n",
|
||||
"train,loss 4 -0.24240297079086304\n",
|
||||
"train,acc 4 0.26875\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"3.0 8 0.375\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 4 0.1526479750778816\n",
|
||||
"train,loss 5 -0.16728299856185913\n",
|
||||
"train,acc 5 0.2375\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 5 0.13707165109034267\n",
|
||||
"train,loss 6 -0.19944357872009277\n",
|
||||
"train,acc 6 0.24375\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"2.0 8 0.25\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 6 0.14018691588785046\n",
|
||||
"train,loss 7 -0.197749525308609\n",
|
||||
"train,acc 7 0.225\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 7 0.1308411214953271\n",
|
||||
"train,loss 8 -0.17148229479789734\n",
|
||||
"train,acc 8 0.20625\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"2.0 8 0.25\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 8 0.12149532710280374\n",
|
||||
"train,loss 9 -0.2207236886024475\n",
|
||||
"train,acc 9 0.18125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 8 0.0\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"1.0 8 0.125\n",
|
||||
"0.0 8 0.0\n",
|
||||
"0.0 1 0.0\n",
|
||||
"test,acc 9 0.10903426791277258\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# -*- coding: utf-8 -*-\n",
|
||||
"\"\"\"\n",
|
||||
"Created on Fri Jul 14 09:58:52 2023\n",
|
||||
"\n",
|
||||
"@author: 16560\n",
|
||||
"\"\"\"\n",
|
||||
"import os\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from PIL import ImageStat \n",
|
||||
"from PIL import Image\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torch.optim as optim\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from torch import nn\n",
|
||||
"# 我们需要选取仅包括水样的图像部分来进行分析\n",
|
||||
"data_root_path='C:/Users/16560/Desktop/小学期/images/'\n",
|
||||
"# =============================================================================\n",
|
||||
"# plt.title('水色样本 '+imgFile+' 分辨率为'+str(img.size)+\" 类别标签 \"+str(imgFile[9]))\n",
|
||||
"# plt.show()\n",
|
||||
"# =============================================================================\n",
|
||||
"\n",
|
||||
"# 加载图像统计信息模块(注:也可以直接通过颜色通道来计算)\n",
|
||||
"\n",
|
||||
"# 遍历全体图像进行快速检查\n",
|
||||
"size = 100\n",
|
||||
"imgWidth = [] # 图像宽度\n",
|
||||
"imgHeight = [] # 图像高度\n",
|
||||
"imgRrange = [] # 图像红色通道极差\n",
|
||||
"imgGrange = [] # 图像绿色通道极差\n",
|
||||
"imgBrange = [] # 图像蓝色通道极差\n",
|
||||
"\n",
|
||||
"newImgs = [] # 获得选取后的图像作为模型训练和验证数据\n",
|
||||
"\n",
|
||||
"imgFiles = os.listdir(data_root_path)\n",
|
||||
"for imgFile in imgFiles:\n",
|
||||
" img = Image.open(os.path.join(data_root_path,imgFile))\n",
|
||||
" imgWidth.append(img.size[0])\n",
|
||||
" imgHeight.append(img.size[1])\n",
|
||||
" \n",
|
||||
" # 获得图像中心区域大小为size的图像块\n",
|
||||
" cx, cy = (int(i/2) for i in img.size)\n",
|
||||
" box = (cx-50, cy-50, cx+50, cy+50)\n",
|
||||
" region = img.crop(box)\n",
|
||||
" \n",
|
||||
" # 计算选取图像块的标准差\n",
|
||||
" stat = ImageStat.Stat(region)\n",
|
||||
" imgRrange.append(stat.extrema[0][1]-stat.extrema[0][0])\n",
|
||||
" imgGrange.append(stat.extrema[1][1]-stat.extrema[1][0])\n",
|
||||
" imgBrange.append(stat.extrema[2][1]-stat.extrema[2][0])\n",
|
||||
" \n",
|
||||
" newImgs.append(region)\n",
|
||||
"\n",
|
||||
"# 构建训练数据集和分类标签\n",
|
||||
"data = []\n",
|
||||
"dy = []\n",
|
||||
"for i, img in enumerate(newImgs):\n",
|
||||
" r, g, b = np.split(np.array(img), 3, axis = 2)\n",
|
||||
" \n",
|
||||
" #计算一阶矩\n",
|
||||
" r_m1 = np.mean(r)\n",
|
||||
" g_m1 = np.mean(g)\n",
|
||||
" b_m1 = np.mean(b)\n",
|
||||
" \n",
|
||||
" #二阶矩\n",
|
||||
" r_m2 = np.std(r)\n",
|
||||
" g_m2 = np.std(g)\n",
|
||||
" b_m2 = np.std(b)\n",
|
||||
" \n",
|
||||
" #三阶矩\n",
|
||||
" r_m3 = np.mean(abs(r - r.mean())**3)**(1/3)\n",
|
||||
" g_m3 = np.mean(abs(g - g.mean())**3)**(1/3)\n",
|
||||
" b_m3 = np.mean(abs(b - b.mean())**3)**(1/3)\n",
|
||||
" \n",
|
||||
" # 构造新数据集\n",
|
||||
" df = np.array([r_m1,g_m1,b_m1,r_m2,g_m2,b_m2,r_m3,g_m3,b_m3])\n",
|
||||
" data.append(df)\n",
|
||||
" \n",
|
||||
" # 保存对应的分类标签\n",
|
||||
" dy.append(int(imgFiles[i][0]))\n",
|
||||
"\n",
|
||||
"dy = np.array(dy)\n",
|
||||
"data = pd.DataFrame(np.array(data))\n",
|
||||
"data.info()\n",
|
||||
"data.head()\n",
|
||||
"\n",
|
||||
"for i in range(dy.size):\n",
|
||||
" dy[i]-=1\n",
|
||||
" \n",
|
||||
"y_train=torch.tensor(dy,dtype=torch.long)\n",
|
||||
"x_train = torch.from_numpy(data.values)\n",
|
||||
"x_train=x_train.reshape(161,1,9)\n",
|
||||
"x_train=torch.tensor(x_train,dtype=torch.float)\n",
|
||||
"\n",
|
||||
"epochs=10\n",
|
||||
"batch_size=8\n",
|
||||
"\n",
|
||||
"data_set=torch.utils.data.TensorDataset(x_train,y_train)\n",
|
||||
"train_loader=DataLoader(dataset=data_set,batch_size=batch_size,shuffle=True)\n",
|
||||
"test_loader=DataLoader(dataset=data_set,batch_size=batch_size,shuffle=True)\n",
|
||||
"\n",
|
||||
"class Try(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Try,self).__init__()\n",
|
||||
" self.unil1=nn.Sequential(\n",
|
||||
" nn.Conv1d(1,16,kernel_size=3,stride=1,padding=1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Conv1d(16,32,kernel_size=3,stride=1,padding=1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Conv1d(32,16,kernel_size=3,stride=1,padding=1),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Flatten(),\n",
|
||||
" nn.Linear(144,100),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(p=0.5),\n",
|
||||
" nn.Linear(100,100),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.Dropout(p=0.5),\n",
|
||||
" nn.BatchNorm1d(100),\n",
|
||||
" nn.Linear(100,5)\n",
|
||||
" )\n",
|
||||
" # tmp=torch.rand(5,6,5)\n",
|
||||
" # out=self.model(tmp)\n",
|
||||
" # print(out.shape)\n",
|
||||
" # print(out)\n",
|
||||
" def forward(self,x):\n",
|
||||
" logits=self.unil1(x)\n",
|
||||
" logits=F.softmax(logits,dim=1)\n",
|
||||
" return logits\n",
|
||||
" \n",
|
||||
"criteon=nn.NLLLoss()\n",
|
||||
"optimizer=optim.Adam(Try().parameters(),lr=1e-2,weight_decay=1e-2)\n",
|
||||
"model=Try()\n",
|
||||
"\n",
|
||||
"print(model)\n",
|
||||
"for epoch in range(epochs):\n",
|
||||
" total_correct=0\n",
|
||||
" total_num=0\n",
|
||||
" model.train()\n",
|
||||
" for batchidx,(train_data1,train_label1) in enumerate(train_loader):\n",
|
||||
" if train_label1.shape[0]>=8:\n",
|
||||
" logits=model(train_data1)\n",
|
||||
" loss=criteon(logits,train_label1)\n",
|
||||
" pred=logits.argmax(dim=1)\n",
|
||||
" total_correct+=torch.eq(pred,train_label1).float().sum().item()\n",
|
||||
" total_num+=train_data1.size(0)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" print('train,loss',epoch,loss.item())\n",
|
||||
" acc=total_correct/total_num\n",
|
||||
" print('train,acc',epoch,acc)\n",
|
||||
" model.eval() \n",
|
||||
" with torch.no_grad():\n",
|
||||
" for test_data1,test_label1 in test_loader:\n",
|
||||
" logits=model(test_data1)\n",
|
||||
" pred=logits.argmax(dim=1)\n",
|
||||
" sb=torch.eq(pred,test_label1)\n",
|
||||
" sb1=torch.eq(pred,test_label1)\n",
|
||||
" total_correct+=torch.eq(pred,test_label1).float().sum().item()\n",
|
||||
" total_num+=test_data1.size(0)\n",
|
||||
" #print(torch.eq(pred,test_label1))\n",
|
||||
" print(torch.eq(pred,test_label1).float().sum().item(),test_data1.size(0),torch.eq(pred,test_label1).float().sum().item()/test_data1.size(0))\n",
|
||||
" acc=total_correct/total_num\n",
|
||||
" print('test,acc',epoch,acc)\n",
|
||||
"\n",
|
||||
"# =============================================================================\n",
|
||||
"# def main():\n",
|
||||
"# net=Try()\n",
|
||||
"# tmp=torch.rand(16,1,9)\n",
|
||||
"# out=net(tmp)\n",
|
||||
"# print(out.shape)\n",
|
||||
"# \n",
|
||||
"# \n",
|
||||
"# if __name__ == '__main__':\n",
|
||||
"# main()\n",
|
||||
"# =============================================================================\n",
|
||||
"# =============================================================================\n",
|
||||
"# name_dict={'1':1,'2':2,'3':3,'4':4,'5':5}\n",
|
||||
"# data_root_path='C:/Users/16560/Desktop/小学期/images/'\n",
|
||||
"# test_file_path = data_root_path + \"test.txt\" #测试文件路径\n",
|
||||
"# train_file_path = data_root_path + \"train.txt\" # 训练文件路径\n",
|
||||
"# name_data_list = {} # 记录每个类别有哪些图片 key:水果名称 value:图片路径构成的列表\n",
|
||||
"# \n",
|
||||
"# def save_train_test_file(path, name):\n",
|
||||
"# if name not in name_data_list: # 该类别水果不在字典中,则新建一个列表插入字典\n",
|
||||
"# img_list = []\n",
|
||||
"# img_list.append(path) # 将图片路径存入列表\n",
|
||||
"# name_data_list[name] = img_list # 将图片列表插入字典\n",
|
||||
"# else: # 该类别水果在字典中,直接添加到列表\n",
|
||||
"# name_data_list[name].append(path)\n",
|
||||
"# \n",
|
||||
"# # 遍历数据集下面每个子目录,将图片路径写入上面的字典\n",
|
||||
"# dirs = os.listdir(data_root_path) # 列出数据集目下所有的文件和子目录\n",
|
||||
"# imgs=os.listdir(data_root_path)\n",
|
||||
"# for img in imgs:\n",
|
||||
"# name=img[0]\n",
|
||||
"# if name !='t':\n",
|
||||
"# save_train_test_file(data_root_path+img, #拼图片完整路径\n",
|
||||
"# name) # 以目录名称作为类别名称\n",
|
||||
"# \n",
|
||||
"# # 将name_data_list字典中的内容写入文件\n",
|
||||
"# ## 清空训练集和测试集文件\n",
|
||||
"# with open(test_file_path, \"w\") as f:\n",
|
||||
"# pass\n",
|
||||
"# \n",
|
||||
"# with open(train_file_path, \"w\") as f:\n",
|
||||
"# pass\n",
|
||||
"# \n",
|
||||
"# # 遍历字典,将字典中的内容写入训练集和测试集\n",
|
||||
"# for name, img_list in name_data_list.items():\n",
|
||||
"# i = 0\n",
|
||||
"# num = len(img_list) # 获取每个类别图片数量\n",
|
||||
"# print(\"%s: %d张\" % (name, num))\n",
|
||||
"# # 写训练集和测试集\n",
|
||||
"# for img in img_list:\n",
|
||||
"# if i % 10 == 0: # 每10笔写一笔测试集\n",
|
||||
"# with open(test_file_path, \"a\") as f: #以追加模式打开测试集文件\n",
|
||||
"# line = \"%s\\t%d\\n\" % (img, name_dict[name]) # 拼一行\n",
|
||||
"# f.write(line) # 写入文件\n",
|
||||
"# else: # 训练集\n",
|
||||
"# with open(train_file_path, \"a\") as f: #以追加模式打开测试集文件\n",
|
||||
"# line = \"%s\\t%d\\n\" % (img, name_dict[name]) # 拼一行\n",
|
||||
"# f.write(line) # 写入文件\n",
|
||||
"# \n",
|
||||
"# i += 1 # 计数器加1\n",
|
||||
"# \n",
|
||||
"# print(\"数据预处理完成.\")\n",
|
||||
"# \n",
|
||||
"# \n",
|
||||
"# =============================================================================\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "85aad8da",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "8e9f8ce4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "module 'matplotlib' has no attribute 'figure'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||
"Input \u001b[1;32mIn [3]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mplt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfigure\u001b[49m(figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m8\u001b[39m,\u001b[38;5;241m8\u001b[39m))\n",
|
||||
"File \u001b[1;32mD:\\anaconda\\lib\\site-packages\\matplotlib\\_api\\__init__.py:222\u001b[0m, in \u001b[0;36mcaching_module_getattr.<locals>.__getattr__\u001b[1;34m(name)\u001b[0m\n\u001b[0;32m 220\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m props:\n\u001b[0;32m 221\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m props[name]\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__get__\u001b[39m(instance)\n\u001b[1;32m--> 222\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[0;32m 223\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodule \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__module__\u001b[39m\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m has no attribute \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"\u001b[1;31mAttributeError\u001b[0m: module 'matplotlib' has no attribute 'figure'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f6fa60f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "8f9469b5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "FileNotFoundError",
|
||||
"evalue": "[Errno 2] No such file or directory: 'fashion-mnist_train.csv'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Input \u001b[1;32mIn [1]\u001b[0m, in \u001b[0;36m<cell line: 108>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 105\u001b[0m ave_acc\u001b[38;5;241m=\u001b[39mtal_acc\u001b[38;5;241m/\u001b[39mval_num\n\u001b[0;32m 106\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ave_acc,ave_loss\n\u001b[1;32m--> 108\u001b[0m x_train, y_train \u001b[38;5;241m=\u001b[39m \u001b[43mload_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_name\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 109\u001b[0m x_validation, y_validation \u001b[38;5;241m=\u001b[39m load_data(test_name)\n\u001b[0;32m 110\u001b[0m train_num \u001b[38;5;241m=\u001b[39m y_train\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n",
|
||||
"Input \u001b[1;32mIn [1]\u001b[0m, in \u001b[0;36mload_data\u001b[1;34m(path)\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mload_data\u001b[39m(path):\n\u001b[0;32m 21\u001b[0m li \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m---> 22\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m 23\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m csv\u001b[38;5;241m.\u001b[39mreader(f):\n\u001b[0;32m 24\u001b[0m li\u001b[38;5;241m.\u001b[39mappend(i)\n",
|
||||
"\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'fashion-mnist_train.csv'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import csv\n",
|
||||
"import torch\n",
|
||||
"from torch import nn,squeeze\n",
|
||||
"import torch.optim as optim\n",
|
||||
"from torch.nn import functional as fun\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import time\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"train_name = \"fashion-mnist_train.csv\"\n",
|
||||
"test_name = \"fashion-mnist_test.csv\"\n",
|
||||
"#训练结果存储地址\n",
|
||||
"save_path=\"./测试结果/\"\n",
|
||||
"\n",
|
||||
"# 自动导入所在路径的csv文件并训练\n",
|
||||
"# 训练后的结果包括训练完成的模型和损失与准确率随轮数的变化图\n",
|
||||
"# 训练结果将会存储在 [所在路径]/测试结果/[训练完成时日期] 下 \n",
|
||||
"\n",
|
||||
"def load_data(path):\n",
|
||||
" li = []\n",
|
||||
" with open(path,'r') as f:\n",
|
||||
" for i in csv.reader(f):\n",
|
||||
" li.append(i)\n",
|
||||
" del li[0]\n",
|
||||
" for i in range(len(li)):\n",
|
||||
" li[i] = [int(data) for data in li[i]]\n",
|
||||
" a = torch.tensor(li,dtype=torch.long).contiguous()\n",
|
||||
" labels = a[:,0].clone().detach()\n",
|
||||
" image = a[:,1:].clone().detach()\n",
|
||||
" image = image.reshape(-1,1,28,28).float()\n",
|
||||
" return image,labels\n",
|
||||
"\n",
|
||||
"class CNN(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" #卷积层\n",
|
||||
" self.fea = nn.Sequential(\n",
|
||||
" nn.Conv2d(1,16,5),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.MaxPool2d(2,stride = 2),\n",
|
||||
" nn.Dropout(),\n",
|
||||
" nn.Conv2d(16,32,5),\n",
|
||||
" nn.ReLU(),\n",
|
||||
" nn.MaxPool2d(2,stride = 2),\n",
|
||||
" nn.Dropout(),\n",
|
||||
" nn.Flatten() \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" first_length = 32*4*4\n",
|
||||
" self.fc1=nn.Linear(first_length,128)\n",
|
||||
" self.fc2=nn.Linear(128,128)\n",
|
||||
" self.fc3=nn.Linear(128,10)\n",
|
||||
"\n",
|
||||
" self.dr1=nn.Dropout()\n",
|
||||
" self.dr2=nn.Dropout()\n",
|
||||
"\n",
|
||||
" self.bn=nn.BatchNorm1d(first_length)\n",
|
||||
" # self.cons = ConSca()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def forward(self,x):\n",
|
||||
" \n",
|
||||
" x = self.fea(x)\n",
|
||||
" x = self.bn(x)\n",
|
||||
" x = fun.relu(self.fc1(x))\n",
|
||||
" x = self.dr1(x)\n",
|
||||
" x = fun.relu(self.fc2(x))\n",
|
||||
" x = self.dr2(x)\n",
|
||||
" x = fun.log_softmax(self.fc3(x),dim=1)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"lossFun=nn.NLLLoss()\n",
|
||||
"epochs=10\n",
|
||||
"batch_size=500\n",
|
||||
"lr=1e-4\n",
|
||||
"train_loss=[]\n",
|
||||
"val_loss=[]\n",
|
||||
"train_acc=[]\n",
|
||||
"val_acc=[]\n",
|
||||
"net=CNN() \n",
|
||||
"optimizer=optim.Adam(net.parameters(),lr=lr)\n",
|
||||
"\n",
|
||||
"def verify(x_validation,y_validation,net=net,batch_size=batch_size):\n",
|
||||
" #分批次地计算的损失和准确度以减轻内存负担\n",
|
||||
" net.eval()\n",
|
||||
" val_num=y_validation.shape[0]\n",
|
||||
" tal_loss=0\n",
|
||||
" tal_acc=0\n",
|
||||
" for i in range(val_num//batch_size+1):\n",
|
||||
" if (i+1)*batch_size<=val_num:\n",
|
||||
" index=list(range(i*batch_size,(i+1)*batch_size))\n",
|
||||
" elif i*batch_size<val_num:\n",
|
||||
" index=list(range(i*batch_size,val_num))\n",
|
||||
" else:\n",
|
||||
" break \n",
|
||||
" val_y_predict=net.forward(x_validation[index])\n",
|
||||
" val_loss=lossFun(val_y_predict,y_validation[index])\n",
|
||||
" val_acc=(y_validation[index] == val_y_predict.argmax(dim=1)).float().sum()\n",
|
||||
" tal_loss=tal_loss+float(val_loss)\n",
|
||||
" tal_acc=tal_acc+float(val_acc)\n",
|
||||
" print(\"\\rcalc_acc&loss...%d/%d tal_acc=%f val_num=%d ave_acc_now=%f \"%(i+1,val_num//batch_size,tal_acc,val_num,tal_acc/val_num),end=\"\")\n",
|
||||
" ave_loss=tal_loss*batch_size/val_num\n",
|
||||
" ave_acc=tal_acc/val_num\n",
|
||||
" return ave_acc,ave_loss\n",
|
||||
"\n",
|
||||
"x_train, y_train = load_data(train_name)\n",
|
||||
"x_validation, y_validation = load_data(test_name)\n",
|
||||
"train_num = y_train.shape[0]\n",
|
||||
"\n",
|
||||
"for e in range(epochs):\n",
|
||||
" perm=torch.randperm(train_num)\n",
|
||||
" net.train()\n",
|
||||
" for i in range(train_num//batch_size):\n",
|
||||
" index=perm[i*batch_size:(i+1)*batch_size]\n",
|
||||
" net.zero_grad()\n",
|
||||
" y_predict=net.forward(x_train[index,...])\n",
|
||||
" loss=lossFun(y_predict,y_train[index])\n",
|
||||
" acc=(y_train[index] == y_predict.argmax(dim=1)).float().mean()\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" print(\"\\re:%d-%d/%d loss:%f acc:%f\"%(e+1,i+1,train_num//batch_size,loss,acc),end=\"\")\n",
|
||||
" print(\"\\r \",end=\"\")\n",
|
||||
" (tempacc,temploss)=verify(x_validation=x_train,y_validation=y_train)\n",
|
||||
" train_acc.append(tempacc)\n",
|
||||
" train_loss.append(temploss)\n",
|
||||
" (tempacc,temploss)=verify(x_validation=x_validation,y_validation=y_validation)\n",
|
||||
" val_acc.append(tempacc)\n",
|
||||
" val_loss.append(temploss)\n",
|
||||
" print(\"\\rEpoch:%d loss:%f val_loss:%f acc:%f val_acc:%f \"\\\n",
|
||||
" %(e+1,train_loss[e],val_loss[e],train_acc[e],val_acc[e]))\n",
|
||||
"\n",
|
||||
"timefornow = time.strftime(\"%Y%m%d_%H%M%S\", time.localtime())\n",
|
||||
"path_to_save=save_path+timefornow+'/' \n",
|
||||
"if not os.path.exists(path_to_save):\n",
|
||||
" os.makedirs(path_to_save)\n",
|
||||
"\n",
|
||||
"torch.save(net.state_dict(), path_to_save+\"model_parameter.pkl\")\n",
|
||||
"\n",
|
||||
"plt.figure()\n",
|
||||
"plt.title(\"Accuracy\")\n",
|
||||
"plt.plot(range(epochs),train_acc, color='red')\n",
|
||||
"plt.plot(range(epochs),val_acc, color='blue')\n",
|
||||
"plt.xlabel('epochs')\n",
|
||||
"plt.ylabel('acc')\n",
|
||||
"plt.legend(['acc', 'val_acc'])\n",
|
||||
"plt.savefig(path_to_save+\"acc.png\", dpi=300, bbox_inches='tight')\n",
|
||||
"plt.figure()\n",
|
||||
"plt.title(\"Loss\")\n",
|
||||
"plt.plot(range(epochs),train_loss, color='red')\n",
|
||||
"plt.plot(range(epochs),val_loss, color='blue')\n",
|
||||
"plt.xlabel('epochs')\n",
|
||||
"plt.ylabel('loss')\n",
|
||||
"plt.legend(['loss', 'val_loss'])\n",
|
||||
"plt.savefig(path_to_save+\"loss.png\", dpi=300, bbox_inches='tight')\n",
|
||||
"# plt.show()\n",
|
||||
"plt.close()\n",
|
||||
"print(\"本轮的结果已存入 %s\"%(path_to_save))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c6437b8c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user