VGG应用:猫狗大战——基于VGG16的猫狗数据分类

一、数据集的处理与加载

class CatDogDataset(Dataset):def __init__(self, data_dir, mode="train", split_n=0.9, rng_seed=620, transform=None):self.mode = modeself.data_dir = data_dirself.rng_seed = rng_seedself.split_n = split_nself.data_info = self._get_img_info()  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本self.transform = transformdef __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB')     # 0~255if self.transform is not None:img = self.transform(img)   # 在这里做transform,转为tensor等等return img, labeldef __len__(self):if len(self.data_info) == 0:raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir))return len(self.data_info)def _get_img_info(self):img_names = os.listdir(self.data_dir)img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))random.seed(self.rng_seed)random.shuffle(img_names)img_labels = [0 if n.startswith('cat') else 1 for n in img_names]split_idx = int(len(img_labels) * self.split_n) # 25000* 0.9 = 22500# split_idx = int(100 * self.split_n)if self.mode == "train":img_set = img_names[:split_idx]     # img_set = img_names[:22500]     label_set = img_labels[:split_idx]elif self.mode == "valid":img_set = img_names[split_idx:]label_set = img_labels[split_idx:]else:raise Exception("self.mode 无法识别,仅支持(train, valid)")path_img_set = [os.path.join(self.data_dir, n) for n in img_set]data_info = [(n, l) for n, l in zip(path_img_set, label_set)]return data_info
    norm_mean = [0.485, 0.456, 0.406]norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((256)),transforms.CenterCrop(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),])normalizes = transforms.Normalize(norm_mean, norm_std)valid_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.TenCrop(224, vertical_flip=False),transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops])),])# 构建MyDataset实例train_data = CatDogDataset(data_dir=data_dir, mode="train", transform=train_transform)valid_data = CatDogDataset(data_dir=data_dir, mode="valid", transform=valid_transform)# 构建DataLodertrain_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)valid_loader = DataLoader(dataset=valid_data, batch_size=4)

二、调用模型,定义损失函数以及优化器

 vgg16_model = get_vgg16(path_state_dict, device, False)num_ftrs = vgg16_model.classifier._modules["6"].in_featuresvgg16_model.classifier._modules["6"] = nn.Linear(num_ftrs, num_classes)vgg16_model.to(device)criterion = nn.CrossEntropyLoss()flag = 0# flag = 1if flag:fc_params_id = list(map(id, vgg16_model.classifier.parameters()))  # 返回的是parameters的 内存地址base_params = filter(lambda p: id(p) not in fc_params_id, vgg16_model.parameters())optimizer = optim.SGD([{'params': base_params, 'lr': LR * 0.1},  # 0{'params': vgg16_model.classifier.parameters(), 'lr': LR}], momentum=0.9)else:optimizer = optim.SGD(vgg16_model.parameters(), lr=LR, momentum=0.9)  # 选择优化器scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) 

三、训练过程

 train_curve = list()valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.vgg16_model.train()for i, data in enumerate(train_loader):# forwardinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = vgg16_model(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().cpu().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%} lr:{}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total, scheduler.get_last_lr()))loss_mean = 0.scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.vgg16_model.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)bs, ncrops, c, h, w = inputs.size()outputs = vgg16_model(inputs.view(-1, c, h, w))outputs_avg = outputs.view(bs, ncrops, -1).mean(1)loss = criterion(outputs_avg, labels)_, predicted = torch.max(outputs_avg.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().cpu().sum().numpy()loss_val += loss.item()loss_val_mean = loss_val/len(valid_loader)valid_curve.append(loss_val_mean)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))vgg16_model.train()

本文链接:https://my.lmcjl.com/post/9324.html

展开阅读全文

4 评论

留下您的评论.