Skip to content

Commit

Permalink
Train script for GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
lim185 committed Jan 3, 2025
1 parent f8268e2 commit 721f2f8
Showing 1 changed file with 70 additions and 29 deletions.
99 changes: 70 additions & 29 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,63 +14,87 @@
import jax
import numpy as np
from pipe.pipe import SpectrogramPipe
from pyspark.ml.feature import StringIndexer
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.sql import SparkSession, functions
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec
from model.model import TimeSeriesTransformer as TSTF

from model.autoencoder import Autoencoder
# data parameters
SPLIT = [0.9, 0.05, 0.05]

# model parameters
HEAD_SIZE = 32
NUM_HEADS = 12
FF_DIM = 32
NUM_TRANSFORMER_BLOCKS = 12
MLP_UNITS = [128]
DROPOUT = 0.3
MLP_DROPOUT = 0.3

# training parameters
BATCH_SIZE = 8
EPOCHS = 800

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 26, 130, 1)
.reshape(-1, 32, 130)

return ndarray

def get_data(spark, split=[0.99, 0.005, 0.005]):
path = Path("/app/datadump/train")
path = Path("/app/workdir")

labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark)
pipe = SpectrogramPipe(spark, filetype="matfiles")
data = pipe.spectrogram_pipe(path, labels)

indexer = StringIndexer(inputCol="treatment", outputCol="treatment_index")
indexed = indexer.fit(data).transform(data)

selected = indexed.select("treatment", "treatment_index").distinct()
selected = selected.sort("treatment_index")
index_max = selected.agg(functions.max("treatment_index")).collect()[0][0]

train_df, validation_df, test_df = indexed.randomSplit(split, seed=42)
print(train_df.count())
print(validation_df.count())
print(test_df.count())

trainx = trim(train_df, "spectrogram")
trainy = np.array(train_df.select("treatment_index").collect())
trainy = np.array(train_df.select("treatment_index").collect()).astype(int)
_trainy = np.zeros((len(trainy), int(index_max+1)))
for index, value in enumerate(trainy):
_trainy[index, value] = 1.
#_trainy[np.arange(trainy.shape[0]), trainy] = 1.
trainy = _trainy
del _trainy

valx = trim(validation_df, "spectrogram")
valy = np.array(validation_df.select("treatment_index").collect())
valy = np.array(validation_df.select("treatment_index").collect()) \
.astype(int)
_valy = np.zeros((len(valy), int(index_max+1)))
for index, value in enumerate(valy):
_valy[index, value] = 1.
#_valy[np.arange(valy.shape[0]), valy] = 1.
valy = _valy
del _valy

testx = trim(test_df, "spectrogram")
testy = np.array(test_df.select("treatment_index").collect())

return ((trainx, trainy), (valx, valy), (testx, testy))
def get_model():
model = Autoencoder()
return model
testy = np.array(test_df.select("treatment_index").collect()).astype(int)
_testy = np.zeros((len(testy), int(index_max+1)))
for index, value in enumerate(testy):

def main():
# jax mesh setup
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
shape=(2,), axis_names=["model"], devices=devices
shape=(len(devices),), axis_names=["model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh)
layout_map["dense.*kernel"] = (None, "model")
Expand All @@ -80,9 +104,9 @@ def main():
layout_map["dense.*activity_regularizer"] = (None,)
layout_map["dense.*kernel_constraint"] = (None, "model")
layout_map["dense.*bias_constraint"] = ("model",)
layout_map["conv2d.*kernel"] = (None, None, None, "model")
layout_map["conv2d.*kernel_regularizer"] = (None, None, None, "model")
layout_map["conv2d.*bias_regularizer"] = ("model",)
layout_map["conv1d.*kernel"] = (None, None, None, "model")
layout_map["conv1d.*kernel_regularizer"] = (None, None, None, "model")
layout_map["conv1d.*bias_regularizer"] = ("model",)

model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map
Expand All @@ -91,16 +115,33 @@ def main():

spark = SparkSession.builder.appName("train").getOrCreate()

train_set, validation_set, test_set = get_data(spark)
indices, train_set, validation_set, test_set = get_data(spark, split=SPLIT)

model = get_model()
n_classes = indices.count()
model = get_model(train_set[0].shape[1:], n_classes)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=4e-4),
loss="categorical_crossentropy",
metrics=["categorical_accuracy"]
)
model.summary()

model.compile(optimizer="adam", loss="mean_squared_error")
start = time.time()
model.fit(x=train_set[0], y=train_set[0], batch_size=2, epochs=50)
model.fit(x=train_set[0], y=train_set[1],
validation_data=(validation_set[0], validation_set[1]),
batch_size=BATCH_SIZE, epochs=EPOCHS)
end = time.time()
print("Training time: ", end - start)

# Test model performance
test_loss, test_accuracy = model.evaluate(test_set[0], test_set[1])
test_predict = model.predict(test_set[0])
print(f"Test loss: {test_loss}, test accuracy: {test_accuracy}")

conf_matrix = confusion_matrix(np.argmax(test_predict, axis=1),
np.argmax(test_set[1], axis=1))
plt.imshow(conf_matrix, origin="upper")
plt.gca().set_aspect("equal")
plt.savefig("confusion_matrix.png")

return

Expand Down

0 comments on commit 721f2f8

Please sign in to comment.