From 8358198f83000442fb7b18e0da68d32a587d8140 Mon Sep 17 00:00:00 2001 From: zhang117228 <13198271+zhang117228@user.noreply.gitee.com> Date: Thu, 13 Jul 2023 11:13:36 +0000 Subject: [PATCH] =?UTF-8?q?=E6=9C=8D=E8=A3=85=E5=88=86=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: zhang117228 <13198271+zhang117228@user.noreply.gitee.com> --- .../homeworks/Fashion-MNIST分类.ipynb | 611 ++++++++++++++++++ 1 file changed, 611 insertions(+) create mode 100644 QuickDraw在线交互识别系统/homeworks/Fashion-MNIST分类.ipynb diff --git a/QuickDraw在线交互识别系统/homeworks/Fashion-MNIST分类.ipynb b/QuickDraw在线交互识别系统/homeworks/Fashion-MNIST分类.ipynb new file mode 100644 index 0000000..1c9cace --- /dev/null +++ b/QuickDraw在线交互识别系统/homeworks/Fashion-MNIST分类.ipynb @@ -0,0 +1,611 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "f2b64e25", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import Dataset,DataLoader" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2e469261", + "metadata": {}, + "outputs": [], + "source": [ + "device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "batch_size=256\n", + "#num_workers代表并行数,windows必须填0,否则会卡死\n", + "num_workers=0\n", + "lr=1e-4\n", + "epochs=20" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5002933c", + "metadata": {}, + "outputs": [], + "source": [ + "from torchvision import transforms\n", + "\n", + "image_size=28\n", + "data_transform=transforms.Compose([\n", + " transforms.ToPILImage(),\n", + " transforms.Resize(image_size),\n", + " transforms.ToTensor()\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "aa4ccfcc", + "metadata": {}, + "outputs": [], + "source": [ + "#读取数据\n", + "class FMDataset(Dataset):\n", + " def __init__(self,df,transform=None):\n", + " self.df=df\n", + " self.transform=transform\n", + " self.images=df.iloc[:,1:].values.astype(np.uint8)\n", + " self.labels=df.iloc[:,0].values\n", + " \n", + " def __len__(self):\n", + " return len(self.images)\n", + " \n", + " def __getitem__(self,idx):\n", + " image=self.images[idx].reshape(28,28,1)\n", + " label=int(self.labels[idx])\n", + " if self.transform is not None:\n", + " image=self.transform(image)\n", + " else:\n", + " image=torch.tensor(image/255.,dtype=torch.float)\n", + " label=torch.tensor(label,dtype=torch.long)\n", + " return image,label\n", + "\n", + "train_df=pd.read_csv(\"./FashionMNIST/fashion-mnist_train.csv\")\n", + "test_df=pd.read_csv(\"./FashionMNIST/fashion-mnist_test.csv\")\n", + "train_data=FMDataset(train_df,data_transform)\n", + "test_data=FMDataset(test_df,data_transform)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0b8040bb", + "metadata": {}, + "outputs": [], + "source": [ + "train_loader=DataLoader(train_data,batch_size=batch_size,shuffle=True,num_workers=num_workers,drop_last=True)\n", + "test_loader=DataLoader(test_data,batch_size=batch_size,shuffle=False,num_workers=num_workers)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d182f098", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([256, 1, 28, 28]) torch.Size([256])\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAP4ElEQVR4nO3db4id5ZnH8d/lJJM4iclkdjT/FNttJKwI2hBkwXXNUixWBC2kS/OiuCCbvqjQQl+suC/qS1m2LX2xFKarmC5dS6UNKshuVSJSkOKYZDXZmDURV9PEmQkhcSZ/zeTaF/O4THTOdU/Oc/5Nru8HhjPzXOc558rJ/OY559znfm5zdwG4+l3T7QYAdAZhB5Ig7EAShB1IgrADSSzq5J2ZGW/9z2Hx4sVhfcWKFWH9zJkzDWtnz55tqqdOWLQo/vVbtWpVWJ+cnAzr586du+KergbubnNtrxV2M7tP0s8k9Un6V3d/ss7tZbV69eqwfu+994b1PXv2NKzt3bu3mZY6YnBwMKxv3bo1rO/atSusv/vuu1fa0lWt6afxZtYn6V8kfUPSrZK2mdmtrWoMQGvVec1+p6RD7v6+u1+Q9GtJD7amLQCtVifs6yV9NOvnI9W2y5jZdjMbNbPRGvcFoKY6r9nnehPgC2/AufuIpBGJN+iAbqpzZD8i6aZZP98o6Wi9dgC0S52wvynpFjP7spn1S/q2pBda0xaAVmv6aby7XzSzRyX9p2aG3p529/0t6yyRBx54IKzv3r07rEfj8MPDw+G+x48fD+t1Rfe/YcOGcN/Dhw+H9S1btoR1ht4uV2uc3d1fkvRSi3oB0EZ8XBZIgrADSRB2IAnCDiRB2IEkCDuQREfns2NuH3/8cVgvzXcfHx9vWLv11ngi4smTJ8N6ycDAQFjv7+9v+r6jfSVpYmIirONyHNmBJAg7kARhB5Ig7EAShB1IgrADSTD01gLXXBP/zbx06VJYLw0hTU9Ph/WLFy82rEXDclL5dM6lenTfUnw65wsXLoT7RqfIlhh6u1Ic2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZW6A0jl6yb9++sL558+awHk2Rbec4uST19fWFdbM5Vw+e132XVrd95ZVXwjoux5EdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0HnDp1KqyXxrKXLFnSsObuTfX0mdI4fUnpdNCR6667LqyfP3++6dvOqNb/pJl9IGlS0rSki+4ef/oDQNe04sj+N+5+vAW3A6CNeM0OJFE37C7p92b2lpltn+sKZrbdzEbNbLTmfQGooe7T+Lvc/aiZ3SDpZTN7191fn30Fdx+RNCJJZlbv3SIATat1ZHf3o9XluKSdku5sRVMAWq/psJvZMjO77rPvJX1dUjxXE0DX1Hkav1rSzmq+8iJJ/+7u/9GSrnCZDz/8MKyvW7euYe306dPhvqX56qVz4pfG4aPPAAwODob7lnrDlWk67O7+vqTbW9gLgDZi6A1IgrADSRB2IAnCDiRB2IEkmOK6AJSGoKKpoKWht7rLTdexdOnSsF7qHVeGIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4+wIwNTUV1qNx+NIYfWms+8KFC2G9NE4fnQZ7eno63PfMmTNhHVeGIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4+wJQWrI5GksvjYPXPZV0ab571Punn34a7ss4e2txZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnXwDqjEeXllQuzVcvKY3Dnz17tunbnpycbHpffFHxyG5mT5vZuJntm7VtyMxeNrP3qstV7W0TQF3zeRr/jKT7PrftMUmvuvstkl6tfgbQw4phd/fXJZ343OYHJe2ovt8h6aHWtgWg1Zp9zb7a3Y9JkrsfM7MbGl3RzLZL2t7k/QBokba/QefuI5JGJMnMvN33B2BuzQ69jZnZWkmqLsdb1xKAdmg27C9Ierj6/mFJz7emHQDtUnwab2bPStoiadjMjkj6kaQnJf3GzB6R9KGkb7Wzyew2bNgQ1u+5556GtV27doX7nj9/PqyvW7curE9MTIT16Lz0GzduDPeN1p2XpIMHD4Z1XK4Ydnff1qD0tRb3AqCN+LgskARhB5Ig7EAShB1IgrADSTDFtQfcdtttYX3Tpk1hPZpmWpomamZN37YkDQwMhPXFixc3VZOkoaGhsL558+awPjo62rBW9xTZCxFHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2HlCawvraa6+F9a1btzasTU9Ph/suX748rJdOY10aj46m0N5wQ8OzmUmSXnzxxbB+9913h/VonP1qHEcv4cgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzt4Dli1bFtanpqbCejSv+7nnngv37e/vD+t9fX1h/dy5c2E9Gqdfu3ZtuG9puefSfeNyHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2ReA0pzyaFnk0nnhS+duL51fvY7Sksyl3kpz0qPbL51P/2pU/J80s6fNbNzM9s3a9oSZ/cnM9lZf97e3TQB1zefP9jOS7ptj+0/d/Y7q66XWtgWg1Yphd/fXJZ3oQC8A2qjOC7JHzezt6mn+qkZXMrPtZjZqZo1PCAag7ZoN+88lfUXSHZKOSfpxoyu6+4i7b3b3eBU+AG3VVNjdfczdp939kqRfSLqztW0BaLWmwm5ms+cmflPSvkbXBdAbiuPsZvaspC2Shs3siKQfSdpiZndIckkfSPpu+1pc+EprmJ8+fbrW7Udjxu4e7luaS18aCx8bGwvrpfPWRxYtin89S2PlK1eubHrfq1Ex7O6+bY7NT7WhFwBtxMdlgSQIO5AEYQeSIOxAEoQdSIIprh1QGr46cSKeerBqVcNPI0uqt/xwaRppaVnlw4cPh/XS9NzImjVrwnrpVNJDQ0MNa0eOHGmqp4WMIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4ewesXr06rJfGi0tLG0fj+KVTSZemoJbG2S9evFjr9iPr168P6x999FFYHxwcbPq+r0Yc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZO6A0n/348eNh/eabbw7rw8PDDWuffPJJuG9fX1/Tty2VPyMQzWcvLQddGuPfv39/WI/ms2fEkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcvQOWLl0a1icmJsL6jTfeGNYPHTrUsHbhwoVw39J540vnbi+N05fG4SPLly8P66XPEOByxSO7md1kZrvM7ICZ7Tez71fbh8zsZTN7r7qMVzIA0FXzeRp/UdIP3f0vJP2lpO+Z2a2SHpP0qrvfIunV6mcAPaoYdnc/5u67q+8nJR2QtF7Sg5J2VFfbIemhNvUIoAWu6DW7mX1J0lcl/VHSanc/Js38QTCzOT/IbGbbJW2v2SeAmuYddjNbLum3kn7g7p+UTmT4GXcfkTRS3YY30ySA+uY19GZmizUT9F+5+++qzWNmtraqr5U03p4WAbRC8chuM4fwpyQdcPefzCq9IOlhSU9Wl8+3pcMFoDR8VVIanipN9YxOVV1nyWRJuv7662vtHy0nXXrcBgYGwvrk5GRY7+/vb/q+6z5uvWg+T+PvkvQdSe+Y2d5q2+OaCflvzOwRSR9K+lZbOgTQEsWwu/sfJDV6gf611rYDoF34uCyQBGEHkiDsQBKEHUiCsANJMMW1BUrjwaUx29I4e53bL912aYpqabnp0pLM0RTbaBxckhYtin89S9N3o95Kj+mpU6fC+kLEkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcvQWWLVsW1usum1xa8vnMmTMNa9F8cql8muvBwcGwXlJnXvi1114b1s+fPx/Wo8el9O9mnB3AgkXYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzt4CpXnZU1NTYf2aa+K/uaUx4dOnTzd926Xzpw8NDYX1Okr3vWTJklq3Hz0uq1bFiw6PjY3Vuu9exJEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5KYz/rsN0n6paQ1ki5JGnH3n5nZE5L+XtJEddXH3f2ldjXay1asWBHWjx8/Xuv266wlXprzXbrtffv2hfU6Vq5cGdZL5wkoieazu3ut216I5vOhmouSfujuu83sOklvmdnLVe2n7v7P7WsPQKvMZ332Y5KOVd9PmtkBSevb3RiA1rqi1+xm9iVJX5X0x2rTo2b2tpk9bWZzfv7QzLab2aiZjdZrFUAd8w67mS2X9FtJP3D3TyT9XNJXJN2hmSP/j+faz91H3H2zu2+u3y6AZs0r7Ga2WDNB/5W7/06S3H3M3afd/ZKkX0i6s31tAqirGHYzM0lPSTrg7j+ZtX3trKt9U1L73rYFUNt83o2/S9J3JL1jZnurbY9L2mZmd0hySR9I+m4b+lsQSlNQS0pTZNesWRPWS9NYIxs3bgzrmzZtCuulJZ2PHj3asHb77beH+05OTob1kmhIcnh4ONz34MGDte67F83n3fg/SLI5SinH1IGFik/QAUkQdiAJwg4kQdiBJAg7kARhB5LgVNItsGfPnrA+MDAQ1s+dOxfWo7HqUr00hfWNN94I6zt37gzrdabvPvPMM2G9NM5e+nzByZMnG9ai00xfrTiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAS1slT6prZhKT/nbVpWFK98yy3T6/21qt9SfTWrFb2drO7Xz9XoaNh/8Kdm4326rnperW3Xu1Lordmdao3nsYDSRB2IIluh32ky/cf6dXeerUvid6a1ZHeuvqaHUDndPvIDqBDCDuQRFfCbmb3mdlBMztkZo91o4dGzOwDM3vHzPZ2e326ag29cTPbN2vbkJm9bGbvVZdzrrHXpd6eMLM/VY/dXjO7v0u93WRmu8zsgJntN7PvV9u7+tgFfXXkcev4a3Yz65P0P5LulXRE0puStrn7f3e0kQbM7ANJm9296x/AMLO/ljQl6Zfuflu17Z8knXD3J6s/lKvc/R96pLcnJE11exnvarWitbOXGZf0kKS/Uxcfu6Cvv1UHHrduHNnvlHTI3d939wuSfi3pwS700fPc/XVJJz63+UFJO6rvd2jml6XjGvTWE9z9mLvvrr6flPTZMuNdfeyCvjqiG2FfL+mjWT8fUW+t9+6Sfm9mb5nZ9m43M4fV7n5MmvnlkXRDl/v5vOIy3p30uWXGe+axa2b587q6Efa5lpLqpfG/u9x9k6RvSPpe9XQV8zOvZbw7ZY5lxntCs8uf19WNsB+RdNOsn2+UFJ9RsYPc/Wh1OS5pp3pvKeqxz1bQrS7Hu9zP/+ulZbznWmZcPfDYdXP5826E/U1Jt5jZl82sX9K3Jb3QhT6+wMyWVW+cyMyWSfq6em8p6hckPVx9/7Ck57vYy2V6ZRnvRsuMq8uPXdeXP3f3jn9Jul8z78gflvSP3eihQV9/Lum/qq/93e5N0rOaeVr3qWaeET0i6c8kvSrpvepyqId6+zdJ70h6WzPBWtul3v5KMy8N35a0t/q6v9uPXdBXRx43Pi4LJMEn6IAkCDuQBGEHkiDsQBKEHUiCsANJEHYgif8Dxj/8IzJoOQYAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#数据可视化\n", + "import matplotlib.pyplot as plt\n", + "image,label=next(iter(train_loader))\n", + "print(image.shape,label.shape)\n", + "plt.imshow(image[0][0],cmap=\"gray\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "62ee436d", + "metadata": {}, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net,self).__init__()\n", + " self.conv=nn.Sequential(\n", + " nn.Conv2d(1,32,5),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(2,stride=2),\n", + " nn.Dropout(0.3),\n", + " nn.Conv2d(32,64,5),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(2,stride=2),\n", + " nn.Dropout(0.3) \n", + " )\n", + " self.fc=nn.Sequential(\n", + " nn.Linear(64*4*4,512),\n", + " nn.ReLU(),\n", + " nn.Linear(512,10)\n", + " )\n", + " \n", + " def forward(self,x):\n", + " x=self.conv(x)\n", + " x=x.view(-1,64*4*4)\n", + " x=self.fc(x)\n", + " return x\n", + " \n", + "model=Net()\n", + "model=model.cuda()\n", + "criterion=nn.CrossEntropyLoss()\n", + "optimizer=optim.Adam(model.parameters(),lr=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a2bdc7b1", + "metadata": {}, + "outputs": [], + "source": [ + "def train(epoch):\n", + " model.train()\n", + " train_loss=0\n", + " for data,label in train_loader:\n", + " data,label=data.cuda(),label.cuda()\n", + " optimizer.zero_grad()\n", + " output=model(data)\n", + " loss=criterion(output,label)\n", + " loss.backward()\n", + " optimizer.step()\n", + " train_loss+=loss.item()*data.size(0)\n", + " train_loss=train_loss/len(train_loader.dataset)\n", + " print('Epoch:{} Training Loss:{:.6f}'.format(epoch,train_loss))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c0d891f1", + "metadata": {}, + "outputs": [], + "source": [ + "def val(epoch):\n", + " model.eval()\n", + " val_loss=0\n", + " gt_labels=[]\n", + " pred_labels=[]\n", + " with torch.no_grad():\n", + " for data,label in test_loader:\n", + " data,label=data.cuda(),label.cuda()\n", + " output=model(data)\n", + " preds=torch.argmax(output,1)\n", + " gt_labels.append(label.cpu().data.numpy())\n", + " pred_labels.append(preds.cpu().data.numpy())\n", + " loss=criterion(output,label)\n", + " val_loss+=loss.item()*data.size(0)\n", + " val_loss=val_loss/len(test_loader.dataset)\n", + " ge_labels,pred_labels=np.concatenate(gt_labels),np.concatenate(pred_labels)\n", + " acc=np.sum(ge_labels==pred_labels)/len(pred_labels)\n", + " print('Epoch:{} Validation Loss:{:.6f},Accuracy :{:6f}'.format(epoch,val_loss,acc))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c026db76", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r", + " 0%| | 0/20 [00:00