From a2cfa7b3ae1e25bbac1088cf2412c7a3a6b64236 Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Tue, 26 Sep 2023 14:32:40 -0400 Subject: [PATCH] Changed test case in models.py --- src/models.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/models.py b/src/models.py index f814d6d..e52025a 100644 --- a/src/models.py +++ b/src/models.py @@ -374,7 +374,7 @@ def get_embeddings(self, data_objects): atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif") datapoint = AtomsDatapoints(atoms) datapoint.process_data( - z_cutoffs=[13.0, 20.0], + layer_cutoffs=[3, 6], node_features=[ ["atomic_number", "dband_center"], ["atomic_number", "reactivity"], @@ -392,35 +392,26 @@ def get_embeddings(self, data_objects): partition_configs = [ { "n_conv": 3, - "n_hidden": 3, - "hidden_size": 30, "conv_size": 40, - "dropout": 0.1, "num_node_features": data_objects[0].num_node_features, "num_edge_features": data_objects[0].num_edge_features, "conv_type": "CGConv", }, { "n_conv": 3, - "n_hidden": 3, - "hidden_size": 30, "conv_size": 40, - "dropout": 0.1, "num_node_features": data_objects[1].num_node_features, "num_edge_features": data_objects[1].num_edge_features, "conv_type": "CGConv", }, { "n_conv": 3, - "n_hidden": 3, - "hidden_size": 30, "conv_size": 40, - "dropout": 0.1, "num_node_features": data_objects[2].num_node_features, "num_edge_features": data_objects[2].num_edge_features, "conv_type": "CGConv", }, ] - net = MultiGCN(partition_configs) + net = SlabGCN(partition_configs) result_dict = net(data_objects) print(result_dict)