Skip to content

Commit

Permalink
It works
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 23, 2023
1 parent 39ab2ca commit 57196b1
Showing 1 changed file with 87 additions and 21 deletions.
108 changes: 87 additions & 21 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
import torch_geometric.nn as gnn

class MultiGCN(gnn.MessagePassing):
class MultiGCN(nn.Module):
"""Class to customize the graph neural network."""
def __init__(self, partition_configs):
"""Initialize the graph neural network.
Expand All @@ -24,6 +24,8 @@ def __init__(self, partition_configs):
num_node_features (number of node features, int), num_edge_features
(number of edge features, int).
"""
super().__init__()

# Store hyperparameters
self.n_conv = [config["n_conv"] for config in partition_configs]
self.n_hidden = [config["n_hidden"] for config in partition_configs]
Expand All @@ -35,7 +37,7 @@ def __init__(self, partition_configs):
config["num_node_features"] for config in partition_configs
]
self.num_edge_features = [
config["num_node_features"] for config in partition_configs
config["num_edge_features"] for config in partition_configs
]
self.n_partitions = len(partition_configs)

Expand All @@ -44,7 +46,7 @@ def __init__(self, partition_configs):
self.init_transform = []
for i in range(self.n_partitions):
self.init_transform.append(
nn.ModuleList(
nn.Sequential(
nn.Linear(self.num_node_features[i], self.conv_size[i]),
nn.LeakyReLU(inplace=True),
)
Expand All @@ -56,13 +58,13 @@ def __init__(self, partition_configs):
# Pooling layers
self.pool_layers = []
for i in range(self.n_partitions):
self.pool_layers.append(gnn.pool.global_addpool())
self.pool_layers.append(gnn.pool.global_add_pool)

# Pool transform
self.pool_transform = []
for i in range(self.n_partitions):
self.pool_transform.append(
nn.ModuleList(
nn.Sequential(
nn.Linear(self.conv_size[i], self.hidden_size[i]),
nn.LeakyReLU(inplace=True),
)
Expand All @@ -72,16 +74,16 @@ def __init__(self, partition_configs):
self.hidden_layers = []
for i in range(self.n_partitions):
self.hidden_layers.append(
nn.ModuleList([
nn.Sequential(*([
nn.Linear(self.hidden_size[i], self.hidden_size[i]),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=self.dropout),
] * (self.hidden_layers - 1) +
nn.Dropout(p=self.dropout[i]),
] * (self.n_hidden[i] - 1) +
[
nn.Linear(self.hidden_size[i], 1),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=self.dropout),
]
nn.Dropout(p=self.dropout[i]),
])
)
)

Expand All @@ -95,16 +97,16 @@ def init_conv_layers(self):
self.conv_layers = []
for i in range(self.n_partitions):
part_conv_layers = []
for j in range(self.n_conv):
for j in range(self.n_conv[i]):
conv_layer = [
gnn.CGConv(
channels=self.num_node_features[i],
channels=self.conv_size[i],
dim=self.num_edge_features[i],
batch_norm=True
batch_norm=True,
),
nn.LeakyReLU(inplace=True)
]
part_conv_layers.append(conv_layer)
part_conv_layers.extend(conv_layer)

self.conv_layers.append(nn.ModuleList(part_conv_layers))

Expand All @@ -127,27 +129,91 @@ def forward(self, data_objects):
# For each data object
for i, data in enumerate(data_objects):
# Apply initial transform
conv_data = self.init_transform[i](data)
conv_data = self.init_transform[i](data.x.to(torch.float32))

# Apply convolutional layers
for layer in self.conv_layers[i]:
conv_data = layer(conv_data)
if isinstance(layer, gnn.MessagePassing):
conv_data = layer(x=conv_data, edge_index=data.edge_index,
edge_attr=data.edge_attr)
else:
conv_data = layer(conv_data)

# Apply pooling layer
pooled_data = self.pool_layers[i](conv_data)
pooled_data = self.pool_layers[i](x=conv_data, batch=None)

# Apply pool-to-hidden transform
hidden_data = self.pool_transform[i](pooled_data)

# Apply hidden layers
for layer in self.hidden_layers[i]:
hidden_data = layer(hidden_data)
hidden_data = self.hidden_layers[i](hidden_data)

# Save contribution
contributions.append(hidden_data)

# Apply final transformation
output = self.final_lin_transform(*contributions)
contributions = torch.cat(contributions, dim=-1)
output = self.final_lin_transform(contributions)

return {"output": output, "contributions": contributions}


if __name__ == "__main__":
from ase.io import read
from data import AtomsDatapoints
from constants import REPO_PATH
from pathlib import Path
# Test for one tensor
# Create datapoins
data_root_path = Path(REPO_PATH) / "data" / "S_calcs"
atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif")
datapoint = AtomsDatapoints(atoms)
datapoint.process_data(
z_cutoffs=[13.0, 20.0],
node_features=[
["atomic_number", "dband_center"],
["atomic_number", "reactivity"],
["atomic_number", "reactivity"],
],
edge_features=[
["bulk_bond_distance"],
["surface_bond_distance"],
["adsorbate_bond_distance"],
],
)
data_objects = datapoint.get(0)

# Get result
partition_configs = [
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"dropout": 0.1,
"num_node_features": data_objects[0].num_node_features,
"num_edge_features": data_objects[0].num_edge_features,
"conv_type": "CGConv",
},
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"dropout": 0.1,
"num_node_features": data_objects[1].num_node_features,
"num_edge_features": data_objects[1].num_edge_features,
"conv_type": "CGConv",
},
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"dropout": 0.1,
"num_node_features": data_objects[2].num_node_features,
"num_edge_features": data_objects[2].num_edge_features,
"conv_type": "CGConv",
}
]
net = MultiGCN(partition_configs)
result_dict = net(data_objects)

0 comments on commit 57196b1

Please sign in to comment.