Skip to content

Commit

Permalink
Fixed codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 26, 2023
1 parent 0f913ee commit 6f606c2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
38 changes: 20 additions & 18 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def __init__(self, partition_configs):
self.fin1_act = nn.LeakyReLU()
self.final_lin_transform = nn.Linear(50, 1)
self.final_lin_act = nn.LeakyReLU()
#with torch.no_grad():
# with torch.no_grad():
# self.final_lin_transform.weight.copy_(torch.ones(self.n_partitions))
#for p in self.final_lin_transform.parameters():
# for p in self.final_lin_transform.parameters():
# p.requires_grad = False

def init_conv_layers(self):
Expand Down Expand Up @@ -173,7 +173,8 @@ def forward(self, data_objects):
output = self.final_lin_act(output)

return {"output": output, "contributions": contributions}



class SlabGCN(nn.Module):
"""Class to customize the graph neural network."""

Expand Down Expand Up @@ -243,20 +244,20 @@ def __init__(self, partition_configs, global_config):

# Hidden layers
self.hidden_layers = nn.Sequential(
*(
[
nn.Linear(self.hidden_size, self.hidden_size),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=self.dropout),
]
* (self.n_hidden - 1)
+ [
nn.Linear(self.hidden_size, 1),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=self.dropout),
]
)
)
*(
[
nn.Linear(self.hidden_size, self.hidden_size),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=self.dropout),
]
* (self.n_hidden - 1)
+ [
nn.Linear(self.hidden_size, 1),
nn.LeakyReLU(inplace=True),
nn.Dropout(p=self.dropout),
]
)
)

def init_conv_layers(self):
"""Initialize convolutional layers."""
Expand Down Expand Up @@ -320,7 +321,7 @@ def forward(self, data_objects):
output = self.hidden_layers(hidden_data)

return {"output": output}

def get_embeddings(self, data_objects):
"""Get the pooled embeddings of each partition.
Expand Down Expand Up @@ -358,6 +359,7 @@ def get_embeddings(self, data_objects):

return embeddings


if __name__ == "__main__":
from pathlib import Path

Expand Down
24 changes: 13 additions & 11 deletions workflows/basic_train_val_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
dataset = AtomsDataset(root=dataset_path, prop_csv=prop_csv_path)

# Process dataset
dataset.process_data(layer_cutoffs=[3, 6],
node_features=[
["atomic_number", "dband_center", "coordination"],
["atomic_number", "reactivity", "coordination"],
["atomic_number", "reactivity", "coordination"],
],
edge_features=[
["bulk_bond_distance"],
["surface_bond_distance"],
["adsorbate_bond_distance"],
])
dataset.process_data(
layer_cutoffs=[3, 6],
node_features=[
["atomic_number", "dband_center", "coordination"],
["atomic_number", "reactivity", "coordination"],
["atomic_number", "reactivity", "coordination"],
],
edge_features=[
["bulk_bond_distance"],
["surface_bond_distance"],
["adsorbate_bond_distance"],
],
)

# Create sampler
sample_config = {"train": 0.8, "val": 0.1, "test": 0.1}
Expand Down

0 comments on commit 6f606c2

Please sign in to comment.