diff --git a/decodertest.py b/decodertest.py index 70188cc..e150970 100644 --- a/decodertest.py +++ b/decodertest.py @@ -29,13 +29,6 @@ import train.decoder_train as dt from visualize.plot import spectra_plot -def build_dict(df, key): - df = df.select(key, f"{key}_str").distinct() - - return df.rdd.map( - lambda row: (str(np.argmax(row[key])), row[f"{key}_str"]) - ).collectAsMap() - def multi_gpu(): devices = jax.devices("gpu") mesh = keras.distribution.DeviceMesh(