Skip to content

Commit

Permalink
Unnecessary aux functions removed
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawith committed Oct 27, 2025
1 parent 065e8b3 commit 4f443cf
Showing 1 changed file with 2 additions and 72 deletions.
74 changes: 2 additions & 72 deletions decodertest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,7 @@
from model.model import CompoundModel, DecoderModel
from pipe.etl import etl, read
import train.decoder_train as dt

def trim(dataframe, column):

ndarray = np.array(dataframe.select(column).collect()) \
.reshape(-1, 32, 130)

return ndarray

def transform(spark, dataframe, keys):
dataframe = dataframe.withColumn(
"index", functions.monotonically_increasing_id()
)
bundle = {key: [
arr.tolist()
for arr in OneHotEncoder(sparse_output=False) \
.fit_transform(dataframe.select(key).collect())
] for key in keys
}

bundle = [dict(zip(bundle.keys(), values))
for values in zip(*bundle.values())]
schema = types.StructType([
types.StructField(key, types.ArrayType(types.FloatType()), True)
for key in keys
])
newframe = spark.createDataFrame(bundle, schema=schema).withColumn(
"index", functions.monotonically_increasing_id()
)
for key in keys:
dataframe = dataframe.withColumnRenamed(key, f"{key}_str")
dataframe = dataframe.join(newframe, on="index", how="inner")

return dataframe
from visualize.plot import spectra_plot

def build_dict(df, key):
df = df.select(key, f"{key}_str").distinct()
Expand All @@ -68,37 +36,6 @@ def build_dict(df, key):
lambda row: (str(np.argmax(row[key])), row[f"{key}_str"])
).collectAsMap()


def get_data(spark, split=[0.99, 0.005, 0.005]):
path = Path("/app/workdir")

labels = []
with open(path / "train.csv", "r") as file:
for line in file:
labels.append(line.strip().split(",")[0])

pipe = SpectrogramPipe(spark, filetype="matfiles")
data = pipe.spectrogram_pipe(path, labels)
data.select("treatment").replace("virus", "cpv") \
.replace("cont", "pbs") \
.replace("control", "pbs") \
.replace("dld", "pbs").distinct()

data = transform(spark, data, ["treatment", "target"])
category_dict = {
key: build_dict(data, key) for key in ["treatment", "target"]
}
splits = data.randomSplit(split, seed=42)
trainx, valx, testx = (trim(dset, "spectra") for dset in splits)
trainy, valy, testy = (
[np.array(dset.select("treatment").collect()).squeeze(),
np.array(dset.select("target").collect()).squeeze()]
for dset in splits
)


return ((trainx, trainy), (valx, valy), (testx, testy), category_dict)

def multi_gpu():
devices = jax.devices("gpu")
mesh = keras.distribution.DeviceMesh(
Expand Down Expand Up @@ -147,14 +84,7 @@ def main():
metrics = {key: None for key in keys}
metrics, test_predict = dt.test_decoder(decoder, test_set, metrics)

plt.pcolor(test_predict[0], cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectra.png")
plt.close()
plt.pcolor(test_predict[0] - np.mean(np.array(test_predict[0])[:4], axis=0),
cmap='bwr', clim=(-1, 1))
plt.savefig("sample_spectrogram.png")
plt.close()
exit()
spectra_plot(test_predict[0], name=None)


if __name__ == "__main__":
Expand Down

0 comments on commit 4f443cf

Please sign in to comment.