Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
ece50024_mini_challenge/train_final.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
88 lines (69 sloc)
2.71 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
from tensorflow.keras.applications import MobileNet | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D | |
from tensorflow.keras.optimizers import Adam | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
import tensorflow as tf | |
def train_mobilenet(): | |
with open("./new_train_mtcnn.csv", "r") as csv_file: | |
df = pd.read_csv(csv_file, delimiter=',') | |
# Split data into training and validation sets | |
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42) | |
# Create training ImageDataGenerator, allows code to go through smaller chunks of training data | |
# at a time instead of loading all data at once | |
train_datagen = ImageDataGenerator( | |
rescale=1./255, | |
rotation_range=20, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True, | |
fill_mode='nearest' | |
) | |
# Load and augment training data | |
train_generator = train_datagen.flow_from_dataframe( | |
dataframe=train_df, | |
directory='./train_mtcnn/', | |
x_col='File Name', | |
y_col='Category', | |
target_size=(224, 224), | |
batch_size=32, | |
class_mode='categorical', | |
shuffle=True, | |
seed=42 | |
) | |
# Define validation data generator | |
val_datagen = ImageDataGenerator(rescale=1./255) | |
# Load validation data | |
val_generator = val_datagen.flow_from_dataframe( | |
dataframe=val_df, | |
directory='./train_mtcnn/', | |
x_col='File Name', | |
y_col='Category', | |
target_size=(224, 224), | |
batch_size=32, | |
class_mode='categorical', | |
shuffle=False | |
) | |
# Load pre-trained MobileNet model w/o top layer | |
base_model = MobileNet(weights='imagenet', include_top=False) | |
# Add custom classification layers on top of MobileNet --> Convert to 100 classes of Celebrities | |
x = base_model.output | |
x = GlobalAveragePooling2D()(x) | |
x = Dense(1024, activation='relu')(x) | |
predictions = Dense(100, activation='softmax')(x) | |
# Compile the model | |
model = Model(inputs=base_model.input, outputs=predictions) | |
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy']) | |
# Train the model with 25 epochs | |
model.fit(train_generator, epochs=25, validation_data=val_generator) | |
# Evaluate the model | |
loss, accuracy = model.evaluate(val_generator) | |
print(f'Validation Loss: {loss}, Validation Accuracy: {accuracy}') | |
# Save model to predict testing data | |
model.save('model_mobilenet_25epochs.h5') | |
if __name__ == "__main__": | |
train_mobilenet() |