@@ -0,0 +1,201 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import torch.nn as nn
|
||||
import time
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# 导入数据
|
||||
path = []
|
||||
data_set = [] # 全部数据
|
||||
label = [] # 全部标签
|
||||
path.append("Rice Disease/Bacterial leaf blight")
|
||||
path.append("Rice Disease/Brown spot")
|
||||
path.append("Rice Disease/Leaf smut")
|
||||
for i in range(3):
|
||||
img_dir = sorted(os.listdir(path[i]))
|
||||
for file in img_dir:
|
||||
img = cv2.imread(os.path.join(path[i], file))
|
||||
if img.shape[0] > img.shape[1]:
|
||||
img = np.rot90(img).copy()
|
||||
data_set.append(cv2.resize(img, (1024, 512)))
|
||||
label.append(i)
|
||||
|
||||
|
||||
# 划分训练集和验证集,训练集与验证集比例为7:3
|
||||
data_train, data_val, label_train, label_val = train_test_split(
|
||||
data_set, label, test_size=0.3
|
||||
)
|
||||
|
||||
# 训练集数据增强操作
|
||||
train_set_transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomRotation(15), # 随机旋转
|
||||
transforms.ToTensor()
|
||||
# Tensor格式为[3,512,1024]
|
||||
]
|
||||
)
|
||||
# 验证集无需旋转
|
||||
val_set_transorm = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
|
||||
|
||||
|
||||
class ImgDataset(Dataset):
|
||||
def __init__(self, x, y, transform=None):
|
||||
self.x = x
|
||||
self.y = torch.LongTensor(y)
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
|
||||
def __getitem__(self, index):
|
||||
X = self.transform(self.x[index])
|
||||
Y = self.y[index]
|
||||
return X, Y
|
||||
|
||||
|
||||
batch_size = 5
|
||||
|
||||
TrainSet = ImgDataset(data_train, label_train, train_set_transform)
|
||||
ValidationSet = ImgDataset(data_val, label_val, val_set_transorm)
|
||||
|
||||
TrainLoader = DataLoader(TrainSet, batch_size, shuffle=True)
|
||||
ValidationLoader = DataLoader(ValidationSet, batch_size, shuffle=False)
|
||||
|
||||
print("DataSet complete")
|
||||
|
||||
|
||||
# 卷积神经网络结构
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
# input维度[3,512,1024]
|
||||
# torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
||||
# torch.nn.MaxPool2d(kernel_size, stride, padding)
|
||||
|
||||
# 卷积层
|
||||
self.cnn = nn.Sequential(
|
||||
# 第一层卷积与池化
|
||||
nn.Conv2d(3, 32, 7, 1, 3), # [32,512,1024]
|
||||
nn.BatchNorm2d(32), # 归一化操作
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2, 0), # [32,256,512]
|
||||
# 第二次
|
||||
nn.Conv2d(32, 64, 3, 1, 1), # [64,256,512]
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2, 0), # [64,128,256]
|
||||
# 第三次
|
||||
nn.Conv2d(64, 128, 3, 1, 1), # [128,128,256]
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2, 0), # [128,64,128]
|
||||
# 第四次
|
||||
nn.Conv2d(128, 128, 3, 1, 1), # [128,64,128]
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2, 0), # [128,32,64]
|
||||
# 第五次
|
||||
nn.Conv2d(128, 256, 3, 1, 1), # [256,32,64]
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2, 0), # [256,16,32]
|
||||
# 第六次
|
||||
nn.Conv2d(256, 256, 3, 1, 1), # [256,16,32]
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2, 0), # [256,8,16]
|
||||
# 第七次
|
||||
nn.Conv2d(256, 512, 3, 1, 1), # [512,8,16]
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, 2, 0), # [512,4,8]
|
||||
)
|
||||
# 全连接层
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(512 * 4 * 8, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 3),
|
||||
)
|
||||
|
||||
# 前向传播
|
||||
def forward(self, x):
|
||||
out = self.cnn(x)
|
||||
out = out.view(out.size()[0], -1) # 将多维数据改成一维
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
print("Net complete")
|
||||
|
||||
|
||||
model = Net().cuda()
|
||||
loss = nn.CrossEntropyLoss() # 交叉熵损失函数
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.000001) # Adam优化器
|
||||
num_epoch = 80 # 迭代次数
|
||||
writer = SummaryWriter(log_dir="runs/result") # 实例化SummaryWriter,用于记录训练过程中的数据并做可视化
|
||||
# init_img用于带入网络从而便于记录网络结构
|
||||
init_img = torch.zeros(1, 3, 512, 1024).cuda()
|
||||
writer.add_graph(model, init_img)
|
||||
for epoch in range(num_epoch):
|
||||
epoch_start_time = time.time()
|
||||
|
||||
train_acc = 0.0
|
||||
val_acc = 0.0
|
||||
train_loss = 0.0
|
||||
val_loss = 0.0
|
||||
|
||||
model.train()
|
||||
for i, data in enumerate(TrainLoader):
|
||||
optimizer.zero_grad()
|
||||
train_pred = model(data[0].cuda())
|
||||
batch_loss = loss(train_pred, data[1].cuda())
|
||||
batch_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_acc += np.sum(
|
||||
np.argmax(train_pred.cpu().data.numpy(), axis=1) == data[1].numpy()
|
||||
)
|
||||
train_loss += batch_loss.item()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for i, data in enumerate(ValidationLoader):
|
||||
val_pred = model(data[0].cuda())
|
||||
batch_loss = loss(val_pred, data[1].cuda())
|
||||
|
||||
val_acc += np.sum(
|
||||
np.argmax(val_pred.cpu().data.numpy(), axis=1) == data[1].numpy()
|
||||
)
|
||||
val_loss += batch_loss.item()
|
||||
|
||||
# 將結果 print 出來
|
||||
print(
|
||||
"[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f | Validation Acc: %3.6f Loss: %3.6f"
|
||||
% (
|
||||
epoch + 1,
|
||||
num_epoch,
|
||||
time.time() - epoch_start_time,
|
||||
train_acc / TrainSet.__len__(),
|
||||
train_loss / TrainSet.__len__(),
|
||||
val_acc / ValidationSet.__len__(),
|
||||
val_loss / ValidationSet.__len__(),
|
||||
)
|
||||
)
|
||||
|
||||
# 记录训练过程中的数据
|
||||
writer.add_scalar("Train Acc", train_acc / TrainSet.__len__(), epoch)
|
||||
writer.add_scalar("Train Loss", train_loss / TrainSet.__len__(), epoch)
|
||||
writer.add_scalar("Validation Acc", val_acc / ValidationSet.__len__(), epoch)
|
||||
writer.add_scalar("Validation Loss", val_loss / ValidationSet.__len__(), epoch)
|
||||
Reference in New Issue
Block a user