diff --git a/src/graphs.py b/src/graphs.py index 76c8048..fdead15 100644 --- a/src/graphs.py +++ b/src/graphs.py @@ -107,13 +107,13 @@ def get_node_tensor(self): """ pass - def get_adjacency_tensor(self): - """Get the adjacency matrix of the graph as a PyTorch Tensor. + def get_edge_tensor(self): + """Get the edge matrix of the graph as a PyTorch Tensor. Returns ------- - adj_matrix: torch.Tensor - Adjacency tensor + edge_matrix: torch.Tensor + Edge tensor """ pass