import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
# Parameters
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 10
NUM_CLASSES = 5 # e.g., stripes, polka dots, floral, geometric, plaid
# Paths (update with your dataset path)
train_dir = 'data/train' # directory with subfolders for each class
val_dir = 'data/val'
# Data Augmentation and Preprocessing
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
zoom_range=0.2
)
val_datagen = ImageDataGenerator(rescale=1./255)
# Data generators
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE,
class_mode='categorical'
)
val_generator = val_datagen.flow_from_directory(
val_dir,
target_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE,
class_mode='categorical'
)
# Load MobileNetV2 pretrained on ImageNet (without top layers)
base_model = MobileNetV2(weights='imagenet', include_top=False,
input_shape=(IMG_SIZE, IMG_SIZE, 3))
# Freeze base model layers
base_model.trainable = False
# Add custom classification head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.3)(x)
predictions = Dense(NUM_CLASSES, activation='softmax')(x)
# Create the full model
model = Model(inputs=base_model.input, outputs=predictions)
# Compile model
model.compile(optimizer=Adam(learning_rate=1e-4),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Train the model
history = model.fit(
train_generator,
validation_data=val_generator,
epochs=EPOCHS
)
# Optionally, unfreeze some layers and fine-tune
base_model.trainable = True
for layer in base_model.layers[:100]:
layer.trainable = False
model.compile(optimizer=Adam(learning_rate=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy'])
fine_tune_epochs = 5
total_epochs = EPOCHS + fine_tune_epochs
history_fine = model.fit(
train_generator,
validation_data=val_generator,
epochs=total_epochs,
initial_epoch=history.epoch[-1]
)
# Save model
model.save('pattern_sense_model.h5')
print("Model saved!")