diff --git a/src/data.py b/src/data.py index e9561af..01e92a1 100644 --- a/src/data.py +++ b/src/data.py @@ -131,7 +131,7 @@ def process_data( # Convert to Data object data_obj = Data( - x=feat_dict["node_tensor"], + x=feat_dict["node_tensor"].to(torch.float32), edge_index=feat_dict["edge_indices"], edge_attr=feat_dict["edge_tensor"], y=torch.Tensor([self.map_name_prop[name]]), @@ -245,7 +245,7 @@ def process_data( # Convert to Data object data_obj = Data( - x=feat_dict["node_tensor"], + x=feat_dict["node_tensor"].to(torch.float32), edge_index=feat_dict["edge_indices"], edge_attr=feat_dict["edge_tensor"], ) diff --git a/src/models.py b/src/models.py index 615ddc8..8ff3fd6 100644 --- a/src/models.py +++ b/src/models.py @@ -95,7 +95,9 @@ def __init__(self, partition_configs): # Final linear layer # TODO: replace 1 with multiple outputs - self.final_lin_transform = nn.Linear(self.n_partitions, 1) + self.final_lin_transform = nn.Linear(self.n_partitions, 1, bias=False) + with torch.no_grad(): + self.final_lin_transform.weight.copy_(torch.ones(self.n_partitions)) def init_conv_layers(self): """Initialize convolutional layers.""" @@ -230,3 +232,4 @@ def forward(self, data_objects): ] net = MultiGCN(partition_configs) result_dict = net(data_objects) + print(result_dict)