当前位置 : 主页 > 编程语言 > java >

Pytorch预测

来源:互联网 收集:自由互联 发布时间:2022-07-04
本次测试输入 dog.png # Coding by ajupyter from PIL import Image from torch import nn import torch import torchvision class Model ( nn . Module ): def __init__ ( self ): super ( Model , self ). __init__ () self . model = nn . Sequentia

本次测试输入 dog.png
Pytorch预测_ide

# Coding by ajupyter
from PIL import Image
from torch import nn
import torch
import torchvision


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(64 * 4 * 4, 64),
nn.Linear(64, 10),
)

def forward(self, x):
output = self.model(x)
return output


model = Model()
model.load_state_dict(torch.load('cifa10_model-epoch19-test_loss699.5733557939529'))

model.eval()
with torch.no_grad():

image = Image.open('dog.png')
image = image.convert('RGB') # 适配jpg和png jpg是四通道:rgb+透明度通道

tool = torchvision.transforms.Compose([
torchvision.transforms.Resize(size=(32, 32)),
torchvision.transforms.ToTensor()
])
input = tool(image)
input = input.reshape((1, 3, 32, 32))
print(input.shape) # torch.Size([3, 32, 32])
res = model(input)
print(res)
print(res.argmax(1)) # 横向比较
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(f'res:{classes[res.argmax(1).item()]}') # only one element tensors can be converted to Python scalars


上一篇:27. 移除元素
下一篇:没有了
网友评论