diff --git a/.gitignore b/.gitignore
index 68bc17f..7415b93 100644
--- a/.gitignore
+++ b/.gitignore
@@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
+
+CryoREAD_Predict_Result/
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
deleted file mode 100644
index 96335fb..0000000
--- a/.gitmodules
+++ /dev/null
@@ -1,3 +0,0 @@
-[submodule "CryoREAD"]
- path = CryoREAD
- url = https://github.com/kiharalab/CryoREAD.git
diff --git a/README.md b/README.md
index 9ff227e..0b97136 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,25 @@
# AutoClass3D
A deep learning based tool to automatically select the best reconstructed 3D maps within a group of maps.
+
+
+## Installation
+clone the repository:
+```
+git clone github.itap.purdue.edu/kiharalab/AutoClass3D
+```
+create conda environment:
+```
+conda env create -f environment.yml
+```
+
+## Arguments (All required)
+```
+-F: Path to the folder that contain all the input mrc files
+-G: The GPU ID to use for the computation
+-J: The Job Name
+```
+
+## Example
+```
+python main.py -F ./Class3D/job052 -G 1 -J job052_select
+```
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000..3b71572
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,24 @@
+name: AutoClass3D
+channels:
+ - pytorch
+ - nvidia
+ - conda-forge
+ - anaconda
+ - defaults
+dependencies:
+ - cudatoolkit=11.8
+ - pip
+ - python=3.10
+ - pytorch
+ - pytorch-cuda=11.8
+ - pip:
+ - biopython
+ - numba
+ - numpy
+ - scipy
+ - tqdm
+ - ortools
+ - mrcfile
+ - progress
+ - numba-progress
+ - loguru
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..53d43f9
--- /dev/null
+++ b/main.py
@@ -0,0 +1,98 @@
+from pathlib import Path
+import subprocess
+from glob import glob
+from loguru import logger
+import argparse
+import os
+from map_utils import calc_map_ccc, calculate_fsc
+
+if __name__ == "__main__":
+
+ logger.add("AutoClass3D.log")
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-F", type=str, help="Input job folder path containing all MRC files", required=True)
+ parser.add_argument("-G", type=str, help="GPU ID to use for prediction", required=True)
+ parser.add_argument("-J", type=str, help="Job name / output folder name", required=True)
+
+ args = parser.parse_args()
+
+ logger.info("Input job folder path: " + args.F)
+
+ CRYOREAD_PATH = "./CryoREAD/main.py"
+
+ mrc_files = glob(args.F + "/*.mrc")
+
+ OUTDIR = str(Path("./CryoREAD_Predict_Result").absolute() / args.J)
+ os.makedirs(OUTDIR, exist_ok=True)
+
+ # print(mrc_files)
+ logger.info("MRC files count: " + str(len(mrc_files)))
+ logger.info("MRC files path:\n" + "\n".join(mrc_files))
+
+ # run CryoREAD
+
+ map_list = []
+
+ for mrc_file in mrc_files:
+
+ curr_out_dir = OUTDIR + "/" + Path(mrc_file).stem.split(".")[0]
+
+ seg_map_path = curr_out_dir + "/input_segment.mrc"
+ prot_prob_path = curr_out_dir + "/mask_protein.mrc"
+
+ if not os.path.exists(seg_map_path) or not os.path.exists(prot_prob_path):
+ logger.info(f"Running CryoREAD prediction on {mrc_file}")
+ cmd = [
+ "python",
+ CRYOREAD_PATH,
+ "--mode=0",
+ f"-F={mrc_file}",
+ "--contour=0",
+ f"--gpu={args.G}",
+ f"--batch_size=4",
+ f"--prediction_only",
+ f"--resolution=8.0",
+ f"--output={curr_out_dir}",
+ ]
+ process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
+
+ while True:
+ output = process.stdout.readline()
+ if output == "" and process.poll() is not None:
+ break
+ if output:
+ logger.info(output.strip()) # Log stdout
+
+ rc = process.poll()
+ while True:
+ err = process.stderr.readline()
+ if err == "" and process.poll() is not None:
+ break
+ if err:
+ logger.error(err.strip()) # Log stderr
+
+ real_space_cc = calc_map_ccc(seg_map_path, prot_prob_path)[0]
+ x, fsc, cutoff_05, cutoff_0143 = calculate_fsc(seg_map_path, prot_prob_path)
+ map_list.append([mrc_file, real_space_cc, cutoff_05])
+
+ map_list.sort(key=lambda x: x[1], reverse=True)
+ for idx, (mrc_file, real_space_cc, golden_standard_fsc) in enumerate(map_list):
+ if idx == 0:
+ logger.opt(colors=True).info(
+ "Input map: "
+ + f"{mrc_file}"
+ + ", Real space CC: "
+ + f"{real_space_cc:.4f}"
+ + ", Golden standard FSC: "
+ + f"{golden_standard_fsc:.4f}"
+ )
+ else:
+ logger.opt(colors=True).info(
+ "Input map: "
+ + mrc_file
+ + ", Real space CC: "
+ + f"{real_space_cc:.4f}"
+ + ", Golden standard FSC: "
+ + f"{golden_standard_fsc:.4f}"
+ )
diff --git a/map_utils.py b/map_utils.py
new file mode 100644
index 0000000..e018bb0
--- /dev/null
+++ b/map_utils.py
@@ -0,0 +1,192 @@
+import mrcfile
+import numpy as np
+
+
+def calc_map_ccc(input_mrc, input_pred, center=True, overlap_only=False):
+ """
+ Calculate the Concordance Correlation Coefficient (CCC) and overlap percentage of two input MRC files.
+
+ Parameters:
+ input_mrc (str): Path to the MRC file.
+ input_pred (str): Path to the prediction MRC file.
+ center (bool, optional): If True, center the data. Defaults to True.
+
+ Returns:
+ float: The calculated CCC.
+ float: The overlap percentage.
+ """
+ # Open the MRC files and copy their data
+ with mrcfile.open(input_mrc) as mrc:
+ mrc_data = mrc.data.copy()
+ with mrcfile.open(input_pred) as mrc:
+ pred_data = mrc.data.copy()
+
+ # mrc_data = np.where(mrc_data > 1e-8, mrc_data, 0.0)
+ # pred_data = np.where(pred_data > 1e-8, pred_data, 0.0)
+
+ # Determine the minimum count of non-zero values
+ min_count = np.min([np.count_nonzero(mrc_data), np.count_nonzero(pred_data)])
+
+ # Calculate the overlap of non-zero values
+ overlap = mrc_data * pred_data > 0.0
+
+ if overlap_only:
+ mrc_data = mrc_data[overlap]
+ pred_data = pred_data[overlap]
+
+ # Center the data if specified
+ if center:
+ mrc_data = mrc_data - np.mean(mrc_data)
+ pred_data = pred_data - np.mean(pred_data)
+
+ # Calculate the overlap percentage
+ overlap_percent = np.sum(overlap) / min_count
+
+ # Calculate the CCC
+ ccc = np.sum(mrc_data * pred_data) / np.sqrt(np.sum(mrc_data**2) * np.sum(pred_data**2))
+
+ return ccc, overlap_percent
+
+
+"""Compute FSC between two volumes, adapted from cryodrgn"""
+import numpy as np
+import torch
+from torch.fft import fftshift, ifftshift, fft2, fftn, ifftn
+
+
+def normalize(img, mean=0, std=None, std_n=None):
+ if std is None:
+ # Since std is a memory consuming process, use the first std_n samples for std determination
+ std = torch.std(img[:std_n, ...])
+
+ # logger.info(f"Normalized by {mean} +/- {std}")
+ return (img - mean) / std
+
+
+def fft2_center(img):
+ return fftshift(fft2(fftshift(img, dim=(-1, -2))), dim=(-1, -2))
+
+
+def fftn_center(img):
+ return fftshift(fftn(fftshift(img)))
+
+
+def ifftn_center(img):
+ if isinstance(img, np.ndarray):
+ # Note: We can't just typecast a complex ndarray using torch.Tensor(array) !
+ img = torch.complex(torch.Tensor(img.real), torch.Tensor(img.imag))
+ x = ifftshift(img)
+ y = ifftn(x)
+ z = ifftshift(y)
+ return z
+
+
+def ht2_center(img):
+ _img = fft2_center(img)
+ return _img.real - _img.imag
+
+
+def htn_center(img):
+ _img = fftshift(fftn(fftshift(img)))
+ return _img.real - _img.imag
+
+
+def iht2_center(img):
+ img = fft2_center(img)
+ img /= img.shape[-1] * img.shape[-2]
+ return img.real - img.imag
+
+
+def ihtn_center(img):
+ img = fftshift(img)
+ img = fftn(img)
+ img = fftshift(img)
+ img /= torch.prod(torch.tensor(img.shape, device=img.device))
+ return img.real - img.imag
+
+
+def symmetrize_ht(ht):
+ if ht.ndim == 2:
+ ht = ht[np.newaxis, ...]
+ assert ht.ndim == 3
+ n = ht.shape[0]
+
+ D = ht.shape[-1]
+ sym_ht = torch.empty((n, D + 1, D + 1), dtype=ht.dtype, device=ht.device)
+ sym_ht[:, 0:-1, 0:-1] = ht
+
+ assert D % 2 == 0
+ sym_ht[:, -1, :] = sym_ht[:, 0, :] # last row is the first row
+ sym_ht[:, :, -1] = sym_ht[:, :, 0] # last col is the first col
+ sym_ht[:, -1, -1] = sym_ht[:, 0, 0] # last corner is first corner
+
+ if n == 1:
+ sym_ht = sym_ht[0, ...]
+
+ return sym_ht
+
+
+def calculate_fsc(vol1_f, vol2_f, Apix=1.0, output_f=None):
+
+ import mrcfile
+
+ with mrcfile.open(vol1_f, permissive=True) as v1:
+ vol1 = v1.data.copy()
+ with mrcfile.open(vol2_f, permissive=True) as v2:
+ vol2 = v2.data.copy()
+
+ assert vol1.shape == vol2.shape
+
+ # pad if non-cubic
+ padding_xyz = np.max(vol1.shape) - vol1.shape
+
+ vol1 = np.pad(vol1, ((0, padding_xyz[0]), (0, padding_xyz[1]), (0, padding_xyz[2])), mode="constant")
+ vol2 = np.pad(vol2, ((0, padding_xyz[0]), (0, padding_xyz[1]), (0, padding_xyz[2])), mode="constant")
+
+ if vol1.shape[0] % 2 != 0:
+ vol1 = np.pad(vol1, ((0, 1), (0, 1), (0, 1)), mode="constant")
+ vol2 = np.pad(vol2, ((0, 1), (0, 1), (0, 1)), mode="constant")
+
+ vol1 = torch.from_numpy(vol1).to(torch.float32)
+ vol2 = torch.from_numpy(vol2).to(torch.float32)
+
+ D = vol1.shape[0]
+ x = np.arange(-D // 2, D // 2)
+ x2, x1, x0 = np.meshgrid(x, x, x, indexing="ij")
+ coords = np.stack((x0, x1, x2), -1)
+ r = (coords**2).sum(-1) ** 0.5
+
+ assert r[D // 2, D // 2, D // 2] == 0.0, r[D // 2, D // 2, D // 2]
+
+ vol1 = fftn_center(vol1)
+ vol2 = fftn_center(vol2)
+
+ prev_mask = np.zeros((D, D, D), dtype=bool)
+ fsc = [1.0]
+ for i in range(1, D // 2):
+ mask = r < i
+ shell = np.where(mask & np.logical_not(prev_mask))
+ v1 = vol1[shell]
+ v2 = vol2[shell]
+ p = np.vdot(v1, v2) / (np.vdot(v1, v1) * np.vdot(v2, v2)) ** 0.5
+ fsc.append(float(p.real))
+ prev_mask = mask
+ fsc = np.asarray(fsc)
+ x = np.arange(D // 2) / D
+
+ res = np.stack((x, fsc), 1)
+ if output_f:
+ np.savetxt(output_f, res)
+ else:
+ # logger.info(res)
+ pass
+
+ w = np.where(fsc < 0.5)
+ if w:
+ cutoff_05 = 1 / x[w[0][0]] * Apix
+
+ w = np.where(fsc < 0.143)
+ if w:
+ cutoff_0143 = 1 / x[w[0][0]] * Apix
+
+ return x, fsc, cutoff_05, cutoff_0143