Skip to content

Commit

Permalink
Added get_embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 26, 2023
1 parent aaa7a7c commit 0f913ee
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,43 @@ def forward(self, data_objects):
output = self.hidden_layers(hidden_data)

return {"output": output}

def get_embeddings(self, data_objects):
"""Get the pooled embeddings of each partition.
Parameters
----------
data_objects: list
List of data objects, each corresponding to a graph of a partition
of an atomic structure.
Returns
------
embeddings: list
List of embedding tensors.
"""
# For each data object
embeddings = []
for i, data in enumerate(data_objects):
# Apply initial transform
conv_data = self.init_transform[i](data.x)

# Apply convolutional layers
for layer in self.conv_layers[i]:
if isinstance(layer, gnn.MessagePassing):
conv_data = layer(
x=conv_data,
edge_index=data.edge_index,
edge_attr=data.edge_attr,
)
else:
conv_data = layer(conv_data)

# Apply pooling layer
pooled_data = self.pool_layers[i](x=conv_data, index=data.batch)
embeddings.append(pooled_data)

return embeddings

if __name__ == "__main__":
from pathlib import Path
Expand Down

0 comments on commit 0f913ee

Please sign in to comment.