From eaba7dc8210258fe3679b28cb396e3235d121741 Mon Sep 17 00:00:00 2001 From: Han Zhu Date: Fri, 3 May 2024 17:18:19 -0400 Subject: [PATCH] update main --- .gitignore | 2 + .gitmodules | 3 - README.md | 23 ++++++ environment.yml | 24 ++++++ main.py | 98 ++++++++++++++++++++++++ map_utils.py | 192 ++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 339 insertions(+), 3 deletions(-) delete mode 100644 .gitmodules create mode 100644 environment.yml create mode 100644 main.py create mode 100644 map_utils.py diff --git a/.gitignore b/.gitignore index 68bc17f..7415b93 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +CryoREAD_Predict_Result/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 96335fb..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "CryoREAD"] - path = CryoREAD - url = https://github.com/kiharalab/CryoREAD.git diff --git a/README.md b/README.md index 9ff227e..0b97136 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,25 @@ # AutoClass3D A deep learning based tool to automatically select the best reconstructed 3D maps within a group of maps. + + +## Installation +clone the repository: +``` +git clone github.itap.purdue.edu/kiharalab/AutoClass3D +``` +create conda environment: +``` +conda env create -f environment.yml +``` + +## Arguments (All required) +``` +-F: Path to the folder that contain all the input mrc files +-G: The GPU ID to use for the computation +-J: The Job Name +``` + +## Example +``` +python main.py -F ./Class3D/job052 -G 1 -J job052_select +``` \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..3b71572 --- /dev/null +++ b/environment.yml @@ -0,0 +1,24 @@ +name: AutoClass3D +channels: + - pytorch + - nvidia + - conda-forge + - anaconda + - defaults +dependencies: + - cudatoolkit=11.8 + - pip + - python=3.10 + - pytorch + - pytorch-cuda=11.8 + - pip: + - biopython + - numba + - numpy + - scipy + - tqdm + - ortools + - mrcfile + - progress + - numba-progress + - loguru diff --git a/main.py b/main.py new file mode 100644 index 0000000..53d43f9 --- /dev/null +++ b/main.py @@ -0,0 +1,98 @@ +from pathlib import Path +import subprocess +from glob import glob +from loguru import logger +import argparse +import os +from map_utils import calc_map_ccc, calculate_fsc + +if __name__ == "__main__": + + logger.add("AutoClass3D.log") + + parser = argparse.ArgumentParser() + parser.add_argument("-F", type=str, help="Input job folder path containing all MRC files", required=True) + parser.add_argument("-G", type=str, help="GPU ID to use for prediction", required=True) + parser.add_argument("-J", type=str, help="Job name / output folder name", required=True) + + args = parser.parse_args() + + logger.info("Input job folder path: " + args.F) + + CRYOREAD_PATH = "./CryoREAD/main.py" + + mrc_files = glob(args.F + "/*.mrc") + + OUTDIR = str(Path("./CryoREAD_Predict_Result").absolute() / args.J) + os.makedirs(OUTDIR, exist_ok=True) + + # print(mrc_files) + logger.info("MRC files count: " + str(len(mrc_files))) + logger.info("MRC files path:\n" + "\n".join(mrc_files)) + + # run CryoREAD + + map_list = [] + + for mrc_file in mrc_files: + + curr_out_dir = OUTDIR + "/" + Path(mrc_file).stem.split(".")[0] + + seg_map_path = curr_out_dir + "/input_segment.mrc" + prot_prob_path = curr_out_dir + "/mask_protein.mrc" + + if not os.path.exists(seg_map_path) or not os.path.exists(prot_prob_path): + logger.info(f"Running CryoREAD prediction on {mrc_file}") + cmd = [ + "python", + CRYOREAD_PATH, + "--mode=0", + f"-F={mrc_file}", + "--contour=0", + f"--gpu={args.G}", + f"--batch_size=4", + f"--prediction_only", + f"--resolution=8.0", + f"--output={curr_out_dir}", + ] + process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) + + while True: + output = process.stdout.readline() + if output == "" and process.poll() is not None: + break + if output: + logger.info(output.strip()) # Log stdout + + rc = process.poll() + while True: + err = process.stderr.readline() + if err == "" and process.poll() is not None: + break + if err: + logger.error(err.strip()) # Log stderr + + real_space_cc = calc_map_ccc(seg_map_path, prot_prob_path)[0] + x, fsc, cutoff_05, cutoff_0143 = calculate_fsc(seg_map_path, prot_prob_path) + map_list.append([mrc_file, real_space_cc, cutoff_05]) + + map_list.sort(key=lambda x: x[1], reverse=True) + for idx, (mrc_file, real_space_cc, golden_standard_fsc) in enumerate(map_list): + if idx == 0: + logger.opt(colors=True).info( + "Input map: " + + f"{mrc_file}" + + ", Real space CC: " + + f"{real_space_cc:.4f}" + + ", Golden standard FSC: " + + f"{golden_standard_fsc:.4f}" + ) + else: + logger.opt(colors=True).info( + "Input map: " + + mrc_file + + ", Real space CC: " + + f"{real_space_cc:.4f}" + + ", Golden standard FSC: " + + f"{golden_standard_fsc:.4f}" + ) diff --git a/map_utils.py b/map_utils.py new file mode 100644 index 0000000..e018bb0 --- /dev/null +++ b/map_utils.py @@ -0,0 +1,192 @@ +import mrcfile +import numpy as np + + +def calc_map_ccc(input_mrc, input_pred, center=True, overlap_only=False): + """ + Calculate the Concordance Correlation Coefficient (CCC) and overlap percentage of two input MRC files. + + Parameters: + input_mrc (str): Path to the MRC file. + input_pred (str): Path to the prediction MRC file. + center (bool, optional): If True, center the data. Defaults to True. + + Returns: + float: The calculated CCC. + float: The overlap percentage. + """ + # Open the MRC files and copy their data + with mrcfile.open(input_mrc) as mrc: + mrc_data = mrc.data.copy() + with mrcfile.open(input_pred) as mrc: + pred_data = mrc.data.copy() + + # mrc_data = np.where(mrc_data > 1e-8, mrc_data, 0.0) + # pred_data = np.where(pred_data > 1e-8, pred_data, 0.0) + + # Determine the minimum count of non-zero values + min_count = np.min([np.count_nonzero(mrc_data), np.count_nonzero(pred_data)]) + + # Calculate the overlap of non-zero values + overlap = mrc_data * pred_data > 0.0 + + if overlap_only: + mrc_data = mrc_data[overlap] + pred_data = pred_data[overlap] + + # Center the data if specified + if center: + mrc_data = mrc_data - np.mean(mrc_data) + pred_data = pred_data - np.mean(pred_data) + + # Calculate the overlap percentage + overlap_percent = np.sum(overlap) / min_count + + # Calculate the CCC + ccc = np.sum(mrc_data * pred_data) / np.sqrt(np.sum(mrc_data**2) * np.sum(pred_data**2)) + + return ccc, overlap_percent + + +"""Compute FSC between two volumes, adapted from cryodrgn""" +import numpy as np +import torch +from torch.fft import fftshift, ifftshift, fft2, fftn, ifftn + + +def normalize(img, mean=0, std=None, std_n=None): + if std is None: + # Since std is a memory consuming process, use the first std_n samples for std determination + std = torch.std(img[:std_n, ...]) + + # logger.info(f"Normalized by {mean} +/- {std}") + return (img - mean) / std + + +def fft2_center(img): + return fftshift(fft2(fftshift(img, dim=(-1, -2))), dim=(-1, -2)) + + +def fftn_center(img): + return fftshift(fftn(fftshift(img))) + + +def ifftn_center(img): + if isinstance(img, np.ndarray): + # Note: We can't just typecast a complex ndarray using torch.Tensor(array) ! + img = torch.complex(torch.Tensor(img.real), torch.Tensor(img.imag)) + x = ifftshift(img) + y = ifftn(x) + z = ifftshift(y) + return z + + +def ht2_center(img): + _img = fft2_center(img) + return _img.real - _img.imag + + +def htn_center(img): + _img = fftshift(fftn(fftshift(img))) + return _img.real - _img.imag + + +def iht2_center(img): + img = fft2_center(img) + img /= img.shape[-1] * img.shape[-2] + return img.real - img.imag + + +def ihtn_center(img): + img = fftshift(img) + img = fftn(img) + img = fftshift(img) + img /= torch.prod(torch.tensor(img.shape, device=img.device)) + return img.real - img.imag + + +def symmetrize_ht(ht): + if ht.ndim == 2: + ht = ht[np.newaxis, ...] + assert ht.ndim == 3 + n = ht.shape[0] + + D = ht.shape[-1] + sym_ht = torch.empty((n, D + 1, D + 1), dtype=ht.dtype, device=ht.device) + sym_ht[:, 0:-1, 0:-1] = ht + + assert D % 2 == 0 + sym_ht[:, -1, :] = sym_ht[:, 0, :] # last row is the first row + sym_ht[:, :, -1] = sym_ht[:, :, 0] # last col is the first col + sym_ht[:, -1, -1] = sym_ht[:, 0, 0] # last corner is first corner + + if n == 1: + sym_ht = sym_ht[0, ...] + + return sym_ht + + +def calculate_fsc(vol1_f, vol2_f, Apix=1.0, output_f=None): + + import mrcfile + + with mrcfile.open(vol1_f, permissive=True) as v1: + vol1 = v1.data.copy() + with mrcfile.open(vol2_f, permissive=True) as v2: + vol2 = v2.data.copy() + + assert vol1.shape == vol2.shape + + # pad if non-cubic + padding_xyz = np.max(vol1.shape) - vol1.shape + + vol1 = np.pad(vol1, ((0, padding_xyz[0]), (0, padding_xyz[1]), (0, padding_xyz[2])), mode="constant") + vol2 = np.pad(vol2, ((0, padding_xyz[0]), (0, padding_xyz[1]), (0, padding_xyz[2])), mode="constant") + + if vol1.shape[0] % 2 != 0: + vol1 = np.pad(vol1, ((0, 1), (0, 1), (0, 1)), mode="constant") + vol2 = np.pad(vol2, ((0, 1), (0, 1), (0, 1)), mode="constant") + + vol1 = torch.from_numpy(vol1).to(torch.float32) + vol2 = torch.from_numpy(vol2).to(torch.float32) + + D = vol1.shape[0] + x = np.arange(-D // 2, D // 2) + x2, x1, x0 = np.meshgrid(x, x, x, indexing="ij") + coords = np.stack((x0, x1, x2), -1) + r = (coords**2).sum(-1) ** 0.5 + + assert r[D // 2, D // 2, D // 2] == 0.0, r[D // 2, D // 2, D // 2] + + vol1 = fftn_center(vol1) + vol2 = fftn_center(vol2) + + prev_mask = np.zeros((D, D, D), dtype=bool) + fsc = [1.0] + for i in range(1, D // 2): + mask = r < i + shell = np.where(mask & np.logical_not(prev_mask)) + v1 = vol1[shell] + v2 = vol2[shell] + p = np.vdot(v1, v2) / (np.vdot(v1, v1) * np.vdot(v2, v2)) ** 0.5 + fsc.append(float(p.real)) + prev_mask = mask + fsc = np.asarray(fsc) + x = np.arange(D // 2) / D + + res = np.stack((x, fsc), 1) + if output_f: + np.savetxt(output_f, res) + else: + # logger.info(res) + pass + + w = np.where(fsc < 0.5) + if w: + cutoff_05 = 1 / x[w[0][0]] * Apix + + w = np.where(fsc < 0.143) + if w: + cutoff_0143 = 1 / x[w[0][0]] * Apix + + return x, fsc, cutoff_05, cutoff_0143