diff --git a/pipe/transform.py b/pipe/transform.py index 8bcb616..a66461b 100644 --- a/pipe/transform.py +++ b/pipe/transform.py @@ -80,22 +80,6 @@ def onehot(dataframe: DataFrame, keys: list) -> DataFrame: result = result.withColumnRenamed(column_name, f"{column_name}_str") result = result.withColumnRenamed(f"{column_name}_encoded", column_name) - """ - bundle = {key: [ - arr.tolist() - for arr in OneHotEncoder(sparse_output=False) \ - .fit_transform(dataframe.select(key).collect()) - ] for key in keys - } - - bundle = [dict(zip(bundle.keys(), values)) - for values in zip(*bundle.values())] - schema = types.StructType([ - types.StructField(key, types.ArrayType(types.FloatType()), True) - for key in keys - ]) - - return bundle, schema""" return result def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \ @@ -107,16 +91,6 @@ def transform(spark: SparkSession, dataframe: DataFrame, keys: list) \ "index", functions.monotonically_increasing_id() ) - """ - bundle, schema = onehot(dataframe, keys) - newframe = spark.createDataFrame(bundle, schema=schema).withColumn( - "index", functions.monotonically_increasing_id() - ) - - for key in keys: - dataframe = dataframe.withColumnRenamed(key, f"{key}_str") - dataframe = dataframe.join(newframe, on="index", how="inner") - """ dataframe = onehot(dataframe, keys) return dataframe