diff --git a/pipe/extract.py b/pipe/extract.py index 9f7aa4b..8396d0c 100644 --- a/pipe/extract.py +++ b/pipe/extract.py @@ -27,7 +27,8 @@ def extract(spark: SparkSession) -> DataFrame: """ path = Path("/app/workdir") - labels = [] + labels = [DataKind.BB, DataKind.FPS, DataKind.NSD, DataKind.NCNT, + DataKind.SPEC, DataKind.TARGET, DataKind.TREATMENT] with open(path / "train.csv", "r") as file: for line in file: labels.append(line.strip().split(",")[0])