diff --git a/src/models.py b/src/models.py index f814d6d..e52025a 100644 --- a/src/models.py +++ b/src/models.py @@ -374,7 +374,7 @@ def get_embeddings(self, data_objects): atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif") datapoint = AtomsDatapoints(atoms) datapoint.process_data( - z_cutoffs=[13.0, 20.0], + layer_cutoffs=[3, 6], node_features=[ ["atomic_number", "dband_center"], ["atomic_number", "reactivity"], @@ -392,35 +392,26 @@ def get_embeddings(self, data_objects): partition_configs = [ { "n_conv": 3, - "n_hidden": 3, - "hidden_size": 30, "conv_size": 40, - "dropout": 0.1, "num_node_features": data_objects[0].num_node_features, "num_edge_features": data_objects[0].num_edge_features, "conv_type": "CGConv", }, { "n_conv": 3, - "n_hidden": 3, - "hidden_size": 30, "conv_size": 40, - "dropout": 0.1, "num_node_features": data_objects[1].num_node_features, "num_edge_features": data_objects[1].num_edge_features, "conv_type": "CGConv", }, { "n_conv": 3, - "n_hidden": 3, - "hidden_size": 30, "conv_size": 40, - "dropout": 0.1, "num_node_features": data_objects[2].num_node_features, "num_edge_features": data_objects[2].num_edge_features, "conv_type": "CGConv", }, ] - net = MultiGCN(partition_configs) + net = SlabGCN(partition_configs) result_dict = net(data_objects) print(result_dict)