diff --git a/pipe/enumsets.py b/pipe/enumsets.py new file mode 100644 index 0000000..c699968 --- /dev/null +++ b/pipe/enumsets.py @@ -0,0 +1,23 @@ +#-*- coding: utf-8 -*- +""" +Enumeration of data types called from onekey output files. +""" + +from enum import Enum, IntEnum + +class FileType(IntEnum): + HDF5 = 0 + MAT = 1 + SHARD = 2 + +class DataKind(Enum): + BB = {"Full_Name": "Backscatter Brightness"} + FPS = {"Full_Name": "Framerate"} + NCNT = {"Full_Name": "Foreground Pixel Count"} + NSD = {"Full_Name": "Normalized Standard Deviation"} + SGRAM = {"Full_Name": "Spectrogram"} + SPEC = {"Full_Name": "Spectra"} + TREAT = {"Full_Name": "Treatment"} + TARGET = {"Full_Name": "Target"} + +# EOF diff --git a/pipe/extract.py b/pipe/extract.py index 849246e..9f7aa4b 100644 --- a/pipe/extract.py +++ b/pipe/extract.py @@ -7,6 +7,7 @@ import os from pathlib import Path import typing +from typing import List import cv2 as cv import h5py @@ -14,6 +15,8 @@ from pyspark.sql import SparkSession, Row, DataFrame import scipy as sp +from pipe.enumsets import FileType, DataKind + def extract(spark: SparkSession) -> DataFrame: """ First step of the ETL pipeline. It reads the list of .mat files from @@ -53,7 +56,7 @@ def image_pipe(spark: SparkSession, imagepath: Path, namepattern: str, return images -class SpectrogramReader: +class FileReader: """ Class to read spectrograms and metadata from different file formats based on user specified filetype. @@ -64,16 +67,17 @@ class SpectrogramReader: 'shards', and 'matfiles'. """ - def __init__(self, spark: SparkSession, filetype: str = "hdf5"): + def __init__(self, spark: SparkSession, filetype: FileType): self.spark = spark - if filetype == "hdf5": - self.spectrogram_read = self.spectrogram_read_hdf5 - elif filetype == "shards": - self.spectrogram_read = self.spectrogram_read_shards - elif filetype == "matfiles": - self.spectrogram_read = self.spectrogram_read_matfiles - else: - raise ValueError + match filetype: + case FileType.HDF5: + self.spectrogram_read = self.spectrogram_read_hdf5 + case FileType.SHARD: + self.spectrogram_read = self.spectrogram_read_shards + case FileType.MAT: + self.spectrogram_read = self.spectrogram_read_matfiles + case _: + raise ValueError(Expected) def metadata_read(self, metapath: Path, labels:list, namepattern: str="metadata{}.json") -> dict: @@ -97,52 +101,63 @@ def metadata_read(self, metapath: Path, labels:list, return metadata - def spectrogram_read_matfiles(self, specpath: Path, labels:list, - default_size: tuple = (32, 130), - pad_value: float = 0.) \ - -> DataFrame: + def read_matfiles(self, specpath: Path, + datakinds: List[DataKind], + default_size: tuple = (32, 130), + pad_value: float = 0.) -> DataFrame: """ - Loads spectrograms for each stack iteration from a set of mat files, + Loads data for each stack iteration from a set of mat files, and turns it into a spark-friendly format. Args: - labels (list): List of target labels. + labels (List[str]): List of target labels. + specpath (Path): Path to the spectrogram files. + default_size (tuple): Default size for the spectrograms. + pad_value (float): Value to use for padding. + datakinds (List[DataKind]): List of data kinds to extract. Returns: - DataFrame: Spark DataFrame containing the spectrograms and - associated metadata. + DataFrame: Spark DataFrame containing the requested data. """ - spectrograms = [] + data = [] row = {} + for label in labels: matdata = sp.io.loadmat(specpath/"matfiles"/label) - row["treatment"] = matdata["header"][0][0][4][0].lower() - try: + if DataKind.TREATMENT in datakinds: + row["treatment"] = matdata["header"][0][0][4][0].lower() + if DataKind.TARGET in datakinds: row["target"] = matdata["header"][0][0][2][0].lower() - except IndexError: - row["target"] = "unknown" - row["label"] = label - spectrogram = np.array(matdata["SP"][0]) - if len(spectrogram.shape) == 3: - spectrogram = spectrogram[0] - if spectrogram.shape[0] > default_size[0]: - spectrogram = spectrogram[:default_size[0], :] - spectrogram = np.pad( - spectrogram, - ((default_size[0] - spectrogram.shape[0], 0), - (default_size[1] - spectrogram.shape[1], 0)), - mode="constant", constant_values=pad_value) - spectrogram[np.isnan(spectrogram)] = 0. - spectrogram[np.abs(spectrogram) == np.inf] = 0. - spectrogram = spectrogram / np.sum(spectrogram) - row["spectrogram"] = spectrogram.tolist() - spectrograms.append(Row(**row)) + if DataKind.FPS in datakinds: + row["fps"] = 2*float(matdata["header"][0][0][15][0]) + if DataKind.BB in datakinds: + row["bb"] = matdata["bb"] + if DataKind.NSD in datakinds: + row["nsd"] = matdata["nsd"] + if DataKind.NCNT in datakinds: + row["ncnt"] = matdata["ncnt"] + if DataKind.SPEC in datakinds: + spectra = np.array(matdata["SP"][0]) + if len(spectra.shape) == 3: + spectra = spectrogram[0] + if spectra.shape[0] > default_size[0]: + spectra = spectra[:default_size[0], :] + spectra = np.pad( + spectra, + ((default_size[0] - spectra.shape[0], 0), + (default_size[1] - spectra.shape[1], 0)), + mode="constant", constant_values=pad_value) + spectra[np.isnan(spectra)] = 0. + spectra[np.abs(spectra) == np.inf] = 0. + spectra = spectra / np.sum(spectra) + row["spectra"] = spectra.tolist() + data.append(Row(**row)) return self.spark.createDataFrame(spectrograms) - def spectrogram_read_hdf5(self, specpath: Path, labels: list, - namepattern:str="averaged_spectrogram{}.hdf5") \ - -> DataFrame: + def read_hdf5(self, specpath: Path, labels: list, + namepattern:str="averaged_spectrogram{}.hdf5") \ + -> DataFrame: """ Loads spectrograms for each stack iteration from an hdf5 data file, and turns it into a spark-friendly format. @@ -167,9 +182,9 @@ def spectrogram_read_hdf5(self, specpath: Path, labels: list, return self.spark.createDataFrame(spectrograms) - def spectrogram_read_shards(self, specpath: Path, namepattern: str, - stacksize: int, freq_samples: int) \ - -> DataFrame: + def read_shards(self, specpath: Path, namepattern: str, + stacksize: int, freq_samples: int) \ + -> DataFrame: """ Loads spectrograms for each stack iteration from a set of shard files, and turns it into a spark-friendly format. diff --git a/train.py b/train.py index 9474007..ddf37f1 100644 --- a/train.py +++ b/train.py @@ -198,7 +198,7 @@ def main(): categories[key].values()) plt.gcf().set_size_inches(len(categories[key])/10+4, len(categories[key])/10+3) - plt.savefig(f"/app/workdir/confusion_matrix_{key}.png", + plt.savefig(f"/app/workdir/figures/confusion_matrix_{key}.png", bbox_inches="tight") plt.close() with open(f"confusion_matrix_{key}.json", 'w') as f: