diff --git a/src/models.py b/src/models.py index 53e1829..14b66e9 100644 --- a/src/models.py +++ b/src/models.py @@ -292,7 +292,7 @@ def forward(self, data_objects): Dictionary containing "output" and "contributions". """ # For each data object - pools = [] + embeddings = [] for i, data in enumerate(data_objects): # Apply initial transform conv_data = self.init_transform[i](data.x) @@ -310,11 +310,11 @@ def forward(self, data_objects): # Apply pooling layer pooled_data = self.pool_layers[i](x=conv_data, index=data.batch) - pools.append(pooled_data) + embeddings.append(pooled_data) # Apply pool-to-hidden transform - pools = torch.cat(pools, dim=1) - hidden_data = self.pool_transform(pools) + embedding = torch.cat(embeddings, dim=1) + hidden_data = self.pool_transform(embedding) # Apply hidden layers output = self.hidden_layers(hidden_data)