Skip to content

Commit

Permalink
Works well on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 25, 2023
1 parent 037ca9d commit b078dff
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 37 deletions.
14 changes: 7 additions & 7 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, partition_configs):

# Initialize layers
# Initial transform
self.init_transform = []
self.init_transform = nn.ModuleList()
for i in range(self.n_partitions):
self.init_transform.append(
nn.Sequential(
Expand All @@ -58,12 +58,12 @@ def __init__(self, partition_configs):
self.init_conv_layers()

# Pooling layers
self.pool_layers = []
self.pool_layers = nn.ModuleList()
for i in range(self.n_partitions):
self.pool_layers.append(gnn.pool.global_add_pool)
self.pool_layers.append(gnn.aggr.SumAggregation())

# Pool transform
self.pool_transform = []
self.pool_transform = nn.ModuleList()
for i in range(self.n_partitions):
self.pool_transform.append(
nn.Sequential(
Expand All @@ -73,7 +73,7 @@ def __init__(self, partition_configs):
)

# Hidden layers
self.hidden_layers = []
self.hidden_layers = nn.ModuleList()
for i in range(self.n_partitions):
self.hidden_layers.append(
nn.Sequential(
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self, partition_configs):

def init_conv_layers(self):
"""Initialize convolutional layers."""
self.conv_layers = []
self.conv_layers = nn.ModuleList()
for i in range(self.n_partitions):
part_conv_layers = []
for j in range(self.n_conv[i]):
Expand Down Expand Up @@ -151,7 +151,7 @@ def forward(self, data_objects):
conv_data = layer(conv_data)

# Apply pooling layer
pooled_data = self.pool_layers[i](x=conv_data, batch=data.batch)
pooled_data = self.pool_layers[i](x=conv_data, index=data.batch)

# Apply pool-to-hidden transform
hidden_data = self.pool_transform[i](pooled_data)
Expand Down
57 changes: 46 additions & 11 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@

class Standardizer:
"""Class to standardize targets."""
def __init__(self, X):
def __init__(self):
"""
Class to standardize outputs.
Initialize with dummy values.
"""
self.mean = 0
self.std = 0.1

def initialize(self, X):
"""Initialize mean and std based on the given tensor.
Parameters
----------
X: torch.Tensor
Expand Down Expand Up @@ -59,6 +67,24 @@ def restore(self, Z):
"""
X = self.mean + Z * self.std
return X

def restore_cont(self, Z):
"""
Restore a standardized contribution to the non-standardized contribution.
Parameters
----------
Z: torch.Tensor
Tensor of standardized contributions
Returns
-------
X: torch.Tensor
Tensor of non-standardized contributions
"""
X = (self.mean / Z.shape[0]) + Z * self.std
return X

def get_state(self):
"""
Expand Down Expand Up @@ -124,7 +150,6 @@ def __init__(self, global_config, partition_configs, model_path):
self.use_gpu = global_config["gpu"]
if self.use_gpu:
self.model.cuda()

# Set loss function
if global_config["loss_function"] == "mse":
self.loss_fn = torch.nn.MSELoss()
Expand All @@ -147,7 +172,7 @@ def __init__(self, global_config, partition_configs, model_path):
)
elif global_config["optimizer"].lower().strip() == "sgd":
self.optimizer = torch.optim.SGD(
self.model_parameters(),
self.model.parameters(),
lr=global_config["learning_rate"],
)

Expand All @@ -159,6 +184,9 @@ def __init__(self, global_config, partition_configs, model_path):
else:
self.scheduler = None

# Set standardizer
self.standardizer = Standardizer()

def make_directory_structure(self, model_path):
"""Make directory structure to store models and results."""
self.model_path = Path(model_path)
Expand All @@ -175,7 +203,7 @@ def init_standardizer(self, targets):
targets: np.ndarray or torch.Tensor
Array of training outputs
"""
self.standardizer = Standardizer(torch.Tensor(targets))
self.standardizer.initialize(torch.Tensor(targets))

def train_epoch(self, dataloader):
"""Train the model for a single epoch.
Expand Down Expand Up @@ -307,12 +335,14 @@ def predict(self, dataset, indices, return_targets=False):
Returns
-------
prediction_dict: dict
Dictionary containing "targets", "predictions", and "indices" (copy of
predict_idx).
Dictionary containing "targets", "predictions", "contributions" and
"indices" (copy of predict_idx).
"""
# Create arrays
n_partitions = len(dataset.get(indices[0]))
targets = np.zeros(len(indices))
predictions = np.zeros(len(indices))
contributions = np.zeros((len(indices), n_partitions))

# Enable eval mode of model
self.model.eval()
Expand All @@ -334,12 +364,15 @@ def predict(self, dataset, indices, return_targets=False):

# Compute prediction
pred_dict = self.model(nn_input)
predictions[i] = self.standardizer.restore(pred_dict["output"].cpu())
predictions[i] = self.standardizer.restore(pred_dict["output"].cpu().detach())
conts_std = pred_dict["contributions"].cpu().detach()
contributions[i, :] = self.standardizer.restore_cont(conts_std).numpy().flatten()

predictions_dict = {
"targets": targets,
"predictions": predictions,
"indices": indices,
"contributions": contributions,
}

return predictions_dict
Expand Down Expand Up @@ -446,15 +479,17 @@ def train(self, epochs, dataloader_dict, verbose=False):
self.save(i, best_status)

# Save losses and metrics
train_losses.append(train_loss)
val_losses.append(val_loss)
train_losses.append(train_loss.cpu().detach().numpy())
val_losses.append(val_loss.cpu().detach().numpy())
train_metrics.append(train_metric)
val_metrics.append(val_metric)

# Print, if verbose
if verbose:
print(f"Epoch: [{i}]\tTraining Loss: [{train_loss}]\
\tValidation Loss: [{val_loss}]")
print(f"Epoch: [{i}] Training loss: [{train_loss:.3f}] " + \
f"Training metric: [{train_metric:.3f}] " +\
f"Validation loss: [{val_loss:.3f}] " +\
f"Validation metric: [{val_metric:.3f}]")

# Load the best model
self.load(best_status=True)
Expand Down
45 changes: 26 additions & 19 deletions workflows/basic_train_val_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# ])

# Create sampler
sample_config = {"train": 0.6, "val": 0.2, "test": 0.2}
sample_config = {"train": 0.8, "val": 0.1, "test": 0.1}
sampler = RandomSampler(seed, dataset.len())
sample_idx = sampler.create_samplers(sample_config)

Expand All @@ -40,39 +40,39 @@

# Create model
global_config = {
"gpu": False,
"gpu": True,
"loss_function": "mse",
"metric_function": "mae",
"learning_rate": 0.1,
"learning_rate": 0.001,
"optimizer": "adam",
"lr_milestones": [3, 10]
"lr_milestones": [100]
}
partition_configs = [
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"n_conv": 5,
"n_hidden": 2,
"hidden_size": 50,
"conv_size": 50,
"dropout": 0.1,
"num_node_features": dataset[0][0].num_node_features,
"num_edge_features": dataset[0][0].num_edge_features,
"conv_type": "CGConv",
},
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"n_conv": 5,
"n_hidden": 2,
"hidden_size": 50,
"conv_size": 50,
"dropout": 0.1,
"num_node_features": dataset[0][1].num_node_features,
"num_edge_features": dataset[0][1].num_edge_features,
"conv_type": "CGConv",
},
{
"n_conv": 3,
"n_hidden": 3,
"hidden_size": 30,
"conv_size": 40,
"n_conv": 5,
"n_hidden": 2,
"hidden_size": 50,
"conv_size": 50,
"dropout": 0.1,
"num_node_features": dataset[0][2].num_node_features,
"num_edge_features": dataset[0][2].num_edge_features,
Expand All @@ -82,6 +82,13 @@

model_path = REPO_PATH / "trained_models" / "S_binary_calcs"
model = Model(global_config, partition_configs, model_path)
model.init_standardizer([dataset[i][0].y for i in sample_idx["train"]])
results_dict = model.train(100, dataloader_dict, verbose=True)
print(results_dict)
#model.init_standardizer([dataset[i][0].y for i in sample_idx["train"]])
#results_dict = model.train(200, dataloader_dict, verbose=True)
#print(f"Test metric: {results_dict['metric']['test']}")

# Load model
model.load(best_status=True)

# Make predictions on a structure
pred_dict = model.predict(dataset, [200], return_targets=True)
print(pred_dict)

0 comments on commit b078dff

Please sign in to comment.