Skip to content

Commit

Permalink
Changed pools to embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 26, 2023
1 parent a5314f2 commit aaa7a7c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def forward(self, data_objects):
Dictionary containing "output" and "contributions".
"""
# For each data object
pools = []
embeddings = []
for i, data in enumerate(data_objects):
# Apply initial transform
conv_data = self.init_transform[i](data.x)
Expand All @@ -310,11 +310,11 @@ def forward(self, data_objects):

# Apply pooling layer
pooled_data = self.pool_layers[i](x=conv_data, index=data.batch)
pools.append(pooled_data)
embeddings.append(pooled_data)

# Apply pool-to-hidden transform
pools = torch.cat(pools, dim=1)
hidden_data = self.pool_transform(pools)
embedding = torch.cat(embeddings, dim=1)
hidden_data = self.pool_transform(embedding)

# Apply hidden layers
output = self.hidden_layers(hidden_data)
Expand Down

0 comments on commit aaa7a7c

Please sign in to comment.