From a1836f27da8d9bdf93989161d890c7807d26397e Mon Sep 17 00:00:00 2001 From: Dawith Date: Sat, 7 Dec 2024 21:07:55 -0500 Subject: [PATCH] Pipe rewritten to identify files by tag --- pipe/pipe.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/pipe/pipe.py b/pipe/pipe.py index 1840d81..b65837e 100644 --- a/pipe/pipe.py +++ b/pipe/pipe.py @@ -2,10 +2,12 @@ pipe.py """ +import os from pathlib import Path import typing import cv2 as cv +import h5py import numpy as np from pyspark.sql import SparkSession, Row @@ -26,35 +28,30 @@ def __init__(self, spark: SparkSession, filetype: str = "hdf5"): elif filetype == "shards": self.spectrogram_pipe = self.spectrogram_pipe_shards else: - raise ValueError(s"Invalid filetype {filetype}.") + raise ValueError - def spectrogram_pipe_hdf5(self, specpath: Path, freq_samples: int) - -> np.ndarray: + def spectrogram_pipe_hdf5(self, specpath: Path, labels: list, + namepattern:str="averaged_spectrogram{}.hdf5" + ) -> np.ndarray: """ Loads spectrograms for each stack iteration from an hdf5 data file, and turns it into a spark-friendly format. Args: specpath (Path): Path to the spectrogram files. - namepattern (str): Name pattern for the spectrogram files. - stacksize (int): Number of spectrograms in the stack. - freq_samples (int): Number of frequency samples in each - spectrogram. Returns: """ spectrograms = [] - for filename in os.listdir(specpath): - if not filename.endswith(".hdf5"): - continue + for label in labels: + filename = namepattern.format(label) with h5py.File(specpath/filename, 'r') as f: - spectrograms.append(Row(label=filename, - spectrogram=f['spectrogram'][:])) - - # Turn spectrogram into a spark dataframe. + spectrograms.append( + Row(label=label, + spectrogram=f['spectrogram'][:].tolist())) - return spectrograms + return self.spark.createDataFrame(spectrograms) def spectrogram_pipe_shards(self, specpath: Path, namepattern: str, stacksize: int, freq_samples: int) -> np.ndarray: