資源簡介
mmdetection在2019年12月13號進行了新版本的更新,其中對api/train.py增加torch.distributed,這塊在windows下不支持,所以要在windows中訓練的話需要把v1.0rc1的版本的train與新版本的train進行合并,主要是去除torch.distributed以及_non_dist_train的修改為主。
代碼片段和文件信息
from?__future__?import?division
import?logging
import?random
import?numpy?as?np
import?re
from?collections?import?OrderedDict
import?torch
from?mmcv.runner?import?Runner?DistSamplerSeedHook?obj_from_dict
from?mmcv.parallel?import?MMDataParallel?MMDistributedDataParallel
from?mmdet?import?datasets
from?mmdet.core?import?(DistOptimizerHook?DistEvalmAPHook
????????????????????????CocoDistEvalRecallHook?CocoDistEvalmAPHook
????????????????????????Fp16OptimizerHook)
from?mmdet.datasets?import?build_dataloader?DATASETS
from?mmdet.models?import?RPN
#?from?.env?import?get_root_logger
def?get_root_logger(log_file=None?log_level=logging.INFO):
????logger?=?logging.getLogger(‘mmdet‘)
????#?if?the?logger?has?been?initialized?just?return?it
????if?logger.hasHandlers():
????????return?logger
????logging.basicConfig(
????????format=‘%(asctime)s?-?%(levelname)s?-?%(message)s‘?level=log_level)
????#?rank?_?=?get_dist_info()
????#?if?rank?!=?0:
????#?????logger.setLevel(‘ERROR‘)
????#?elif?log_file?is?not?None:
????#?????file_handler?=?logging.FileHandler(log_file?‘w‘)
????#?????file_handler.setFormatter(
????#?????????logging.Formatter(‘%(asctime)s?-?%(levelname)s?-?%(message)s‘))
????#?????file_handler.setLevel(log_level)
????#?????logger.addHandler(file_handler)
????return?logger
def?set_random_seed(seed?deterministic=False):
????“““Set?random?seed.
????Args:
????????seed?(int):?Seed?to?be?used.
????????deterministic?(bool):?Whether?to?set?the?deterministic?option?for
????????????CUDNN?backend?i.e.?set?‘torch.backends.cudnn.deterministic‘
????????????to?True?and?‘torch.backends.cudnn.benchmark‘?to?False.
????????????Default:?False.
????“““
????random.seed(seed)
????np.random.seed(seed)
????torch.manual_seed(seed)
????torch.cuda.manual_seed_all(seed)
????if?deterministic:
????????torch.backends.cudnn.deterministic?=?True
????????torch.backends.cudnn.benchmark?=?False
def?parse_losses(losses):
????log_vars?=?OrderedDict()
????for?loss_name?loss_value?in?losses.items():
????????if?isinstance(loss_value?torch.Tensor):
????????????log_vars[loss_name]?=?loss_value.mean()
????????elif?isinstance(loss_value?list):
????????????log_vars[loss_name]?=?sum(_loss.mean()?for?_loss?in?loss_value)
????????else:
????????????raise?TypeError(
????????????????‘{}?is?not?a?tensor?or?list?of?tensors‘.format(loss_name))
????loss?=?sum(_value?for?_key?_value?in?log_vars.items()?if?‘loss‘?in?_key)
????log_vars[‘loss‘]?=?loss
????for?name?in?log_vars:
????????log_vars[name]?=?log_vars[name].item()
????return?loss?log_vars
def?batch_processor(model?data?train_mode):
????losses?=?model(**data)
????loss?log_vars?=?parse_losses(losses)
????outputs?=?dict(
????????loss=loss?log_vars=log_vars?num_samples=len(data[‘img‘].data))
????return?outputs
def?train_detector(model
???????????????????dataset
???????????????????cfg
???????????????????distributed=False
??
評論
共有 條評論