Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
def train_mobilenet():
with open("./new_train_mtcnn.csv", "r") as csv_file:
df = pd.read_csv(csv_file, delimiter=',')
# Split data into training and validation sets
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
# Create training ImageDataGenerator, allows code to go through smaller chunks of training data
# at a time instead of loading all data at once
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# Load and augment training data
train_generator = train_datagen.flow_from_dataframe(
dataframe=train_df,
directory='./train_mtcnn/',
x_col='File Name',
y_col='Category',
target_size=(224, 224),
batch_size=32,
class_mode='categorical',
shuffle=True,
seed=42
)
# Define validation data generator
val_datagen = ImageDataGenerator(rescale=1./255)
# Load validation data
val_generator = val_datagen.flow_from_dataframe(
dataframe=val_df,
directory='./train_mtcnn/',
x_col='File Name',
y_col='Category',
target_size=(224, 224),
batch_size=32,
class_mode='categorical',
shuffle=False
)
# Load pre-trained MobileNet model w/o top layer
base_model = MobileNet(weights='imagenet', include_top=False)
# Add custom classification layers on top of MobileNet --> Convert to 100 classes of Celebrities
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(100, activation='softmax')(x)
# Compile the model
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
# Train the model with 25 epochs
model.fit(train_generator, epochs=25, validation_data=val_generator)
# Evaluate the model
loss, accuracy = model.evaluate(val_generator)
print(f'Validation Loss: {loss}, Validation Accuracy: {accuracy}')
# Save model to predict testing data
model.save('model_mobilenet_25epochs.h5')
if __name__ == "__main__":
train_mobilenet()