From 57196b1da0148936c58a0ca1213710b56b31fe4a Mon Sep 17 00:00:00 2001 From: Gaurav S Deshmukh Date: Fri, 22 Sep 2023 21:46:42 -0400 Subject: [PATCH] It works --- src/models.py | 108 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 87 insertions(+), 21 deletions(-) diff --git a/src/models.py b/src/models.py index a7f2e33..56321d2 100644 --- a/src/models.py +++ b/src/models.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch_geometric.nn as gnn -class MultiGCN(gnn.MessagePassing): +class MultiGCN(nn.Module): """Class to customize the graph neural network.""" def __init__(self, partition_configs): """Initialize the graph neural network. @@ -24,6 +24,8 @@ def __init__(self, partition_configs): num_node_features (number of node features, int), num_edge_features (number of edge features, int). """ + super().__init__() + # Store hyperparameters self.n_conv = [config["n_conv"] for config in partition_configs] self.n_hidden = [config["n_hidden"] for config in partition_configs] @@ -35,7 +37,7 @@ def __init__(self, partition_configs): config["num_node_features"] for config in partition_configs ] self.num_edge_features = [ - config["num_node_features"] for config in partition_configs + config["num_edge_features"] for config in partition_configs ] self.n_partitions = len(partition_configs) @@ -44,7 +46,7 @@ def __init__(self, partition_configs): self.init_transform = [] for i in range(self.n_partitions): self.init_transform.append( - nn.ModuleList( + nn.Sequential( nn.Linear(self.num_node_features[i], self.conv_size[i]), nn.LeakyReLU(inplace=True), ) @@ -56,13 +58,13 @@ def __init__(self, partition_configs): # Pooling layers self.pool_layers = [] for i in range(self.n_partitions): - self.pool_layers.append(gnn.pool.global_addpool()) + self.pool_layers.append(gnn.pool.global_add_pool) # Pool transform self.pool_transform = [] for i in range(self.n_partitions): self.pool_transform.append( - nn.ModuleList( + nn.Sequential( nn.Linear(self.conv_size[i], self.hidden_size[i]), nn.LeakyReLU(inplace=True), ) @@ -72,16 +74,16 @@ def __init__(self, partition_configs): self.hidden_layers = [] for i in range(self.n_partitions): self.hidden_layers.append( - nn.ModuleList([ + nn.Sequential(*([ nn.Linear(self.hidden_size[i], self.hidden_size[i]), nn.LeakyReLU(inplace=True), - nn.Dropout(p=self.dropout), - ] * (self.hidden_layers - 1) + + nn.Dropout(p=self.dropout[i]), + ] * (self.n_hidden[i] - 1) + [ nn.Linear(self.hidden_size[i], 1), nn.LeakyReLU(inplace=True), - nn.Dropout(p=self.dropout), - ] + nn.Dropout(p=self.dropout[i]), + ]) ) ) @@ -95,16 +97,16 @@ def init_conv_layers(self): self.conv_layers = [] for i in range(self.n_partitions): part_conv_layers = [] - for j in range(self.n_conv): + for j in range(self.n_conv[i]): conv_layer = [ gnn.CGConv( - channels=self.num_node_features[i], + channels=self.conv_size[i], dim=self.num_edge_features[i], - batch_norm=True + batch_norm=True, ), nn.LeakyReLU(inplace=True) ] - part_conv_layers.append(conv_layer) + part_conv_layers.extend(conv_layer) self.conv_layers.append(nn.ModuleList(part_conv_layers)) @@ -127,27 +129,91 @@ def forward(self, data_objects): # For each data object for i, data in enumerate(data_objects): # Apply initial transform - conv_data = self.init_transform[i](data) + conv_data = self.init_transform[i](data.x.to(torch.float32)) # Apply convolutional layers for layer in self.conv_layers[i]: - conv_data = layer(conv_data) + if isinstance(layer, gnn.MessagePassing): + conv_data = layer(x=conv_data, edge_index=data.edge_index, + edge_attr=data.edge_attr) + else: + conv_data = layer(conv_data) # Apply pooling layer - pooled_data = self.pool_layers[i](conv_data) + pooled_data = self.pool_layers[i](x=conv_data, batch=None) # Apply pool-to-hidden transform hidden_data = self.pool_transform[i](pooled_data) # Apply hidden layers - for layer in self.hidden_layers[i]: - hidden_data = layer(hidden_data) + hidden_data = self.hidden_layers[i](hidden_data) # Save contribution contributions.append(hidden_data) # Apply final transformation - output = self.final_lin_transform(*contributions) + contributions = torch.cat(contributions, dim=-1) + output = self.final_lin_transform(contributions) return {"output": output, "contributions": contributions} - \ No newline at end of file + +if __name__ == "__main__": + from ase.io import read + from data import AtomsDatapoints + from constants import REPO_PATH + from pathlib import Path + # Test for one tensor + # Create datapoins + data_root_path = Path(REPO_PATH) / "data" / "S_calcs" + atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif") + datapoint = AtomsDatapoints(atoms) + datapoint.process_data( + z_cutoffs=[13.0, 20.0], + node_features=[ + ["atomic_number", "dband_center"], + ["atomic_number", "reactivity"], + ["atomic_number", "reactivity"], + ], + edge_features=[ + ["bulk_bond_distance"], + ["surface_bond_distance"], + ["adsorbate_bond_distance"], + ], + ) + data_objects = datapoint.get(0) + + # Get result + partition_configs = [ + { + "n_conv": 3, + "n_hidden": 3, + "hidden_size": 30, + "conv_size": 40, + "dropout": 0.1, + "num_node_features": data_objects[0].num_node_features, + "num_edge_features": data_objects[0].num_edge_features, + "conv_type": "CGConv", + }, + { + "n_conv": 3, + "n_hidden": 3, + "hidden_size": 30, + "conv_size": 40, + "dropout": 0.1, + "num_node_features": data_objects[1].num_node_features, + "num_edge_features": data_objects[1].num_edge_features, + "conv_type": "CGConv", + }, + { + "n_conv": 3, + "n_hidden": 3, + "hidden_size": 30, + "conv_size": 40, + "dropout": 0.1, + "num_node_features": data_objects[2].num_node_features, + "num_edge_features": data_objects[2].num_edge_features, + "conv_type": "CGConv", + } + ] + net = MultiGCN(partition_configs) + result_dict = net(data_objects) \ No newline at end of file