Skip to content

Commit

Permalink
Data type reorganized as enum and extract segment expanded to load ot…
Browse files Browse the repository at this point in the history
…her data kinds
  • Loading branch information
lim185 committed Oct 1, 2025
1 parent b8dcff2 commit c751d57
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 46 deletions.
23 changes: 23 additions & 0 deletions pipe/enumsets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#-*- coding: utf-8 -*-
"""
Enumeration of data types called from onekey output files.
"""

from enum import Enum, IntEnum

class FileType(IntEnum):
HDF5 = 0
MAT = 1
SHARD = 2

class DataKind(Enum):
BB = {"Full_Name": "Backscatter Brightness"}
FPS = {"Full_Name": "Framerate"}
NCNT = {"Full_Name": "Foreground Pixel Count"}
NSD = {"Full_Name": "Normalized Standard Deviation"}
SGRAM = {"Full_Name": "Spectrogram"}
SPEC = {"Full_Name": "Spectra"}
TREAT = {"Full_Name": "Treatment"}
TARGET = {"Full_Name": "Target"}

# EOF
105 changes: 60 additions & 45 deletions pipe/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import os
from pathlib import Path
import typing
from typing import List

import cv2 as cv
import h5py
import numpy as np
from pyspark.sql import SparkSession, Row, DataFrame
import scipy as sp

from pipe.enumsets import FileType, DataKind

def extract(spark: SparkSession) -> DataFrame:
"""
First step of the ETL pipeline. It reads the list of .mat files from
Expand Down Expand Up @@ -53,7 +56,7 @@ def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str,

return images

class SpectrogramReader:
class FileReader:
"""
Class to read spectrograms and metadata from different file formats based
on user specified filetype.
Expand All @@ -64,16 +67,17 @@ class SpectrogramReader:
'shards', and 'matfiles'.
"""

def __init__(self, spark: SparkSession, filetype: str = "hdf5"):
def __init__(self, spark: SparkSession, filetype: FileType):
self.spark = spark
if filetype == "hdf5":
self.spectrogram_read = self.spectrogram_read_hdf5
elif filetype == "shards":
self.spectrogram_read = self.spectrogram_read_shards
elif filetype == "matfiles":
self.spectrogram_read = self.spectrogram_read_matfiles
else:
raise ValueError
match filetype:
case FileType.HDF5:
self.spectrogram_read = self.spectrogram_read_hdf5
case FileType.SHARD:
self.spectrogram_read = self.spectrogram_read_shards
case FileType.MAT:
self.spectrogram_read = self.spectrogram_read_matfiles
case _:
raise ValueError(Expected)

def metadata_read(self, metapath: Path, labels:list,
namepattern: str="metadata{}.json") -> dict:
Expand All @@ -97,52 +101,63 @@ def metadata_read(self, metapath: Path, labels:list,

return metadata

def spectrogram_read_matfiles(self, specpath: Path, labels:list,
default_size: tuple = (32, 130),
pad_value: float = 0.) \
-> DataFrame:
def read_matfiles(self, specpath: Path,
datakinds: List[DataKind],
default_size: tuple = (32, 130),
pad_value: float = 0.) -> DataFrame:
"""
Loads spectrograms for each stack iteration from a set of mat files,
Loads data for each stack iteration from a set of mat files,
and turns it into a spark-friendly format.
Args:
labels (list): List of target labels.
labels (List[str]): List of target labels.
specpath (Path): Path to the spectrogram files.
default_size (tuple): Default size for the spectrograms.
pad_value (float): Value to use for padding.
datakinds (List[DataKind]): List of data kinds to extract.
Returns:
DataFrame: Spark DataFrame containing the spectrograms and
associated metadata.
DataFrame: Spark DataFrame containing the requested data.
"""
spectrograms = []
data = []
row = {}

for label in labels:
matdata = sp.io.loadmat(specpath/"matfiles"/label)
row["treatment"] = matdata["header"][0][0][4][0].lower()
try:
if DataKind.TREATMENT in datakinds:
row["treatment"] = matdata["header"][0][0][4][0].lower()
if DataKind.TARGET in datakinds:
row["target"] = matdata["header"][0][0][2][0].lower()
except IndexError:
row["target"] = "unknown"
row["label"] = label
spectrogram = np.array(matdata["SP"][0])
if len(spectrogram.shape) == 3:
spectrogram = spectrogram[0]
if spectrogram.shape[0] > default_size[0]:
spectrogram = spectrogram[:default_size[0], :]
spectrogram = np.pad(
spectrogram,
((default_size[0] - spectrogram.shape[0], 0),
(default_size[1] - spectrogram.shape[1], 0)),
mode="constant", constant_values=pad_value)
spectrogram[np.isnan(spectrogram)] = 0.
spectrogram[np.abs(spectrogram) == np.inf] = 0.
spectrogram = spectrogram / np.sum(spectrogram)
row["spectrogram"] = spectrogram.tolist()
spectrograms.append(Row(**row))
if DataKind.FPS in datakinds:
row["fps"] = 2*float(matdata["header"][0][0][15][0])
if DataKind.BB in datakinds:
row["bb"] = matdata["bb"]
if DataKind.NSD in datakinds:
row["nsd"] = matdata["nsd"]
if DataKind.NCNT in datakinds:
row["ncnt"] = matdata["ncnt"]
if DataKind.SPEC in datakinds:
spectra = np.array(matdata["SP"][0])
if len(spectra.shape) == 3:
spectra = spectrogram[0]
if spectra.shape[0] > default_size[0]:
spectra = spectra[:default_size[0], :]
spectra = np.pad(
spectra,
((default_size[0] - spectra.shape[0], 0),
(default_size[1] - spectra.shape[1], 0)),
mode="constant", constant_values=pad_value)
spectra[np.isnan(spectra)] = 0.
spectra[np.abs(spectra) == np.inf] = 0.
spectra = spectra / np.sum(spectra)
row["spectra"] = spectra.tolist()
data.append(Row(**row))

return self.spark.createDataFrame(spectrograms)

def spectrogram_read_hdf5(self, specpath: Path, labels: list,
namepattern:str="averaged_spectrogram{}.hdf5") \
-> DataFrame:
def read_hdf5(self, specpath: Path, labels: list,
namepattern:str="averaged_spectrogram{}.hdf5") \
-> DataFrame:
"""
Loads spectrograms for each stack iteration from an hdf5 data file,
and turns it into a spark-friendly format.
Expand All @@ -167,9 +182,9 @@ def spectrogram_read_hdf5(self, specpath: Path, labels: list,

return self.spark.createDataFrame(spectrograms)

def spectrogram_read_shards(self, specpath: Path, namepattern: str,
stacksize: int, freq_samples: int) \
-> DataFrame:
def read_shards(self, specpath: Path, namepattern: str,
stacksize: int, freq_samples: int) \
-> DataFrame:
"""
Loads spectrograms for each stack iteration from a set of shard files,
and turns it into a spark-friendly format.
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def main():
categories[key].values())
plt.gcf().set_size_inches(len(categories[key])/10+4,
len(categories[key])/10+3)
plt.savefig(f"/app/workdir/confusion_matrix_{key}.png",
plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png",
bbox_inches="tight")
plt.close()
with open(f"confusion_matrix_{key}.json", 'w') as f:
Expand Down

0 comments on commit c751d57

Please sign in to comment.