基于 FashionMnist 数据集的 LogisticRegression
#!--*-- coding:utf-8 --*--
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
# FashionMnist Datas
train_datas = FashionMNIST(root='/home/ai/Desktop', train=True, transform=transforms.ToTensor(), download=False)
val_datas = FashionMNIST(root='/home/ai/Desktop', train=False, transform=transforms.ToTensor(), download=False)
# LogisticRegression 网络
class LogisticRegression(nn.Module):
def __init__(self, input_size, num_classes):
super(LogisticRegression,self).__init__()
self.logistic_regression = nn.Linear(input_size,num_classes)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.logistic_regression(x)
x = self.sigmoid(x)
return x
num_workers = 12
batch_size = 128
input_size = 28 * 28
num_classes = 10
learning_rate = 0.01
momentum = 0.9
num_epochs = 100
model = LogisticRegression(input_size, num_classes)
model.to('cuda')
print(model)
train_loader = DataLoader(dataset=train_datas, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset=val_datas, batch_size=batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, momentum=momentum)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.view(-1, 28 * 28).to("cuda")
with torch.no_grad():
labels = labels.to("cuda")
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 100 == 0:
print('epoch = %d ---- iter %d ---- current loss = %.5f' % (epoch, i, loss.data[0]))
# 模型测试
correct = 0
total = 0
for images, labels in valid_loader:
images = images.view(-1, 28 * 28).to("cuda")
labels = labels.to("cuda")
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('accuracy of the model %.2f' % (100 * correct / total))
print('Done.')
100 个 epoch 得到的模型精度约为 81%.