Skip to content

Commit

Permalink
Merge pull request #19 from GreeleyGroup/enh/train
Browse files Browse the repository at this point in the history
ENH: Add Model class for model training and validation
  • Loading branch information
deshmukg authored Sep 24, 2023
2 parents 2781f1a + b212fae commit b6b8292
Show file tree
Hide file tree
Showing 5 changed files with 504 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def load_datapoints(atoms, process_dict):
# ])
print(dataset[0][-1].x)
print(dataset.df_name_idx.head())
print(dataset[0][-1].name)

# Create datapoint
atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif")
Expand Down
2 changes: 1 addition & 1 deletion src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def init_conv_layers(self):
gnn.CGConv(
channels=self.conv_size[i],
dim=self.num_edge_features[i],
batch_norm=True,
batch_norm=False,
),
nn.LeakyReLU(inplace=True),
]
Expand Down
6 changes: 4 additions & 2 deletions src/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def create_samplers(self, sample_config):
randomizer.shuffle(idx_array)

# Get indices
train_size = int(np.ceil(sample_config["train"] * self.dataset_size))
if sample_config["train"] < 1.:
train_size = int(np.ceil(sample_config["train"] * self.dataset_size))
train_idx = idx_array[:train_size]
val_size = int(np.ceil(sample_config["val"] * self.dataset_size))
if sample_config["val"] < 1.:
val_size = int(np.floor(sample_config["val"] * self.dataset_size))
val_idx = idx_array[train_size : train_size + val_size]
test_idx = idx_array[train_size + val_size :]

Expand Down
Loading

0 comments on commit b6b8292

Please sign in to comment.