python 撰寫訓練模型,判斷貓或狗

1. 安裝必要的庫

pip install tensorflow keras numpy matplotlib


2. 準備數據集

下載並準備數據集。Kaggle提供了大量標記好的貓狗照片數據集。將數據集解壓到工作目錄中。


3. 建立模型

import tensorflow as tf

from tensorflow.keras.applications import MobileNetV2

from tensorflow.keras.layers import Dense, GlobalAveragePooling2D

from tensorflow.keras.models import Model


base_model = MobileNetV2(weights='imagenet', include_top=False)


x = base_model.output

x = GlobalAveragePooling2D()(x)

x = Dense(1024, activation='relu')(x)

predictions = Dense(1, activation='sigmoid')(x)


model = Model(inputs=base_model.input, outputs=predictions)

for layer in base_model.layers:

    layer.trainable = False

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])


4. 訓練模型

from tensorflow.keras.preprocessing.image import ImageDataGenerator


train_datagen = ImageDataGenerator(rescale=1./255)

validation_datagen = ImageDataGenerator(rescale=1./255)


train_generator = train_datagen.flow_from_directory(

    'data/train',

    target_size=(224, 224),

    batch_size=32,

    class_mode='binary'

)


validation_generator = validation_datagen.flow_from_directory(

    'data/validation',

    target_size=(224, 224),

    batch_size=32,

    class_mode='binary'

)


model.fit(

    train_generator,

    steps_per_epoch=train_generator.n // train_generator.batch_size,

    validation_data=validation_generator,

    validation_steps=validation_generator.n // validation_generator.batch_size,

    epochs=10

)

5. 評估模型

test_generator = validation_datagen.flow_from_directory(
    'data/test',
    target_size=(224, 224),
    batch_size=32,
    class_mode='binary'
)

loss, accuracy = model.evaluate(test_generator)
print(f'Test Accuracy: {accuracy * 100:.2f}%')

6. 使用模型進行預測

from tensorflow.keras.preprocessing import image
import numpy as np

img_path = 'path_to_your_image.jpg'
img = image.load_img(img_path, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.

prediction = model.predict(img_array)
if prediction < 0.5:
    print("這是一隻貓")
else:
    print("這是一隻狗")

結論

以上步驟展示了如何使用AI和深度學習模型來判斷照片中的動物是貓還是狗。通過這些步驟,你可以訓練出一個準確的貓狗分類器並用於實際應用中。


留言

熱門文章