python 撰寫訓練模型,判斷貓或狗
1. 安裝必要的庫
pip install tensorflow keras numpy matplotlib
2. 準備數據集
下載並準備數據集。Kaggle提供了大量標記好的貓狗照片數據集。將數據集解壓到工作目錄中。
3. 建立模型
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
base_model = MobileNetV2(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(1, activation='sigmoid')(x)
model = Model(inputs=base_model.input, outputs=predictions)
for layer in base_model.layers:
layer.trainable = False
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
4. 訓練模型
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255)
validation_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(224, 224),
batch_size=32,
class_mode='binary'
)
validation_generator = validation_datagen.flow_from_directory(
'data/validation',
target_size=(224, 224),
batch_size=32,
class_mode='binary'
)
model.fit(
train_generator,
steps_per_epoch=train_generator.n // train_generator.batch_size,
validation_data=validation_generator,
validation_steps=validation_generator.n // validation_generator.batch_size,
epochs=10
)
留言
張貼留言