Skip to content

Commit

Permalink
Fixed codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav S Deshmukh committed Sep 22, 2023
1 parent 986380f commit e3c2241
Showing 1 changed file with 63 additions and 45 deletions.
108 changes: 63 additions & 45 deletions src/data.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
"Store graph data using PyTorch Geometric abstractions."
"""Store graph data using PyTorch Geometric abstractions."""

import csv
from pathlib import Path

import pandas as pd
import torch
import tqdm

from ase.io import read
from ase import Atoms
from ase.io import read
from torch_geometric.data import Data, Dataset

from utils import partition_structure, featurize_atoms
from featurizers import OneHotEncoder
from constants import REPO_PATH
from featurizers import OneHotEncoder
from utils import featurize_atoms, partition_structure


class AtomsDataset(Dataset):
"""Class to hold a dataset containing graphs of atomic_structures."""

def __init__(self, root, prop_csv):
"""Initialize an AtomsDataset.
Expand Down Expand Up @@ -50,17 +51,21 @@ def __init__(self, root, prop_csv):
self.props.append(float(row[1]))

# Create name to property map
self.map_name_prop = {
name: prop for name, prop in zip(self.names, self.props)
}
self.map_name_prop = {name: prop for name, prop in zip(self.names, self.props)}

# Load index.csv if processed
self.index_path = self.processed_path / "index.csv"
if self.processed_status():
self.df_name_idx = pd.read_csv(self.index_path)

def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
encoder=OneHotEncoder()):
def process_data(
self,
z_cutoffs,
node_features,
edge_features,
max_atoms=None,
encoder=OneHotEncoder(),
):
"""Process raw data in the root directory into PyTorch Data and save.
Each atomic structure in the root directory is partitioned based on the
Expand All @@ -76,12 +81,12 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
length of z_cutoffs.
node_features: list[list]
List of lists of node featurization methods to be used for each
partition. For e.g., specify [["atomic_number", "dband_center"],
partition. For e.g., specify [["atomic_number", "dband_center"],
["atomic_number", "reactivity"], ["atomic_number", "reactivity"]] for
a typical bulk + surface + adsorbate partition.
edge_features: list[list]
List of lists of edge featurization methods to be used for each
partition. For e.g., specify [["bulk_bond_distance"],
partition. For e.g., specify [["bulk_bond_distance"],
["surface_bond_distance"], ["adsorbate_bond_distance"]] for
a typical bulk + surface + adsorbate partition.
max_atoms: int (default = None)
Expand All @@ -96,8 +101,9 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
)

# Iterate over files and process them
for i, name in tqdm.tqdm(enumerate(self.names), desc="Processing data",
total=len(self.names)):
for i, name in tqdm.tqdm(
enumerate(self.names), desc="Processing data", total=len(self.names)
):
# Map index to name
self.df_name_idx.loc[i, "index"] = i
self.df_name_idx.loc[i, "name"] = name
Expand All @@ -120,15 +126,15 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
node_features=node_features[j],
edge_features=edge_features[j],
max_atoms=max_atoms,
encoder=encoder
encoder=encoder,
)

# Convert to Data object
data_obj = Data(
x=feat_dict["node_tensor"],
edge_index=feat_dict["edge_indices"],
edge_attr=feat_dict["edge_tensor"],
y=torch.Tensor([self.map_name_prop[name]])
y=torch.Tensor([self.map_name_prop[name]]),
)
data_objects.append(data_obj)

Expand All @@ -137,32 +143,34 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,

# Save name-index dataframe
self.df_name_idx.to_csv(self.index_path, index=None)

def len(self):
"""Return size of the dataset."""
return len(self.names)

def get(self, i):
"""Fetch the processed graph(s) at the i-th index."""
data_objects = torch.load(self.processed_path / f"data_{i}.pt")
return data_objects

def processed_status(self):
"""Check if the dataset is processed."""
if Path(self.index_path).exists():
return True
else:
return False


class AtomsDatapoints:
"""Class to hold atomic structures as a datapoints (without targets).
This main difference between this class and AtomsDataset is that this is
initialized with a list of atoms objects (as opposed to a directory with
files containing atomic structures) without any targets specified. This is
useful to make predictions on atomic structures for which true target values
are not known, i.e., previously unseen structures.
"""

