import numpy as np import matplotlib . pyplot as plt from keras . preprocessing . image import ImageDataGenerator , load_img , img_to_array from keras . layers import Conv2D , Flatten , MaxPooling2D , Dense from keras . models import Sequen
import matplotlib.pyplot as plt
from keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
from keras.layers import Conv2D,Flatten,MaxPooling2D,Dense
from keras.models import Sequential,load_model
import glob,os,random
import time
import keras
base_path = "datasets"
def look_dataset_num():
img_list = glob.glob(os.path.join(base_path, "*/*.jpg"))
print(len(img_list)) # 2307
# 随机查看数据,枚举
for i, img_path in enumerate(random.sample(img_list, 6)):
img = load_img(img_path)
img = img_to_array(img, dtype=np.uint8)
# 子图
plt.subplot(2, 3, i + 1)
plt.imshow(img.squeeze())
plt.show()
def crate_model():
start = time.time()
train_datagen = ImageDataGenerator(
rescale=1. / 225, shear_range=0.1, zoom_range=0.1,
width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True,
vertical_flip=True, validation_split=0.1)
test_datagen = ImageDataGenerator(
rescale=1. / 255, validation_split=0.1)
train_generator = train_datagen.flow_from_directory(
base_path, target_size=(300, 300), batch_size=16,
class_mode='categorical', subset='training', seed=0)
# Found 2276 images belonging to 6 classes.
validation_generator = test_datagen.flow_from_directory(
base_path, target_size=(300, 300), batch_size=16,
class_mode='categorical', subset='validation', seed=0)
# Found 251 images belonging to 6 classes.
a = (validation_generator.class_indices)
a = dict((v, k) for k, v in a.items())
labels = (train_generator.class_indices)
labels = dict((v, k) for k, v in labels.items())
print('train_datagen ', a)
# train_datagen {0: 'cardboard', 1: 'glass', 2: 'metal', 3: 'paper', 4: 'plastic', 5: 'trash'}
print('test_datagen', train_datagen)
# test_datagen <keras.preprocessing.image.ImageDataGenerator object at 0x000002B54BB429B0>
print('labels', labels)
# labels {0: 'cardboard', 1: 'glass', 2: 'metal', 3: 'paper', 4: 'plastic', 5: 'trash'}
# 4.模型的建立和训练
model = Sequential([
Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(300, 300, 3)),
MaxPooling2D(pool_size=2),
Conv2D(filters=64, kernel_size=3, padding='same', activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)),
MaxPooling2D(pool_size=2),
Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)),
MaxPooling2D(pool_size=2),
Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)),
MaxPooling2D(pool_size=2),
Flatten(),
Dense(64, activation='relu'),
Dense(6, activation='softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
model.fit_generator(train_generator, epochs=100, steps_per_epoch=2276 // 32, validation_data=validation_generator,
validation_steps=251 // 32)
model.save('rubbish/rubbish_model.h5')
#
# 5.结果展示
# 下面我们随机抽取validation中的16张图片,展示图片以及其标签,并且给予我们的预测。
# 我们发现预测的准确度还是蛮高的,对于大部分图片,都能识别出其类别。
test_x, test_y = validation_generator.__getitem__(1)
preds = model.predict(test_x)
plt.figure(figsize=(16, 16))
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.title('pred:%s / truth:%s' % (labels[np.argmax(preds[i])], labels[np.argmax(test_y[i])]))
plt.imshow(test_x[i])
plt.show()
end = time.time()
t = end - start
print('运行time', t)
def use_model():
train_datagen = ImageDataGenerator(
rescale=1. / 225, shear_range=0.1, zoom_range=0.1,
width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True,
vertical_flip=True, validation_split=0.1)
test_datagen = ImageDataGenerator(
rescale=1. / 255, validation_split=0.1)
train_generator = train_datagen.flow_from_directory(
base_path, target_size=(300, 300), batch_size=36,
class_mode='categorical', subset='training', seed=0)
# Found 2276 images belonging to 6 classes.
validation_generator = test_datagen.flow_from_directory(
base_path, target_size=(300, 300), batch_size=36,
class_mode='categorical', subset='validation', seed=0)
a = (validation_generator.class_indices)
labels = (train_generator.class_indices)
labels = dict((v, k) for k, v in labels.items())
model = load_model('rubbish/rubbish_model.h5')
test_x, test_y = validation_generator.__getitem__(1)
print(test_x)
preds = model.predict(test_x)
plt.figure(figsize=(36, 36))
for i in range(36):
plt.subplot(6, 6, i + 1)
plt.title('pred:%s / truth:%s' % (labels[np.argmax(preds[i])], labels[np.argmax(test_y[i])]))
plt.imshow(test_x[i])
plt.show()
if __name__ == '__main__':
# look_dataset_num()
# crate_model()
use_model() 【转自:外国服务器 http://www.558idc.com/shsgf.html转载请说明出处】