本文共 34008 字,大约阅读时间需要 113 分钟。
文中图片和部分内容、代码转自:
目标检测网络的训练大致流程为:
dataloader
model
loss
optimizer
loss
-反向传播首先,需要导入必要的库,然后设定各种超参数:
import time import torch.backends.cudnn as cudnnimport torch.optimimport torch.utils.datafrom model import tiny_detector, MultiBoxLossfrom datasets import PascalVOCDatasetfrom utils import *device = torch.device("cuda" if torch.cuda.is_available() else "cpu")cudnn.benchmark = True# Data parameters# 数据参数data_folder = '../../../dataset/VOCdevkit' # data files root pathkeep_difficult = True # use objects considered difficult to detect?n_classes = len(label_map) # number of different types of objects# Learning parameters# 训练相关超参数total_epochs = 100 # number of epochs to trainbatch_size = 32 # batch sizeworkers = 4 # number of workers for loading data in the DataLoaderprint_freq = 100 # print training status every __ batcheslr = 1e-3 # learning ratedecay_lr_at = [150, 190] # decay learning rate after these many epochsdecay_lr_to = 0.1 # decay learning rate to this fraction of the existing learning ratemomentum = 0.9 # momentumweight_decay = 5e-4 # weight decay
训练代码如下所示:
def main(): """ Training. """ # Initialize model and optimizer # 初始化模型和优化器 # 模型 model = tiny_detector(n_classes=n_classes) # 损失函数 criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy) # 优化器,使用SGD optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) # Move to default device model = model.to(device) criterion = criterion.to(device) # Custom dataloaders # 导入dataloader train_dataset = PascalVOCDataset(data_folder, split='train', keep_difficult=keep_difficult) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, num_workers=workers, pin_memory=True) # Epochs # 遍历训练数据,调用封装的单个epoch训练方法进行预测-计算loss-反向传播 for epoch in range(total_epochs): # Decay learning rate at particular epochs if epoch in decay_lr_at: adjust_learning_rate(optimizer, decay_lr_to) # One epoch's training train(train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, epoch=epoch) # Save checkpoint save_checkpoint(epoch, model, optimizer)
单个epoch
的训练逻辑由单独的train
方法来执行:
def train(train_loader, model, criterion, optimizer, epoch): """ One epoch's training. :param train_loader: DataLoader for training data :param model: model :param criterion: MultiBox loss :param optimizer: optimizer :param epoch: epoch number """ model.train() # training mode enables dropout batch_time = AverageMeter() # forward prop. + back prop. time data_time = AverageMeter() # data loading time losses = AverageMeter() # loss start = time.time() # Batches for i, (images, boxes, labels, _) in enumerate(train_loader): data_time.update(time.time() - start) # Move to default device images = images.to(device) # (batch_size (N), 3, 224, 224) boxes = [b.to(device) for b in boxes] labels = [l.to(device) for l in labels] # Forward prop. # 前向传播,预测定位和分类 predicted_locs, predicted_scores = model(images) # (N, 441, 4), (N, 441, n_classes) # Loss # 计算定位损失和回归损失 loss = criterion(predicted_locs, predicted_scores, boxes, labels) # scalar # Backward prop. # 反向传播, optimizer.zero_grad() loss.backward() # Update model # 更新模型,loss,梯度 optimizer.step() losses.update(loss.item(), images.size(0)) batch_time.update(time.time() - start) start = time.time() # Print status if i % print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) del predicted_locs, predicted_scores, images, boxes, labels # free some memory since their histories may be stored
根据上述代码,就可以开始训练模型,由于自己是薅的colab的GPU羊毛,性能不太好,就只设置了100
个epoch
进行训练,训练过程输出如下所示,最终loss为2.8288,如果使用预设的230
个epoch
进行训练,可能会更低:
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth100%528M/528M [00:15<00:00, 35.1MB/s]Loaded base model./usr/local/lib/python3.6/dist-packages/torch/nn/_reduction.py:44: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead. warnings.warn(warning.format(ret))Epoch: [0][0/157] Batch Time 77.123 (77.123) Data Time 72.689 (72.689) Loss 30.5775 (30.5775) Epoch: [0][100/157] Batch Time 19.484 (5.872) Data Time 18.983 (5.381) Loss 6.0357 (8.2452) Epoch: [1][0/157] Batch Time 5.594 (5.594) Data Time 4.979 (4.979) Loss 6.0772 (6.0772) Epoch: [1][100/157] Batch Time 0.496 (1.230) Data Time 0.000 (0.684) Loss 6.0197 (6.0285) Epoch: [2][0/157] Batch Time 6.235 (6.235) Data Time 5.657 (5.657) Loss 5.2538 (5.2538) Epoch: [2][100/157] Batch Time 1.465 (1.230) Data Time 0.897 (0.685) Loss 5.4543 (5.7394) Epoch: [3][0/157] Batch Time 5.732 (5.732) Data Time 5.124 (5.124) Loss 5.8163 (5.8163) Epoch: [3][100/157] Batch Time 0.559 (1.212) Data Time 0.000 (0.663) Loss 5.2891 (5.5478) Epoch: [4][0/157] Batch Time 5.123 (5.123) Data Time 4.528 (4.528) Loss 5.4803 (5.4803) Epoch: [4][100/157] Batch Time 3.388 (1.214) Data Time 2.783 (0.663) Loss 5.3299 (5.3545) Epoch: [5][0/157] Batch Time 5.464 (5.464) Data Time 4.862 (4.862) Loss 5.0741 (5.0741) Epoch: [5][100/157] Batch Time 3.011 (1.226) Data Time 2.397 (0.676) Loss 4.7024 (5.2014) Epoch: [6][0/157] Batch Time 5.926 (5.926) Data Time 5.291 (5.291) Loss 4.8069 (4.8069) Epoch: [6][100/157] Batch Time 2.511 (1.246) Data Time 1.877 (0.698) Loss 4.9602 (5.0523) Epoch: [7][0/157] Batch Time 5.003 (5.003) Data Time 4.366 (4.366) Loss 5.1470 (5.1470) Epoch: [7][100/157] Batch Time 2.256 (1.243) Data Time 1.726 (0.699) Loss 5.0753 (4.9043) Epoch: [8][0/157] Batch Time 5.193 (5.193) Data Time 4.594 (4.594) Loss 5.0136 (5.0136) Epoch: [8][100/157] Batch Time 3.316 (1.219) Data Time 2.648 (0.671) Loss 4.6186 (4.7953) Epoch: [9][0/157] Batch Time 5.369 (5.369) Data Time 4.788 (4.788) Loss 4.7506 (4.7506) Epoch: [9][100/157] Batch Time 0.510 (1.232) Data Time 0.001 (0.692) Loss 4.3336 (4.7193) Epoch: [10][0/157] Batch Time 4.986 (4.986) Data Time 4.392 (4.392) Loss 4.8641 (4.8641) Epoch: [10][100/157] Batch Time 0.496 (1.253) Data Time 0.000 (0.707) Loss 5.1990 (4.6216) Epoch: [11][0/157] Batch Time 4.314 (4.314) Data Time 3.711 (3.711) Loss 4.6118 (4.6118) Epoch: [11][100/157] Batch Time 1.450 (1.259) Data Time 0.914 (0.711) Loss 4.8840 (4.5299) Epoch: [12][0/157] Batch Time 4.678 (4.678) Data Time 4.052 (4.052) Loss 4.6417 (4.6417) Epoch: [12][100/157] Batch Time 2.718 (1.179) Data Time 2.100 (0.629) Loss 3.9749 (4.4553) Epoch: [13][0/157] Batch Time 5.581 (5.581) Data Time 5.004 (5.004) Loss 4.4255 (4.4255) Epoch: [13][100/157] Batch Time 1.199 (1.183) Data Time 0.644 (0.639) Loss 4.1124 (4.3908) Epoch: [14][0/157] Batch Time 6.078 (6.078) Data Time 5.476 (5.476) Loss 4.7143 (4.7143) Epoch: [14][100/157] Batch Time 0.495 (1.166) Data Time 0.000 (0.631) Loss 4.1552 (4.3045) Epoch: [15][0/157] Batch Time 4.724 (4.724) Data Time 4.146 (4.146) Loss 4.0838 (4.0838) Epoch: [15][100/157] Batch Time 2.247 (1.183) Data Time 1.731 (0.640) Loss 4.3365 (4.2980) Epoch: [16][0/157] Batch Time 5.820 (5.820) Data Time 5.259 (5.259) Loss 4.4975 (4.4975) Epoch: [16][100/157] Batch Time 2.967 (1.182) Data Time 2.380 (0.640) Loss 3.9704 (4.2184) Epoch: [17][0/157] Batch Time 4.830 (4.830) Data Time 4.197 (4.197) Loss 3.7967 (3.7967) Epoch: [17][100/157] Batch Time 3.528 (1.153) Data Time 2.918 (0.608) Loss 4.3692 (4.1701) Epoch: [18][0/157] Batch Time 4.642 (4.642) Data Time 3.991 (3.991) Loss 4.3692 (4.3692) Epoch: [18][100/157] Batch Time 0.536 (1.182) Data Time 0.000 (0.641) Loss 4.5552 (4.1860) Epoch: [19][0/157] Batch Time 6.237 (6.237) Data Time 5.656 (5.656) Loss 4.6410 (4.6410) Epoch: [19][100/157] Batch Time 2.577 (1.238) Data Time 1.997 (0.688) Loss 3.9425 (4.0813) Epoch: [20][0/157] Batch Time 6.347 (6.347) Data Time 5.727 (5.727) Loss 3.8097 (3.8097) Epoch: [20][100/157] Batch Time 3.339 (1.263) Data Time 2.729 (0.713) Loss 4.4440 (4.0131) Epoch: [21][0/157] Batch Time 5.024 (5.024) Data Time 4.458 (4.458) Loss 3.8048 (3.8048) Epoch: [21][100/157] Batch Time 3.248 (1.252) Data Time 2.656 (0.692) Loss 4.3952 (4.0339) Epoch: [22][0/157] Batch Time 4.521 (4.521) Data Time 3.895 (3.895) Loss 3.8010 (3.8010) Epoch: [22][100/157] Batch Time 0.544 (1.236) Data Time 0.000 (0.683) Loss 3.6530 (3.9895) Epoch: [23][0/157] Batch Time 5.147 (5.147) Data Time 4.546 (4.546) Loss 3.9452 (3.9452) Epoch: [23][100/157] Batch Time 1.326 (1.212) Data Time 0.760 (0.664) Loss 3.8949 (3.9554) Epoch: [24][0/157] Batch Time 5.387 (5.387) Data Time 4.732 (4.732) Loss 4.3835 (4.3835) Epoch: [24][100/157] Batch Time 0.486 (1.209) Data Time 0.000 (0.667) Loss 3.6183 (3.9049) Epoch: [25][0/157] Batch Time 5.166 (5.166) Data Time 4.524 (4.524) Loss 3.5827 (3.5827) Epoch: [25][100/157] Batch Time 3.381 (1.204) Data Time 2.808 (0.662) Loss 3.8020 (3.8971) Epoch: [26][0/157] Batch Time 5.533 (5.533) Data Time 4.931 (4.931) Loss 4.1527 (4.1527) Epoch: [26][100/157] Batch Time 2.147 (1.161) Data Time 1.596 (0.615) Loss 3.7455 (3.8749) Epoch: [27][0/157] Batch Time 4.821 (4.821) Data Time 4.150 (4.150) Loss 4.1030 (4.1030) Epoch: [27][100/157] Batch Time 0.552 (1.146) Data Time 0.001 (0.596) Loss 3.4264 (3.8422) Epoch: [28][0/157] Batch Time 4.878 (4.878) Data Time 4.316 (4.316) Loss 3.4356 (3.4356) Epoch: [28][100/157] Batch Time 3.690 (1.176) Data Time 3.110 (0.633) Loss 3.7646 (3.8147) Epoch: [29][0/157] Batch Time 5.668 (5.668) Data Time 5.070 (5.070) Loss 3.8441 (3.8441) Epoch: [29][100/157] Batch Time 2.368 (1.167) Data Time 1.840 (0.625) Loss 4.0635 (3.8037) Epoch: [30][0/157] Batch Time 5.029 (5.029) Data Time 4.477 (4.477) Loss 3.9522 (3.9522) Epoch: [30][100/157] Batch Time 0.552 (1.165) Data Time 0.000 (0.625) Loss 3.7617 (3.7656) Epoch: [31][0/157] Batch Time 4.593 (4.593) Data Time 3.973 (3.973) Loss 3.8892 (3.8892) Epoch: [31][100/157] Batch Time 2.022 (1.207) Data Time 1.424 (0.661) Loss 3.3591 (3.7424) Epoch: [32][0/157] Batch Time 5.977 (5.977) Data Time 5.354 (5.354) Loss 3.7753 (3.7753) Epoch: [32][100/157] Batch Time 0.525 (1.240) Data Time 0.000 (0.693) Loss 3.6649 (3.7522) Epoch: [33][0/157] Batch Time 6.381 (6.381) Data Time 5.811 (5.811) Loss 3.5785 (3.5785) Epoch: [33][100/157] Batch Time 0.779 (1.190) Data Time 0.276 (0.647) Loss 3.4808 (3.6498) Epoch: [34][0/157] Batch Time 5.983 (5.983) Data Time 5.435 (5.435) Loss 3.7731 (3.7731) Epoch: [34][100/157] Batch Time 3.039 (1.279) Data Time 2.437 (0.730) Loss 3.6683 (3.7440) Epoch: [35][0/157] Batch Time 6.553 (6.553) Data Time 5.929 (5.929) Loss 3.7103 (3.7103) Epoch: [35][100/157] Batch Time 1.151 (1.231) Data Time 0.640 (0.690) Loss 3.9885 (3.6810) Epoch: [36][0/157] Batch Time 6.034 (6.034) Data Time 5.432 (5.432) Loss 4.1681 (4.1681) Epoch: [36][100/157] Batch Time 2.362 (1.253) Data Time 1.817 (0.702) Loss 3.4722 (3.6904) Epoch: [37][0/157] Batch Time 5.670 (5.670) Data Time 5.117 (5.117) Loss 3.7901 (3.7901) Epoch: [37][100/157] Batch Time 3.908 (1.224) Data Time 3.346 (0.678) Loss 3.7887 (3.6416) Epoch: [38][0/157] Batch Time 5.631 (5.631) Data Time 4.997 (4.997) Loss 3.5071 (3.5071) Epoch: [38][100/157] Batch Time 1.276 (1.226) Data Time 0.753 (0.680) Loss 3.3913 (3.6304) Epoch: [39][0/157] Batch Time 5.360 (5.360) Data Time 4.768 (4.768) Loss 3.3580 (3.3580) Epoch: [39][100/157] Batch Time 0.536 (1.227) Data Time 0.000 (0.680) Loss 3.3767 (3.6263) Epoch: [40][0/157] Batch Time 4.478 (4.478) Data Time 3.904 (3.904) Loss 3.3700 (3.3700) Epoch: [40][100/157] Batch Time 1.889 (1.255) Data Time 1.315 (0.704) Loss 3.7565 (3.5814) Epoch: [41][0/157] Batch Time 6.011 (6.011) Data Time 5.398 (5.398) Loss 3.6425 (3.6425) Epoch: [41][100/157] Batch Time 0.543 (1.214) Data Time 0.000 (0.666) Loss 3.8458 (3.5214) Epoch: [42][0/157] Batch Time 4.987 (4.987) Data Time 4.386 (4.386) Loss 3.3984 (3.3984) Epoch: [42][100/157] Batch Time 0.528 (1.190) Data Time 0.000 (0.636) Loss 3.6731 (3.5428) Epoch: [43][0/157] Batch Time 5.474 (5.474) Data Time 4.773 (4.773) Loss 3.2220 (3.2220) Epoch: [43][100/157] Batch Time 1.879 (1.212) Data Time 1.287 (0.670) Loss 3.2955 (3.5607) Epoch: [44][0/157] Batch Time 4.998 (4.998) Data Time 4.437 (4.437) Loss 3.3769 (3.3769) Epoch: [44][100/157] Batch Time 2.914 (1.230) Data Time 2.285 (0.681) Loss 3.6326 (3.5706) Epoch: [45][0/157] Batch Time 5.574 (5.574) Data Time 4.944 (4.944) Loss 3.1715 (3.1715) Epoch: [45][100/157] Batch Time 0.527 (1.250) Data Time 0.003 (0.700) Loss 3.5300 (3.4923) Epoch: [46][0/157] Batch Time 6.143 (6.143) Data Time 5.519 (5.519) Loss 3.7597 (3.7597) Epoch: [46][100/157] Batch Time 4.015 (1.261) Data Time 3.322 (0.698) Loss 3.5858 (3.4622) Epoch: [47][0/157] Batch Time 6.221 (6.221) Data Time 5.624 (5.624) Loss 3.4800 (3.4800) Epoch: [47][100/157] Batch Time 3.911 (1.254) Data Time 3.400 (0.700) Loss 3.8922 (3.4871) Epoch: [48][0/157] Batch Time 4.968 (4.968) Data Time 4.399 (4.399) Loss 3.4070 (3.4070) Epoch: [48][100/157] Batch Time 0.528 (1.152) Data Time 0.000 (0.612) Loss 3.5273 (3.5174) Epoch: [49][0/157] Batch Time 4.905 (4.905) Data Time 4.351 (4.351) Loss 3.6661 (3.6661) Epoch: [49][100/157] Batch Time 0.538 (1.171) Data Time 0.002 (0.624) Loss 3.5492 (3.4724) Epoch: [50][0/157] Batch Time 6.191 (6.191) Data Time 5.580 (5.580) Loss 3.6665 (3.6665) Epoch: [50][100/157] Batch Time 0.508 (1.200) Data Time 0.000 (0.654) Loss 3.4365 (3.4505) Epoch: [51][0/157] Batch Time 5.117 (5.117) Data Time 4.513 (4.513) Loss 3.2535 (3.2535) Epoch: [51][100/157] Batch Time 0.547 (1.191) Data Time 0.000 (0.644) Loss 3.3378 (3.4166) Epoch: [52][0/157] Batch Time 5.958 (5.958) Data Time 5.318 (5.318) Loss 3.8985 (3.8985) Epoch: [52][100/157] Batch Time 0.523 (1.203) Data Time 0.003 (0.658) Loss 3.5762 (3.4173) Epoch: [53][0/157] Batch Time 4.088 (4.088) Data Time 3.494 (3.494) Loss 3.0108 (3.0108) Epoch: [53][100/157] Batch Time 0.525 (1.186) Data Time 0.000 (0.644) Loss 3.2233 (3.4033) Epoch: [54][0/157] Batch Time 5.307 (5.307) Data Time 4.641 (4.641) Loss 3.1964 (3.1964) Epoch: [54][100/157] Batch Time 0.638 (1.181) Data Time 0.000 (0.635) Loss 3.3944 (3.3447) Epoch: [55][0/157] Batch Time 5.958 (5.958) Data Time 5.350 (5.350) Loss 3.1831 (3.1831) Epoch: [55][100/157] Batch Time 3.058 (1.217) Data Time 2.518 (0.673) Loss 3.3706 (3.3732) Epoch: [56][0/157] Batch Time 5.266 (5.266) Data Time 4.643 (4.643) Loss 3.3416 (3.3416) Epoch: [56][100/157] Batch Time 3.273 (1.234) Data Time 2.736 (0.688) Loss 2.8699 (3.3404) Epoch: [57][0/157] Batch Time 5.889 (5.889) Data Time 5.233 (5.233) Loss 3.1328 (3.1328) Epoch: [57][100/157] Batch Time 2.653 (1.236) Data Time 2.030 (0.681) Loss 3.2941 (3.3612) Epoch: [58][0/157] Batch Time 6.120 (6.120) Data Time 5.494 (5.494) Loss 3.4540 (3.4540) Epoch: [58][100/157] Batch Time 2.134 (1.249) Data Time 1.612 (0.701) Loss 3.4945 (3.3431) Epoch: [59][0/157] Batch Time 5.383 (5.383) Data Time 4.712 (4.712) Loss 3.4249 (3.4249) Epoch: [59][100/157] Batch Time 0.530 (1.160) Data Time 0.000 (0.615) Loss 4.0819 (3.3691) Epoch: [60][0/157] Batch Time 4.583 (4.583) Data Time 3.981 (3.981) Loss 3.4435 (3.4435) Epoch: [60][100/157] Batch Time 1.201 (1.146) Data Time 0.659 (0.608) Loss 3.2914 (3.3157) Epoch: [61][0/157] Batch Time 6.419 (6.419) Data Time 5.823 (5.823) Loss 3.1378 (3.1378) Epoch: [61][100/157] Batch Time 3.009 (1.169) Data Time 2.393 (0.620) Loss 3.1612 (3.2938) Epoch: [62][0/157] Batch Time 5.576 (5.576) Data Time 4.975 (4.975) Loss 3.7410 (3.7410) Epoch: [62][100/157] Batch Time 0.984 (1.139) Data Time 0.449 (0.598) Loss 3.5736 (3.3208) Epoch: [63][0/157] Batch Time 4.701 (4.701) Data Time 4.110 (4.110) Loss 3.9091 (3.9091) Epoch: [63][100/157] Batch Time 0.513 (1.139) Data Time 0.000 (0.600) Loss 3.1188 (3.2731) Epoch: [64][0/157] Batch Time 5.912 (5.912) Data Time 5.296 (5.296) Loss 3.4835 (3.4835) Epoch: [64][100/157] Batch Time 0.562 (1.155) Data Time 0.000 (0.611) Loss 3.2797 (3.2735) Epoch: [65][0/157] Batch Time 5.282 (5.282) Data Time 4.659 (4.659) Loss 3.4425 (3.4425) Epoch: [65][100/157] Batch Time 0.504 (1.160) Data Time 0.000 (0.612) Loss 3.4175 (3.2727) Epoch: [66][0/157] Batch Time 5.218 (5.218) Data Time 4.602 (4.602) Loss 3.7070 (3.7070) Epoch: [66][100/157] Batch Time 0.547 (1.145) Data Time 0.003 (0.601) Loss 3.0053 (3.2533) Epoch: [67][0/157] Batch Time 4.934 (4.934) Data Time 4.278 (4.278) Loss 3.0060 (3.0060) Epoch: [67][100/157] Batch Time 0.604 (1.179) Data Time 0.058 (0.632) Loss 3.3017 (3.2353) Epoch: [68][0/157] Batch Time 4.207 (4.207) Data Time 3.648 (3.648) Loss 3.6687 (3.6687) Epoch: [68][100/157] Batch Time 3.075 (1.194) Data Time 2.532 (0.662) Loss 3.1581 (3.2250) Epoch: [69][0/157] Batch Time 5.361 (5.361) Data Time 4.762 (4.762) Loss 3.4243 (3.4243) Epoch: [69][100/157] Batch Time 0.496 (1.197) Data Time 0.000 (0.653) Loss 3.0237 (3.2431) Epoch: [70][0/157] Batch Time 5.571 (5.571) Data Time 4.901 (4.901) Loss 3.2671 (3.2671) Epoch: [70][100/157] Batch Time 2.331 (1.184) Data Time 1.760 (0.639) Loss 3.4042 (3.1733) Epoch: [71][0/157] Batch Time 5.644 (5.644) Data Time 5.045 (5.045) Loss 3.6682 (3.6682) Epoch: [71][100/157] Batch Time 0.537 (1.205) Data Time 0.006 (0.656) Loss 3.0481 (3.2186) Epoch: [72][0/157] Batch Time 5.808 (5.808) Data Time 5.251 (5.251) Loss 3.2464 (3.2464) Epoch: [72][100/157] Batch Time 3.153 (1.264) Data Time 2.593 (0.711) Loss 3.2885 (3.2067) Epoch: [73][0/157] Batch Time 6.785 (6.785) Data Time 6.191 (6.191) Loss 3.1055 (3.1055) Epoch: [73][100/157] Batch Time 2.170 (1.254) Data Time 1.584 (0.698) Loss 3.3191 (3.2226) Epoch: [74][0/157] Batch Time 5.423 (5.423) Data Time 4.851 (4.851) Loss 2.9773 (2.9773) Epoch: [74][100/157] Batch Time 0.556 (1.213) Data Time 0.000 (0.669) Loss 2.7604 (3.1649) Epoch: [75][0/157] Batch Time 4.887 (4.887) Data Time 4.281 (4.281) Loss 3.1617 (3.1617) Epoch: [75][100/157] Batch Time 2.355 (1.254) Data Time 1.807 (0.702) Loss 3.0595 (3.1706) Epoch: [76][0/157] Batch Time 6.033 (6.033) Data Time 5.457 (5.457) Loss 3.7104 (3.7104) Epoch: [76][100/157] Batch Time 0.584 (1.234) Data Time 0.000 (0.683) Loss 3.4489 (3.1944) Epoch: [77][0/157] Batch Time 6.244 (6.244) Data Time 5.608 (5.608) Loss 3.3040 (3.3040) Epoch: [77][100/157] Batch Time 4.001 (1.258) Data Time 3.410 (0.709) Loss 3.4311 (3.1946) Epoch: [78][0/157] Batch Time 6.309 (6.309) Data Time 5.726 (5.726) Loss 3.4853 (3.4853) Epoch: [78][100/157] Batch Time 0.582 (1.244) Data Time 0.000 (0.693) Loss 3.7759 (3.1751) Epoch: [79][0/157] Batch Time 5.059 (5.059) Data Time 4.480 (4.480) Loss 3.0116 (3.0116) Epoch: [79][100/157] Batch Time 3.091 (1.244) Data Time 2.510 (0.696) Loss 3.1623 (3.1491) Epoch: [80][0/157] Batch Time 5.154 (5.154) Data Time 4.569 (4.569) Loss 2.9073 (2.9073) Epoch: [80][100/157] Batch Time 0.522 (1.271) Data Time 0.000 (0.720) Loss 2.8975 (3.1447) Epoch: [81][0/157] Batch Time 5.520 (5.520) Data Time 4.959 (4.959) Loss 3.0765 (3.0765) Epoch: [81][100/157] Batch Time 0.512 (1.257) Data Time 0.001 (0.708) Loss 3.0712 (3.1235) Epoch: [82][0/157] Batch Time 5.532 (5.532) Data Time 4.900 (4.900) Loss 2.6729 (2.6729) Epoch: [82][100/157] Batch Time 0.534 (1.227) Data Time 0.000 (0.673) Loss 2.6621 (3.1381) Epoch: [83][0/157] Batch Time 5.431 (5.431) Data Time 4.819 (4.819) Loss 2.6671 (2.6671) Epoch: [83][100/157] Batch Time 0.509 (1.241) Data Time 0.000 (0.687) Loss 3.6746 (3.1360) Epoch: [84][0/157] Batch Time 5.230 (5.230) Data Time 4.691 (4.691) Loss 3.2582 (3.2582) Epoch: [84][100/157] Batch Time 2.343 (1.207) Data Time 1.840 (0.669) Loss 2.7627 (3.1041) Epoch: [85][0/157] Batch Time 5.069 (5.069) Data Time 4.473 (4.473) Loss 2.5432 (2.5432) Epoch: [85][100/157] Batch Time 2.777 (1.210) Data Time 2.213 (0.662) Loss 2.9933 (3.0721) Epoch: [86][0/157] Batch Time 4.928 (4.928) Data Time 4.376 (4.376) Loss 3.0606 (3.0606) Epoch: [86][100/157] Batch Time 2.373 (1.195) Data Time 1.873 (0.654) Loss 3.1929 (3.1299) Epoch: [87][0/157] Batch Time 5.272 (5.272) Data Time 4.672 (4.672) Loss 3.1911 (3.1911) Epoch: [87][100/157] Batch Time 0.562 (1.164) Data Time 0.000 (0.622) Loss 2.9478 (3.0708) Epoch: [88][0/157] Batch Time 5.334 (5.334) Data Time 4.778 (4.778) Loss 2.8791 (2.8791) Epoch: [88][100/157] Batch Time 0.537 (1.158) Data Time 0.000 (0.625) Loss 3.3845 (3.0651) Epoch: [89][0/157] Batch Time 4.419 (4.419) Data Time 3.826 (3.826) Loss 3.4761 (3.4761) Epoch: [89][100/157] Batch Time 3.394 (1.205) Data Time 2.836 (0.660) Loss 3.0615 (3.0566) Epoch: [90][0/157] Batch Time 4.759 (4.759) Data Time 4.208 (4.208) Loss 2.8328 (2.8328) Epoch: [90][100/157] Batch Time 0.530 (1.171) Data Time 0.000 (0.628) Loss 3.0565 (3.0466) Epoch: [91][0/157] Batch Time 4.423 (4.423) Data Time 3.848 (3.848) Loss 3.2111 (3.2111) Epoch: [91][100/157] Batch Time 2.584 (1.167) Data Time 1.984 (0.625) Loss 3.0058 (3.0079) Epoch: [92][0/157] Batch Time 5.392 (5.392) Data Time 4.826 (4.826) Loss 3.0817 (3.0817) Epoch: [92][100/157] Batch Time 3.003 (1.162) Data Time 2.503 (0.628) Loss 2.6739 (3.0199) Epoch: [93][0/157] Batch Time 4.586 (4.586) Data Time 3.948 (3.948) Loss 2.7553 (2.7553) Epoch: [93][100/157] Batch Time 0.523 (1.176) Data Time 0.000 (0.627) Loss 3.4794 (3.0703) Epoch: [94][0/157] Batch Time 5.172 (5.172) Data Time 4.566 (4.566) Loss 2.8152 (2.8152) Epoch: [94][100/157] Batch Time 1.623 (1.170) Data Time 1.111 (0.630) Loss 2.9180 (3.0037) Epoch: [95][0/157] Batch Time 5.520 (5.520) Data Time 4.952 (4.952) Loss 2.7869 (2.7869) Epoch: [95][100/157] Batch Time 3.065 (1.157) Data Time 2.481 (0.613) Loss 3.0694 (3.0276) Epoch: [96][0/157] Batch Time 5.285 (5.285) Data Time 4.680 (4.680) Loss 2.8010 (2.8010) Epoch: [96][100/157] Batch Time 0.530 (1.152) Data Time 0.000 (0.611) Loss 2.9136 (3.0267) Epoch: [97][0/157] Batch Time 4.322 (4.322) Data Time 3.668 (3.668) Loss 2.9088 (2.9088) Epoch: [97][100/157] Batch Time 1.079 (1.152) Data Time 0.574 (0.621) Loss 2.6336 (3.0073) Epoch: [98][0/157] Batch Time 6.046 (6.046) Data Time 5.471 (5.471) Loss 3.0895 (3.0895) Epoch: [98][100/157] Batch Time 2.745 (1.173) Data Time 2.151 (0.632) Loss 3.4441 (3.0436) Epoch: [99][0/157] Batch Time 4.769 (4.769) Data Time 4.170 (4.170) Loss 2.9008 (2.9008) Epoch: [99][100/157] Batch Time 2.950 (1.175) Data Time 2.324 (0.634) Loss 2.8288 (3.0305)
由于模型不是直接预测的目标框信息,而是预测的基于anchor
的偏移,且经过了编码。
因此后处理的第一步,就是要对模型的回归头的输出进行解码,拿到真正意义上的目标框的预测结果。
由于预设了大量的先验框,因此预测时在目标周围会形成大量高度重合的检测框,而目标检测的结果只希望保留一个足够准确的预测框,所以需要使用某些算法对检测框去重。这里使用NMS
算法。
NMS
非极大值抑制大致的算法步骤为:
IOU
高于一个阈值(自己设定,如0.5)的框认为需要被抑制,从剩余框数组中删除;IOU
高于设定阈值的框被抑制;这里对NMS
的解释其实不太好理解,看了另一个文章,感觉要容易理解一些。
使用NMS
主要是为了消除冗余的先验框,一个目标仅保留一个概率最大的先验框作为最终的结果。一张图片中可能会有多个目标,因此NMS
主要就是计算出局部极值。
代码实现如下:
def detect_objects(self, predicted_locs, predicted_scores, min_score, max_overlap, top_k): """ Decipher the 441 locations and class scores (output of the tiny_detector) to detect objects. For each class, perform Non-Maximum Suppression (NMS) on boxes that are above a minimum threshold. :param predicted_locs: predicted locations/boxes w.r.t the 441 prior boxes, a tensor of dimensions (N, 441, 4) :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 441, n_classes) :param min_score: minimum threshold for a box to be considered a match for a certain class :param max_overlap: maximum overlap two boxes can have so that the one with the lower score is not suppressed via NMS :param top_k: if there are a lot of resulting detection across all classes, keep only the top 'k' :return: detections (boxes, labels, and scores), lists of length batch_size """ batch_size = predicted_locs.size(0) n_priors = self.priors_cxcy.size(0) predicted_scores = F.softmax(predicted_scores, dim=2) # (N, 441, n_classes) # Lists to store final predicted boxes, labels, and scores for all images in batch all_images_boxes = list() all_images_labels = list() all_images_scores = list() assert n_priors == predicted_locs.size(1) == predicted_scores.size(1) for i in range(batch_size): # Decode object coordinates from the form we regressed predicted boxes to decoded_locs = cxcy_to_xy( gcxgcy_to_cxcy(predicted_locs[i], self.priors_cxcy)) # (441, 4), these are fractional pt. coordinates # Lists to store boxes and scores for this image image_boxes = list() image_labels = list() image_scores = list() max_scores, best_label = predicted_scores[i].max(dim=1) # (441) # Check for each class # 针对每一个类别进行分组处理 for c in range(1, self.n_classes): # Keep only predicted boxes and scores where scores for this class are above the minimum score # 仅保留达到一定阈值的先验框 class_scores = predicted_scores[i][:, c] # (441) score_above_min_score = class_scores > min_score # torch.uint8 (byte) tensor, for indexing n_above_min_score = score_above_min_score.sum().item() if n_above_min_score == 0: continue class_scores = class_scores[score_above_min_score] # (n_qualified), n_min_score <= 441 class_decoded_locs = decoded_locs[score_above_min_score] # (n_qualified, 4) # Sort predicted boxes and scores by scores # 基于分类置信度进行排序 class_scores, sort_ind = class_scores.sort(dim=0, descending=True) # (n_qualified), (n_min_score) class_decoded_locs = class_decoded_locs[sort_ind] # (n_min_score, 4) # Find the overlap between predicted boxes # 计算所有先验框之间的IOU overlap = find_jaccard_overlap(class_decoded_locs, class_decoded_locs) # (n_qualified, n_min_score) # Non-Maximum Suppression (NMS) # A torch.uint8 (byte) tensor to keep track of which predicted boxes to suppress # 1 implies suppress, 0 implies don't suppress suppress = torch.zeros((n_above_min_score), dtype=torch.uint8).to(device) # (n_qualified) # Consider each box in order of decreasing scores for box in range(class_decoded_locs.size(0)): # If this box is already marked for suppression if suppress[box] == 1: continue # Suppress boxes whose overlaps (with current box) are greater than maximum overlap # Find such boxes and update suppress indices suppress = torch.max(suppress, (overlap[box] > max_overlap).to(torch.uint8)) # The max operation retains previously suppressed boxes, like an 'OR' operation # Don't suppress this box, even though it has an overlap of 1 with itself suppress[box] = 0 # Store only unsuppressed boxes for this class image_boxes.append(class_decoded_locs[1 - suppress]) image_labels.append(torch.LongTensor((1 - suppress).sum().item() * [c]).to(device)) image_scores.append(class_scores[1 - suppress]) # If no object in any class is found, store a placeholder for 'background' if len(image_boxes) == 0: image_boxes.append(torch.FloatTensor([[0., 0., 1., 1.]]).to(device)) image_labels.append(torch.LongTensor([0]).to(device)) image_scores.append(torch.FloatTensor([0.]).to(device)) # Concatenate into single tensors image_boxes = torch.cat(image_boxes, dim=0) # (n_objects, 4) image_labels = torch.cat(image_labels, dim=0) # (n_objects) image_scores = torch.cat(image_scores, dim=0) # (n_objects) n_objects = image_scores.size(0) # Keep only the top k objects if n_objects > top_k: image_scores, sort_ind = image_scores.sort(dim=0, descending=True) image_scores = image_scores[:top_k] # (top_k) image_boxes = image_boxes[sort_ind][:top_k] # (top_k, 4) image_labels = image_labels[sort_ind][:top_k] # (top_k) # Append to lists that store predicted boxes and scores for all images all_images_boxes.append(image_boxes) all_images_labels.append(image_labels) all_images_scores.append(image_scores) return all_images_boxes, all_images_labels, all_images_scores # lists of length batch_size
上面实现的NMS
部分有点绕,可以参考Fast R0CNN
中的NMS
实现,会更简洁清晰:
# --------------------------------------------------------# Fast R-CNN# Copyright (c) 2015 Microsoft# Licensed under The MIT License [see LICENSE for details]# Written by Ross Girshick# --------------------------------------------------------import numpy as np# dets: 检测的 boxes 及对应的 scores;# thresh: 设定的阈值def nms(dets,thresh): # boxes 位置 x1 = dets[:,0] y1 = dets[:,1] x2 = dets[:,2] y2 = dets[:,3] # boxes scores scores = dets[:,4] areas = (x2-x1+1)*(y2-y1+1) # 各box的面积 order = scores.argsort()[::-1] # 分类置信度排序 keep = [] # 记录保留下的 boxes while order.size > 0: i = order[0] # score最大的box对应的 index keep.append(i) # 将本轮score最大的box的index保留 \# 计算剩余 boxes 与当前 box 的重叠程度 IoU xx1 = np.maximum(x1[i],x1[order[1:]]) yy1 = np.maximum(y1[i],y1[order[1:]]) xx2 = np.minimum(x2[i],x2[order[1:]]) yy2 = np.minimum(y2[i],y2[order[1:]]) w = np.maximum(0.0,xx2-xx1+1) # IoU h = np.maximum(0.0,yy2-yy1+1) inter = w*h ovr = inter/(areas[i]+areas[order[1:]]-inter) \# 保留 IoU 小于设定阈值的 boxes inds = np.where(ovr<=thresh)[0] order = order[inds+1] return keep
训练完模型后,就可以使用模型对单张图片进行推理,得到目标检测的结果。
首先需要导入必要的python
包,然后加载训练好的模型权重。
然后需要定义预处理函数,将测试环节的预处理方案和训练时保持一致,仅去除掉数据增强相关的变换即可。
因此,这里需要进行的数据预处理为:
224*224
的大小;Tensor
并除以255
# Set detect transforms (It's important to be consistent with training)resize = transforms.Resize((224, 224))to_tensor = transforms.ToTensor()normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
接下来进行推理,核心流程为:
核心代码为:
# Transform the imageimage = normalize(to_tensor(resize(original_image)))# Move to default deviceimage = image.to(device)# Forward prop.# 前向传播预测predicted_locs, predicted_scores = model(image.unsqueeze(0))# Post process, get the final detect objects from our tiny detector output# 进行后处理# 先对模型的输出进行解码,得到代表具体位置信息的预测框# 随后基于NMS来过滤一些多余的检测框det_boxes, det_labels, det_scores = model.detect_objects(predicted_locs, predicted_scores, min_score=min_score, max_overlap=max_overlap, top_k=top_k)
通过模型得到的检测结果:
可以看出模型对于一些简单的测试图片的检测效果不错,但是更难的一些图片的预测效果就不怎么好,特别是自己仅训练了100
个epoch
,就暴露了各种各样的问题:
map
指标真实值是positive,模型认为是positive的数量(True Positive=TP)
真实值是positive,模型认为是negative的数量(False Negative = FN):这就是统计学上的第二类错误(Type II Error)
真实值是negative,模型认为是positive的数量(False Positive = FP):这就是统计学上的第一类错误(Type I Error)
真实值是negative,模型认为是negative的数量(True Negative = TN)
在机器学习领域,混淆矩阵(confusion matrix)),又称为可能性表格或错误矩阵。是一种特定的矩阵用来呈现算法性能的可视化效果,通常用于监督学习(非监督学习,通常用匹配矩阵:matching matrix)。其每一列代表预测值,每一行代表的是实际的类别。它能够非常容易的表明多个类别是否有混淆(也就是一个class被要射成另一个class)。
假设有一个用来对猫(cats)、狗(dogs)、兔子(rabbits)进行分类的系统,混淆矩阵就是为了进一步分析性能而对该算法测试结果做出的总结。假设总共有27只动物:8只猫、6条狗、13只兔子。结果的混淆矩阵如下表:
通过上面的四个二级指标,可以将混淆矩阵中数量的结果转化为0-1之间的比率。便于进行标准号的衡量。
F1 Score
,计算公式为: F 1 S c o r e = 2 P R P + R F1 Score = \frac{2PR}{P+R} F1Score=P+R2PR其中,P
代表Precision
,R
代表Recall
(召回率)。F1 Score
指标综合了Precision
与Recall
的产出结果。它的取值范围从0到1,1代表模型的输出最好,0代表模型的输出结果最差。
AP
指标即Average Precision
,即平均精确度。
mAP
即Mean Average Precision
即平均AP
值,是对多个验证集个体求平均AP
值,作为object detection
中衡量检测精度的指标。
在目标检测场景中计算AP
,需要引入P-R
曲线,即以Precision
和Recall
作为纵、横轴坐标的二维曲线。通过选取不同阈值时对应的精度和召回度画出,如下图所示:
P-R
曲线的总体趋势是:精度越高,召回越低,当召回达到1时,对应概率分数最低的正样本,这时正样本数量除以所有大于等于该阈值的样本数量就是最低的精度值。P-R
曲线围起来的面积就是AP
值,通常来说一个越好的分类器,AP
值越高。
在目标检测中,每一类都可以根据
Recall
和Precision
绘制P-R
曲线,AP
就是该曲线下的面积,mAP
就是所有类的AP
的平均值。
通过运行eval.py
脚本,评估模型在VOC2007测试集上的效果如下:
Evaluating: 0%| | 0/78 [00:00
可以看出模型的mAP
得分为49.6,得分比较低,主要是跟自己的模型训练epoch
次数较小有关,并且一些小物体的得分都比较低,说明了模型对于小物体,较为密集的物体的检测效果较为不好。
转载地址:http://kmmws.baihongyu.com/