diff --git a/train.py b/train.py index c82e0ef..6315016 100644 --- a/train.py +++ b/train.py @@ -15,7 +15,7 @@ def main(): labels = [] with open(path / "train.csv", "r") as file: for line in file: - labels.append(line.strip().split(",")) + labels.append(line.strip().split(",")[0]) spark = SparkSession.builder.appName("train").getOrCreate() pipe = SpectrogramPipe(spark)