diff --git a/src/models.py b/src/models.py index 14b66e9..b91078f 100644 --- a/src/models.py +++ b/src/models.py @@ -320,7 +320,43 @@ def forward(self, data_objects): output = self.hidden_layers(hidden_data) return {"output": output} + + def get_embeddings(self, data_objects): + """Get the pooled embeddings of each partition. + + Parameters + ---------- + data_objects: list + List of data objects, each corresponding to a graph of a partition + of an atomic structure. + + Returns + ------ + embeddings: list + List of embedding tensors. + """ + # For each data object + embeddings = [] + for i, data in enumerate(data_objects): + # Apply initial transform + conv_data = self.init_transform[i](data.x) + + # Apply convolutional layers + for layer in self.conv_layers[i]: + if isinstance(layer, gnn.MessagePassing): + conv_data = layer( + x=conv_data, + edge_index=data.edge_index, + edge_attr=data.edge_attr, + ) + else: + conv_data = layer(conv_data) + + # Apply pooling layer + pooled_data = self.pool_layers[i](x=conv_data, index=data.batch) + embeddings.append(pooled_data) + return embeddings if __name__ == "__main__": from pathlib import Path