Skip to content

Commit

Permalink
Changed test case in models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 26, 2023
1 parent 6f606c2 commit a2cfa7b
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def get_embeddings(self, data_objects):
atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif")
datapoint = AtomsDatapoints(atoms)
datapoint.process_data(
z_cutoffs=[13.0, 20.0],
layer_cutoffs=[3, 6],
node_features=[
["atomic_number", "dband_center"],
["atomic_number", "reactivity"],
Expand All @@ -392,35 +392,26 @@ def get_embeddings(self, data_objects):
partition_configs = [
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"dropout": 0.1,
"num_node_features": data_objects[0].num_node_features,
"num_edge_features": data_objects[0].num_edge_features,
"conv_type": "CGConv",
},
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"dropout": 0.1,
"num_node_features": data_objects[1].num_node_features,
"num_edge_features": data_objects[1].num_edge_features,
"conv_type": "CGConv",
},
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"dropout": 0.1,
"num_node_features": data_objects[2].num_node_features,
"num_edge_features": data_objects[2].num_edge_features,
"conv_type": "CGConv",
},
]
net = MultiGCN(partition_configs)
net = SlabGCN(partition_configs)
result_dict = net(data_objects)
print(result_dict)

0 comments on commit a2cfa7b

Please sign in to comment.