当前位置 : 主页 > 编程语言 > python >

pytorch版本PSEnet训练并部署方式

来源:互联网 收集:自由互联 发布时间:2023-05-14
目录 概述 制作数据集 1、训练的数据集 2、将数据集分成训练集和测试集 3、训练 4、部署测试 总结 概述 源码地址 torch版本 训练环境没有按照torch的readme一样的环境,自己部署环境为:
目录
  • 概述
  • 制作数据集
    • 1、训练的数据集
    • 2、将数据集分成训练集和测试集
    • 3、训练
    • 4、部署测试
  • 总结

    概述

    源码地址

    torch版本

    训练环境没有按照torch的readme一样的环境,自己部署环境为:

    torch==1.9.1
    torchvision==0.10.1
    python==3.8.0
    cuda==10.2
    mmcv==0.2.12
    editdistance==0.5.3
    Polygon3==3.0.9.1
    pyclipper==1.3.0
    opencv-python==3.4.2.17
    Cython==0.29.24
    ./compile.sh

    制作数据集

    1、训练的数据集

    采用的是rolabelimg进行标注,需要转换为ic2015格式的数据。

    转换代码:

    import os
    from lxml import etree
    import numpy as np
    import math
    src_xml = "ANN"
    txt_dir = "gt"
    xml_listdir = os.listdir(src_xml)
    xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]
    def xml_out(xml_path):
        gt_lines = []
        ET = etree.parse(xml_path)
        objs = ET.findall("object")
        for ix,obj in enumerate(objs):
            name = obj.find("name").text
            robox = obj.find("robndbox")
            cx = int(float(robox.find("cx").text))
            cy = int(float(robox.find("cy").text))
            w = int(float(robox.find("w").text))
            h = int(float(robox.find("h").text))
            angle = float(robox.find("angle").text)
            # angle = math.degrees(angle1)
            wx1 = cx - int(0.5 * w)
            wy1 = cy - int(0.5 * h)
            wx2 = cx + int(0.5 * w)
            wy2 = cy - int(0.5 * h)
            wx3 = cx - int(0.5 * w)
            wy3 = cy + int(0.5 * h)
            wx4 = cx + int(0.5 * w)
            wy4 = cy + int(0.5 * h)
            x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
            y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)
            x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
            y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)
            x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
            y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)
            x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
            y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)
            lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
                    str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
            gt_lines.append(lines)
            return gt_lines
    def main():
        count = 0
        for xml_dir in xml_listdir:
            gt_lines = xml_out(os.path.join(src_xml,xml_dir))
            txt_path = "gt_" + xml_dir[:-4] + ".txt"
            with open(os.path.join(txt_dir,txt_path),"a+") as fd:
                fd.writelines(gt_lines)
            count +=1
            print("Write file %s" % str(count))
    if __name__ == "__main__":
        main()

    rolabelimg标注后的xml文件和labelimg的xml有些区别,根据不同的标注软件,转换代码略有区别。

    转换后的格式为x1,y1,x2,y2,x3,y3,x4,y4,"classes",此处classes为检测的类别,如果是模糊训练的话,classes为“###”。

    但是重点,这个源代码对于模糊训练,loss一直为1。

    2、将数据集分成训练集和测试集

    数据集

    这里可以按照源码路径存放数据集,也可以修改源码存放位置。

    PSENet-python3\dataset\psenet\psenet_ic15.py

    修改下述代码为自己文件夹

    3、训练

    CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py

    其中根据源码中的readme,

    可以根据自己的需要,自行选择配置文件。

    4、部署测试

    import torch
    import numpy as np
    import argparse
    import os
    import os.path as osp
    import sys
    import time
    import json
    from mmcv import Config
    import cv2
    from torchvision import transforms
    from dataset import build_data_loader
    from models import build_model
    from models.utils import fuse_module
    from utils import ResultFormat, AverageMeter
    def prepare_image(image, target_size):
        """Do image preprocessing before prediction on any data.
        :param image:       original image
        :param target_size: target image size
        :return:
                            preprocessed image
        """
        #assert os.path.exists(img), 'file is not exists'
        #img = cv2.imread(img)
        img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # h, w = image.shape[:2]
        # scale = long_size / max(h, w)
        img = cv2.resize(img, target_size)
        # 将图片由(w,h)变为(1,img_channel,h,w)
        tensor = transforms.ToTensor()(img)
        tensor = tensor.unsqueeze_(0)
        tensor = tensor.to(torch.device("cuda:0"))
        return tensor
    def report_speed(outputs, speed_meters):
        total_time = 0
        for key in outputs:
            if 'time' in key:
                total_time += outputs[key]
                speed_meters[key].update(outputs[key])
                print('%s: %.4f' % (key, speed_meters[key].avg))
        speed_meters['total_time'].update(total_time)
        print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
    def load_model(cfg):
        model = build_model(cfg.model)
        model = model.cuda()
        model.eval()
        checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
        if checkpoint is not None:
            if os.path.isfile(checkpoint):
                print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
                sys.stdout.flush()
                checkpoint = torch.load(checkpoint)
                d = dict()
                for key, value in checkpoint['state_dict'].items():
                    tmp = key[7:]
                    d[tmp] = value
                model.load_state_dict(d)
            else:
                print("No checkpoint found at")
                raise
            # fuse conv and bn
        model = fuse_module(model)
        return model
    if __name__ == '__main__':
        src_dir = "testimg/"
        save_dir = "test_save/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
        for d in [cfg, cfg.data.test]:
            d.update(dict(
                report_speed=False
            ))
        if cfg.report_speed:
            speed_meters = dict(
                backbone_time=AverageMeter(500),
                neck_time=AverageMeter(500),
                det_head_time=AverageMeter(500),
                det_pse_time=AverageMeter(500),
                rec_time=AverageMeter(500),
                total_time=AverageMeter(500)
            )
        model = load_model(cfg)
        model.eval()
        count = 0
        for img_name in os.listdir(src_dir):
            img = cv2.imread(src_dir + img_name)
            tensor = prepare_image(img, target_size=(1376, 1024))
            data = dict()
            img_metas = dict()
            data['imgs'] = tensor
            img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
            img_metas['img_size'] = torch.tensor([[1376, 1024]])
            data['img_metas'] = img_metas
            data.update(dict(
                cfg=cfg
            ))
            with torch.no_grad():
                outputs = model(**data)
            if cfg.report_speed:
                report_speed(outputs, speed_meters)
            for bboxes in outputs['bboxes']:
                x1 = bboxes[0]
                y1 = bboxes[1]
                x2 = bboxes[4]
                y2 = bboxes[5]
                cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
            count = count + 1
            cv2.imwrite(save_dir + img_name, img)
            print("img test:", count)
    from dataset import build_data_loader
    from models import build_model
    from models.utils import fuse_module
    from utils import ResultFormat, AverageMeter

    训练代码里含有。

    总结

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持自由互联。 

    网友评论