Skip to content

Commit

Permalink
Merge pull request #16 from GreeleyGroup/enh/model
Browse files Browse the repository at this point in the history
ENH: Added MultiGCN model
  • Loading branch information
deshmukg authored Sep 23, 2023
2 parents 836799a + fd0ec97 commit 945c494
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def process_data(

# Convert to Data object
data_obj = Data(
x=feat_dict["node_tensor"],
x=feat_dict["node_tensor"].to(torch.float32),
edge_index=feat_dict["edge_indices"],
edge_attr=feat_dict["edge_tensor"],
y=torch.Tensor([self.map_name_prop[name]]),
Expand Down Expand Up @@ -245,7 +245,7 @@ def process_data(

# Convert to Data object
data_obj = Data(
x=feat_dict["node_tensor"],
x=feat_dict["node_tensor"].to(torch.float32),
edge_index=feat_dict["edge_indices"],
edge_attr=feat_dict["edge_tensor"],
)
Expand Down
235 changes: 235 additions & 0 deletions src/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""Graph neural network models."""

import torch
import torch.nn as nn
import torch_geometric.nn as gnn


class MultiGCN(nn.Module):
"""Class to customize the graph neural network."""

def __init__(self, partition_configs):
"""Initialize the graph neural network.
Parameters
----------
partition_configs: List[Dict]
List of dictionaries containing parameters for the GNN for each
partition. The number of different GNNs are judged based on the
size of the list. Each partition config should contain the following
keys: n_conv (number of convolutional layers, int), n_hidden (number
of hidden layers, int), conv_size (feature size before convolution, int)
hidden_size (nodes per hidden layer node, int), dropout (dropout
probability for hidden layers, float), conv_type (type of convolution
layer, str; currently only "CGConv" is supported), pool_type
(type of pooling layer, str; currently "add" and "mean" are supported),
num_node_features (number of node features, int), num_edge_features
(number of edge features, int).
"""
super().__init__()

# Store hyperparameters
self.n_conv = [config["n_conv"] for config in partition_configs]
self.n_hidden = [config["n_hidden"] for config in partition_configs]
self.hidden_size = [config["hidden_size"] for config in partition_configs]
self.conv_size = [config["conv_size"] for config in partition_configs]
self.conv_type = [config["conv_type"] for config in partition_configs]
self.dropout = [config["dropout"] for config in partition_configs]
self.num_node_features = [
config["num_node_features"] for config in partition_configs
]
self.num_edge_features = [
config["num_edge_features"] for config in partition_configs
]
self.n_partitions = len(partition_configs)

# Initialize layers
# Initial transform
self.init_transform = []
for i in range(self.n_partitions):
self.init_transform.append(
nn.Sequential(
nn.Linear(self.num_node_features[i], self.conv_size[i]),
nn.LeakyReLU(inplace=True),
)
)

# Convolutional layers
self.init_conv_layers()

# Pooling layers
self.pool_layers = []
for i in range(self.n_partitions):
self.pool_layers.append(gnn.pool.global_add_pool)

# Pool transform
self.pool_transform = []
for i in range(self.n_partitions):
self.pool_transform.append(
nn.Sequential(
nn.Linear(self.conv_size[i], self.hidden_size[i]),
nn.LeakyReLU(inplace=True),
)
)

# Hidden layers
self.hidden_layers = []
for i in range(self.n_partitions):
self.hidden_layers.append(
nn.Sequential(
*(
[
nn.Linear(self.hidden_size[i], self.hidden_size[i]),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=self.dropout[i]),
]
* (self.n_hidden[i] - 1)
+ [
nn.Linear(self.hidden_size[i], 1),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=self.dropout[i]),
]
)
)
)

# Final linear layer
# TODO: replace 1 with multiple outputs
self.final_lin_transform = nn.Linear(self.n_partitions, 1, bias=False)
with torch.no_grad():
self.final_lin_transform.weight.copy_(torch.ones(self.n_partitions))

def init_conv_layers(self):
"""Initialize convolutional layers."""
self.conv_layers = []
for i in range(self.n_partitions):
part_conv_layers = []
for j in range(self.n_conv[i]):
# TODO Add possibility of changing convolutional layers
conv_layer = [
gnn.CGConv(
channels=self.conv_size[i],
dim=self.num_edge_features[i],
batch_norm=True,
),
nn.LeakyReLU(inplace=True),
]
part_conv_layers.extend(conv_layer)

self.conv_layers.append(nn.ModuleList(part_conv_layers))

def forward(self, data_objects):
"""Foward pass of the network(s).
Parameters
----------
data_objects: list
List of data objects, each corresponding to a graph of a partition
of an atomic structure.
Returns
------
dict
Dictionary containing "output" and "contributions".
"""
# Initialize empty list for contributions
contributions = []
# For each data object
for i, data in enumerate(data_objects):
# Apply initial transform
conv_data = self.init_transform[i](data.x)

# Apply convolutional layers
for layer in self.conv_layers[i]:
if isinstance(layer, gnn.MessagePassing):
conv_data = layer(
x=conv_data,
edge_index=data.edge_index,
edge_attr=data.edge_attr,
)
else:
conv_data = layer(conv_data)

# Apply pooling layer
pooled_data = self.pool_layers[i](x=conv_data, batch=None)

# Apply pool-to-hidden transform
hidden_data = self.pool_transform[i](pooled_data)

# Apply hidden layers
hidden_data = self.hidden_layers[i](hidden_data)

# Save contribution
contributions.append(hidden_data)

# Apply final transformation
contributions = torch.cat(contributions, dim=-1)
output = self.final_lin_transform(contributions)

return {"output": output, "contributions": contributions}


if __name__ == "__main__":
from pathlib import Path

from ase.io import read

from constants import REPO_PATH
from data import AtomsDatapoints

# Test for one tensor
# Create datapoins
data_root_path = Path(REPO_PATH) / "data" / "S_calcs"
atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif")
datapoint = AtomsDatapoints(atoms)
datapoint.process_data(
z_cutoffs=[13.0, 20.0],
node_features=[
["atomic_number", "dband_center"],
["atomic_number", "reactivity"],
["atomic_number", "reactivity"],
],
edge_features=[
["bulk_bond_distance"],
["surface_bond_distance"],
["adsorbate_bond_distance"],
],
)
data_objects = datapoint.get(0)

# Get result
partition_configs = [
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"dropout": 0.1,
"num_node_features": data_objects[0].num_node_features,
"num_edge_features": data_objects[0].num_edge_features,
"conv_type": "CGConv",
},
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"dropout": 0.1,
"num_node_features": data_objects[1].num_node_features,
"num_edge_features": data_objects[1].num_edge_features,
"conv_type": "CGConv",
},
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"dropout": 0.1,
"num_node_features": data_objects[2].num_node_features,
"num_edge_features": data_objects[2].num_edge_features,
"conv_type": "CGConv",
},
]
net = MultiGCN(partition_configs)
result_dict = net(data_objects)
print(result_dict)

0 comments on commit 945c494

Please sign in to comment.