diff --git a/decodertest.py b/decodertest.py index 6c5b63e..70188cc 100644 --- a/decodertest.py +++ b/decodertest.py @@ -27,39 +27,7 @@ from model.model import CompoundModel, DecoderModel from pipe.etl import etl, read import train.decoder_train as dt - -def trim(dataframe, column): - - ndarray = np.array(dataframe.select(column).collect()) \ - .reshape(-1, 32, 130) - - return ndarray - -def transform(spark, dataframe, keys): - dataframe = dataframe.withColumn( - "index", functions.monotonically_increasing_id() - ) - bundle = {key: [ - arr.tolist() - for arr in OneHotEncoder(sparse_output=False) \ - .fit_transform(dataframe.select(key).collect()) - ] for key in keys - } - - bundle = [dict(zip(bundle.keys(), values)) - for values in zip(*bundle.values())] - schema = types.StructType([ - types.StructField(key, types.ArrayType(types.FloatType()), True) - for key in keys - ]) - newframe = spark.createDataFrame(bundle, schema=schema).withColumn( - "index", functions.monotonically_increasing_id() - ) - for key in keys: - dataframe = dataframe.withColumnRenamed(key, f"{key}_str") - dataframe = dataframe.join(newframe, on="index", how="inner") - - return dataframe +from visualize.plot import spectra_plot def build_dict(df, key): df = df.select(key, f"{key}_str").distinct() @@ -68,37 +36,6 @@ def build_dict(df, key): lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) ).collectAsMap() - -def get_data(spark, split=[0.99, 0.005, 0.005]): - path = Path("/app/workdir") - - labels = [] - with open(path / "train.csv", "r") as file: - for line in file: - labels.append(line.strip().split(",")[0]) - - pipe = SpectrogramPipe(spark, filetype="matfiles") - data = pipe.spectrogram_pipe(path, labels) - data.select("treatment").replace("virus", "cpv") \ - .replace("cont", "pbs") \ - .replace("control", "pbs") \ - .replace("dld", "pbs").distinct() - - data = transform(spark, data, ["treatment", "target"]) - category_dict = { - key: build_dict(data, key) for key in ["treatment", "target"] - } - splits = data.randomSplit(split, seed=42) - trainx, valx, testx = (trim(dset, "spectra") for dset in splits) - trainy, valy, testy = ( - [np.array(dset.select("treatment").collect()).squeeze(), - np.array(dset.select("target").collect()).squeeze()] - for dset in splits - ) - - - return ((trainx, trainy), (valx, valy), (testx, testy), category_dict) - def multi_gpu(): devices = jax.devices("gpu") mesh = keras.distribution.DeviceMesh( @@ -147,14 +84,7 @@ def main(): metrics = {key: None for key in keys} metrics, test_predict = dt.test_decoder(decoder, test_set, metrics) - plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1)) - plt.savefig("sample_spectra.png") - plt.close() - plt.pcolor(test_predict[0] - np.mean(np.array(test_predict[0])[:4], axis=0), - cmap='bwr', clim=(-1, 1)) - plt.savefig("sample_spectrogram.png") - plt.close() - exit() + spectra_plot(test_predict[0], name=None) if __name__ == "__main__":