diff --git a/analysis.py b/analysis.py index f8b7cfa..ec6a808 100644 --- a/analysis.py +++ b/analysis.py @@ -45,7 +45,12 @@ def pca(data, features): if __name__ == "__main__": spark = SparkSession.builder.appName("train").getOrCreate() - data = load(spark, split=[0.9, 0.5, 0.5]) + SPLIT = [0.9, 0.05, 0.05] + load_from_scratch = False + if load_from_scratch: + data = etl(spark, split=SPLIT) + else: + data = read(spark) pca(data) category_distribution(data)