paper:https://arxiv.org/pdf/2001.07685.pdf code: https://github.com/google-research/fixmatch 综述 该方法由Sohn 等提出,结合了伪标签和一致性正则化,极大地简化了整个方法。它在广泛
paper:https://arxiv.org/pdf/2001.07685.pdf code: https://github.com/google-research/fixmatch
综述
该方法由Sohn 等提出,结合了伪标签和一致性正则化,极大地简化了整个方法。它在广泛的基准测试中得到了最先进的结果。
如我们所见,我们在有标签图像上使用交叉熵损失训练一个监督模型。对于每一幅未标记的图像,分别采用弱增强和强增强方法得到两幅图像。弱增强的图像被传递给我们的模型,我们得到预测。把置信度最大的类的概率与阈值进行比较。如果它高于阈值,那么我们将这个类作为标签,即伪标签。然后,将强增强后的图像通过模型进行分类预测。该预测方法与基于交叉熵损失的伪标签的方法进行了比较。把两种损失合并来优化模型。
FixMatch流程图
代码
train.py
class Trainer(object):def __init__(self, cfg):
self.cfg = cfg
##########hyper parameters setting#################
self.net = get_model(cfg.num_classes, cfg.model_name).to(device)
optimizer = RAdam(params=self.net.parameters(), lr=cfg.lr, weight_decay=0.0001)
self.optimizer = Lookahead(optimizer)
milestones = [5 + x * 60 for x in range(5)]
# print(f'milestones:{milestones}')
scheduler_c = CyclicCosAnnealingLR(optimizer, milestones=milestones, eta_min=5e-5)
self.scheduler = LearningRateWarmUP(optimizer=optimizer, target_iteration=5, target_lr=0.003,
after_scheduler=scheduler_c)
self.criterion = ComboLoss().to(device)
self.G = GridMask(True, True)
self.best_acc = -100
def load_net(self, path):
self.net = torch.load(path, map_location='cuda:0')["model_state"]
# self.best_acc = torch.load(path, map_location='cuda:0')["best_acc"]
# print(f'best_acc: {self.best_acc}')
def train_one_epoch(self, loader):
num_samples = 0
running_loss = 0
trn_error = 0
self.net.train()
for images, masks in loader:
if self.cfg.cutMix:
images, masks = cutmix(images, masks)
if self.cfg.fmix:
w, h = images.size(-1), images.size(-2)
images, masks = fmix_seg(images, masks, alpha=1., decay_power=3., shape=(w, h))
images = images.to(device, dtype=torch.float)
if self.cfg.Grid:
images = self.G(images)
masks = torch.squeeze(masks.to(device))
# print("images'size:{},masks'size:{}".format(images.size(),masks.size()))
num_samples += int(images.size(0))
self.optimizer.zero_grad()
outputs, cls = self.net(images)
loss = self.criterion(outputs, masks, cls)
loss.backward()
batch_loss = loss.item()
self.optimizer.step()
running_loss += batch_loss
pred = get_predictions(outputs)
masks = masks.type(torch.cuda.LongTensor)
masks = masks.data.cpu()
trn_error += compute_error(pred, masks)
return running_loss / len(loader), trn_error / len(loader)
def validate(self, loader):
num_samples = 0
running_loss = 0
trn_error = 0
self.net.eval()
for images, masks in loader:
images = images.to(device, dtype=torch.float)
masks = torch.squeeze(masks.to(device))
num_samples += int(images.size(0))
outputs, cls = self.net(images)
loss = self.criterion(outputs, masks, cls)
batch_loss = loss.item()
running_loss += batch_loss
pred = get_predictions(outputs)
masks = masks.type(torch.cuda.LongTensor)
masks = masks.data.cpu()
trn_error += compute_error(pred, masks)
return running_loss / len(loader), trn_error / len(loader)
def train(self):
mkdir(self.cfg.model_save_path)
##########prepare dataset################################
train_loader, val_loader, test_loader = build_loader(self.cfg)
for epoch in range(self.cfg.num_epochs):
print("Epoch: {}/{}".format(epoch + 1, self.cfg.num_epochs))
# optimizer.step()
self.scheduler.step(epoch)
####################train####################################
train_loss, train_error = self.train_one_epoch(train_loader)
start = time.strftime("%H:%M:%S")
print(
f"epoch:{epoch + 1}/{self.cfg.num_epochs} | ⏰: {start} ",
f"Training Loss: {train_loss:.4f}.. ",
f"Training Acc: {1 - train_error:.4f}.. ",
)
######################valid##################################
val_loss, val_error = self.validate(val_loader)
start = time.strftime("%H:%M:%S")
print(
f"epoch:{epoch + 1}/{self.cfg.num_epochs} | ⏰: {start} ",
f"validation Loss: {val_loss:.4f}.. ",
f"validation Acc: {1 - val_error:.4f}.. ",
)
if 1 - val_error > self.best_acc:
state = {
"epoch": epoch + 1,
"model_state": self.net,
"best_acc": 1 - val_error
}
checkpoint = f'{self.cfg.model_name}_best.pth'
torch.save(state, os.path.join(self.cfg.model_save_path, checkpoint)) # save model
print("The model has saved successfully!")
self.best_acc = 1 - val_error
def train_one_epoch_semi(self, trainloader, testloader):
running_loss = 0
trn_error = 0
loader = zip(trainloader, testloader)
self.net.train()
for data_x, data_u in loader:
images_x, targets_x = data_x
images_u_w, images_u_s = data_u
# cpu ==> gpu
images_x = images_x.to(device, dtype=torch.float)
targets_x = torch.squeeze(targets_x.to(device))
images_u_w = images_u_w.to(device, dtype=torch.float)
images_u_s = images_u_s.to(device, dtype=torch.float)
if self.cfg.Grid:
images_x = self.G(images_x)
images_u_s = self.G(images_u_s)
# print("images'size:{},masks'size:{}".format(images.size(),masks.size()))
self.optimizer.zero_grad()
outputs_x, cls_x = self.net(images_x)
outputs_u_w, cls_u_w = self.net(images_u_w)
outputs_u_s, cls_u_s = self.net(images_u_s)
# get pseudo label
targets_u = outputs_u_w.ge(self.cfg.threshold).float()
loss_x = self.criterion(outputs_x, targets_x, cls_x)
loss_u = (self.criterion(outputs_u_s, torch.squeeze(targets_u), cls_x, reduction='none') * torch.squeeze(targets_u)).mean()
loss = loss_x + self.cfg.lambda_u * loss_u
loss.backward()
batch_loss = loss.item()
self.optimizer.step()
running_loss += batch_loss
pred = get_predictions(outputs_x)
masks = targets_x.type(torch.cuda.LongTensor)
masks = masks.data.cpu()
trn_error += compute_error(pred, masks)
return running_loss / len(trainloader), trn_error / len(trainloader)
def train_semi(self):
self.load_net(f'{self.cfg.model_save_path}/{self.cfg.model_name}_best.pth')
model_save_path = self.cfg.model_save_path + '_semi'
mkdir(model_save_path)
##########prepare dataset################################
train_loader, val_loader, test_loader = build_loader_v2(self.cfg)
for epoch in range(self.cfg.num_epochs):
print("Epoch: {}/{}".format(epoch + 1, self.cfg.num_epochs))
# optimizer.step()
self.scheduler.step(epoch)
####################train####################################
train_loss, train_error = self.train_one_epoch_semi(train_loader, test_loader)
start = time.strftime("%H:%M:%S")
print(
f"epoch:{epoch + 1}/{self.cfg.num_epochs} | ⏰: {start} ",
f"Training Loss: {train_loss:.4f}.. ",
f"Training Acc: {1 - train_error:.4f}.. ",
)
######################valid##################################
val_loss, val_error = self.validate(val_loader)
start = time.strftime("%H:%M:%S")
print(
f"epoch:{epoch + 1}/{self.cfg.num_epochs} | ⏰: {start} ",
f"validation Loss: {val_loss:.4f}.. ",
f"validation Acc: {1 - val_error:.4f}.. ",
)
if 1 - val_error > self.best_acc:
state = {
"epoch": epoch + 1,
"model_state": self.net,
"best_acc": 1 - val_error
}
checkpoint = f'{self.cfg.model_name}_best.pth'
torch.save(state, os.path.join(model_save_path, checkpoint)) # save model
print("The model has saved successfully!")
self.best_acc = 1 - val_error
dataset.py
from torch.utils.data import Dataset, DataLoaderfrom torch.utils.data.sampler import SubsetRandomSampler
import torch
import torchvision
from torchvision.transforms import Compose
import numpy as np
import cv2 as cv
import os
from random import sample
from utils.transforms import *
from utils.randaugment import *
from utils.grid import Grid
def img_to_tensor(img):
tensor = torch.from_numpy(img.transpose((2, 0, 1)))
return tensor
def to_monochrome(x):
# x_ = x.convert('L')
x_ = np.array(x).astype(np.float32) # convert image to monochrome
return x_
def to_tensor(x):
x_ = np.expand_dims(x, axis=0)
x_ = torch.from_numpy(x_)
return x_
ImageToTensor = torchvision.transforms.ToTensor
def custom_blur_demo(image):
kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], np.float32) #锐化
dst = cv.filter2D(image, -1, kernel=kernel)
return dst
class SasDataset(Dataset):
def __init__(self, root, mode='train', is_ndvi=False):
self.root = root
self.mode = mode
self.is_ndvi = is_ndvi
self.imgList = sorted(img for img in os.listdir(self.root))
self.transform = DualCompose([
RandomFlip(),
RandomRotate90(),
Rotate(),
Shift(),
CoarseDropout()
])
self.RA = RandomAugment(2, 10)
self.imgTransforms = Compose([img_to_tensor])
self.maskTransforms = Compose([
torchvision.transforms.Lambda(to_monochrome),
torchvision.transforms.Lambda(to_tensor),
])
def __getitem__(self, idx):
imgPath = os.path.join(self.root, self.imgList[idx])
img = np.load(imgPath)
img = custom_blur_demo(img)
imgName = os.path.split(imgPath)[-1].split('.')[0]
if self.mode == 'test':
batch_data = {'img': self.imgTransforms(img), 'file_name': imgName}
return batch_data
labelPath = imgPath.replace('images', 'labels').replace('npy', 'png')
mask = cv.imread(labelPath)/255
# data augmentation
if self.mode == 'train':
img, mask = self.transform(img, mask)
img = self.RA(img)
# img, mask =img.astype(np.float), mask.astype(np.float)
w, h = mask.shape[:2]
mask = mask[:, :, 0]
mask = np.reshape(mask, (w, h, 1)).transpose((2, 0, 1))
return self.imgTransforms(img), self.maskTransforms(np.squeeze(mask))
def __len__(self):
return len(self.imgList)
class USasDataset(Dataset):
def __init__(self, root, mode='train'):
self.root = root
self.mode = mode
self.imgList = sorted(img for img in os.listdir(self.root))
self.transform = DualCompose([
RandomFlip(),
RandomRotate90(),
Rotate(),
Shift(),
# Cutout(num_holes=20, max_h_size=20, max_w_size=20, fill_value=0)
])
self.RA = RandomAugment(2, 10)
self.imgTransforms = Compose([ImageToTensor()])
def __getitem__(self, idx):
imgPath = os.path.join(self.root, self.imgList[idx])
img = np.load(imgPath)
img = custom_blur_demo(img)
mask = np.zeros_like(img)
# weak data augmentation
img_w, _ = self.transform(img, mask)
# serious data augmentation
img_s = self.RA(img_w)
return self.imgTransforms(img_w), self.imgTransforms(img_s)
def __len__(self):
return len(self.imgList)