Skip to content

Commit

Permalink
Fixed final layer weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 23, 2023
1 parent 79a9704 commit fd0ec97
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def process_data(

# Convert to Data object
data_obj = Data(
x=feat_dict["node_tensor"],
x=feat_dict["node_tensor"].to(torch.float32),
edge_index=feat_dict["edge_indices"],
edge_attr=feat_dict["edge_tensor"],
y=torch.Tensor([self.map_name_prop[name]]),
Expand Down Expand Up @@ -245,7 +245,7 @@ def process_data(

# Convert to Data object
data_obj = Data(
x=feat_dict["node_tensor"],
x=feat_dict["node_tensor"].to(torch.float32),
edge_index=feat_dict["edge_indices"],
edge_attr=feat_dict["edge_tensor"],
)
Expand Down
5 changes: 4 additions & 1 deletion src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(self, partition_configs):

# Final linear layer
# TODO: replace 1 with multiple outputs
self.final_lin_transform = nn.Linear(self.n_partitions, 1)
self.final_lin_transform = nn.Linear(self.n_partitions, 1, bias=False)
with torch.no_grad():
self.final_lin_transform.weight.copy_(torch.ones(self.n_partitions))

def init_conv_layers(self):
"""Initialize convolutional layers."""
Expand Down Expand Up @@ -230,3 +232,4 @@ def forward(self, data_objects):
]
net = MultiGCN(partition_configs)
result_dict = net(data_objects)
print(result_dict)

0 comments on commit fd0ec97

Please sign in to comment.