Skip to content

Commit

Permalink
Added Model class
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 24, 2023
1 parent 2781f1a commit 4768241
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def load_datapoints(atoms, process_dict):
# ])
print(dataset[0][-1].x)
print(dataset.df_name_idx.head())
print(dataset[0][-1].name)

# Create datapoint
atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif")
Expand Down
2 changes: 1 addition & 1 deletion src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def init_conv_layers(self):
gnn.CGConv(
channels=self.conv_size[i],
dim=self.num_edge_features[i],
batch_norm=True,
batch_norm=False,
),
nn.LeakyReLU(inplace=True),
]
Expand Down
6 changes: 4 additions & 2 deletions src/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def create_samplers(self, sample_config):
randomizer.shuffle(idx_array)

# Get indices
train_size = int(np.ceil(sample_config["train"] * self.dataset_size))
if sample_config["train"] < 1.:
train_size = int(np.ceil(sample_config["train"] * self.dataset_size))
train_idx = idx_array[:train_size]
val_size = int(np.ceil(sample_config["val"] * self.dataset_size))
if sample_config["val"] < 1.:
val_size = int(np.floor(sample_config["val"] * self.dataset_size))
val_idx = idx_array[train_size : train_size + val_size]
test_idx = idx_array[train_size + val_size :]

Expand Down
35 changes: 35 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np
import torch

from torch.utils.data import SubsetRandomSampler
from torch_geometric.loader import DataLoader

from featurizers import (
OneHotEncoder,
list_of_edge_featurizers,
Expand Down Expand Up @@ -133,6 +136,38 @@ def featurize_atoms(
"edge_indices": edge_indices,
}

def create_dataloaders(proc_data, sample_idx, batch_size, num_proc=0):
"""Create training, validation, and/or test dataloaders.
Parameters
----------
proc_data: AtomsDataset or AtomsDatapoints
Processed dataset object
sampler: dict
A dictionary with "train", "val", and "test" indices returned by a Sampler
object.
batch_size: int
Batch size
num_proc: int (default = 0)
Number of cores to be used for parallelization. Defaults to serial.
Returns
-------
dataloader_dict: dict
Dictionary of "train", "val", and "test" dataloaders
"""
# Create dataloader dict
dataloader_dict = {"train": [], "val": [], "test": []}

for key in dataloader_dict.keys():
if sample_idx[key].shape[0] > 0.:
sampler = SubsetRandomSampler(sample_idx[key])
dataloader_dict[key] = DataLoader(dataset=proc_data,
batch_size=batch_size,
sampler=sampler,
num_workers=num_proc)

return dataloader_dict

if __name__ == "__main__":
from ase.io import read
Expand Down

0 comments on commit 4768241

Please sign in to comment.