diff --git a/pipe/pipe.py b/pipe/pipe.py index b65837e..6589a40 100644 --- a/pipe/pipe.py +++ b/pipe/pipe.py @@ -30,6 +30,28 @@ def __init__(self, spark: SparkSession, filetype: str = "hdf5"): else: raise ValueError + def metadata_pipe(self, metapath: Path, labels:list, + namepattern: str="metadata{}.json") -> dict: + """ + Loads metadata for each target label from a set of json files and + return them as a hierarchical dictionary. + + Args: + metapath (Path): Path to the metadata files. + labels (list): List of target labels. + namepattern (str): Name pattern for the metadata files. + + Returns: + metadata: Hierarchical dictionary of metadata. + """ + + metadata = {} + for label in labels: + with open(metapath/namepattern.format(label), 'r') as f: + metadata[label] = json.load(f) + + return metadata + def spectrogram_pipe_hdf5(self, specpath: Path, labels: list, namepattern:str="averaged_spectrogram{}.hdf5" ) -> np.ndarray: @@ -43,13 +65,14 @@ def spectrogram_pipe_hdf5(self, specpath: Path, labels: list, Returns: """ + metadata = self.metadata_pipe(specpath, labels) spectrograms = [] for label in labels: filename = namepattern.format(label) + meta = metadata[label] with h5py.File(specpath/filename, 'r') as f: - spectrograms.append( - Row(label=label, - spectrogram=f['spectrogram'][:].tolist())) + meta["spectrogram"] = f['spectrogram'][:].tolist() + spectrograms.append(Row(**meta)) return self.spark.createDataFrame(spectrograms)