def __init__(self, atoms):
"""Initialize an AtomsDatapoint.
Expand All @@ -181,8 +189,14 @@ def __init__(self, atoms):
self.atoms = atoms
self.data = []

def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
encoder=OneHotEncoder()):
def process_data(
self,
z_cutoffs,
node_features,
edge_features,
max_atoms=None,
encoder=OneHotEncoder(),
):
"""Process list of Atoms objects into PyTorch Data and save.
Each atomic structure in the root directory is partitioned based on the
Expand All @@ -198,12 +212,12 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
length of z_cutoffs.
node_features: list[list]
List of lists of node featurization methods to be used for each
partition. For e.g., specify [["atomic_number", "dband_center"],
partition. For e.g., specify [["atomic_number", "dband_center"],
["atomic_number", "reactivity"], ["atomic_number", "reactivity"]] for
a typical bulk + surface + adsorbate partition.
edge_features: list[list]
List of lists of edge featurization methods to be used for each
partition. For e.g., specify [["bulk_bond_distance"],
partition. For e.g., specify [["bulk_bond_distance"],
["surface_bond_distance"], ["adsorbate_bond_distance"]] for
a typical bulk + surface + adsorbate partition.
max_atoms: int (default is None)
Expand All @@ -226,7 +240,7 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
node_features=node_features[j],
edge_features=edge_features[j],
max_atoms=max_atoms,
encoder=encoder
encoder=encoder,
)

# Convert to Data object
Expand All @@ -243,12 +257,13 @@ def process_data(self, z_cutoffs, node_features, edge_features, max_atoms=None,
def len(self):
"""Return size of the dataset."""
return len(self.data)

def get(self, i):
"""Fetch the processed graph(s) at the i-th index."""
data_objects = self.data[i]
return data_objects


def load_dataset(root, prop_csv, process_dict=None):
"""Load an AtomsDataset at the path given by root.
Expand All @@ -273,7 +288,7 @@ def load_dataset(root, prop_csv, process_dict=None):
node_features, edge_features, max_atoms (optional), encoder (optional).
Refer to the documentation of process_atoms for more information regarding
these parameters.
Returns
-------
dataset: AtomsDataset
Expand All @@ -285,6 +300,7 @@ def load_dataset(root, prop_csv, process_dict=None):

return dataset


def load_datapoints(atoms, process_dict):
"""Load AtomsDatapoints for the provided ase.Atoms or list of ase.Atoms.
Expand All @@ -297,13 +313,13 @@ def load_datapoints(atoms, process_dict):
----------
atoms: ase.Atoms object or a list of ase.Atoms objects
Structures for which predictions are to be made.
process_dict: dict
process_dict: dict
Parameters to process the provided Atoms objects into graphs.
This should contain the following keys: z_cutoffs, node_features,
edge_features, max_atoms (optional), encoder (optional). Refer to the
This should contain the following keys: z_cutoffs, node_features,
edge_features, max_atoms (optional), encoder (optional). Refer to the
documentation of process_atoms for more information regarding these
parameters.
Returns
-------
datapoints: AtomsDatapoints
Expand All @@ -312,7 +328,7 @@ def load_datapoints(atoms, process_dict):
datapoints = AtomsDatapoints(atoms)
if process_dict is not None:
datapoints.process_data(**process_dict)

return datapoints


Expand All @@ -323,7 +339,7 @@ def load_datapoints(atoms, process_dict):

# Create dataset
dataset = AtomsDataset(data_root_path, prop_csv_path)
# dataset.process_data(z_cutoffs=[13., 20.],
# dataset.process_data(z_cutoffs=[13., 20.],
# node_features=[
# ["atomic_number", "dband_center"],
# ["atomic_number", "reactivity"],
Expand All @@ -340,15 +356,17 @@ def load_datapoints(atoms, process_dict):
# Create datapoint
atoms = read(data_root_path / "Pt_3_Rh_9_-7-7-S.cif")
datapoint = AtomsDatapoints(atoms)
datapoint.process_data(z_cutoffs=[13., 20.],
node_features=[
["atomic_number", "dband_center"],
["atomic_number", "reactivity"],
["atomic_number", "reactivity"],
],
edge_features=[
["bulk_bond_distance"],
["surface_bond_distance"],
["adsorbate_bond_distance"],
])
print(datapoint.get(0))
datapoint.process_data(
z_cutoffs=[13.0, 20.0],
node_features=[
["atomic_number", "dband_center"],
["atomic_number", "reactivity"],
["atomic_number", "reactivity"],
],
edge_features=[
["bulk_bond_distance"],
["surface_bond_distance"],
["adsorbate_bond_distance"],
],
)
print(datapoint.get(0))

0 comments on commit e3c2241

Please sign in to comment.