Skip to content

Commit

Permalink
update main
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Zhu committed May 3, 2024
1 parent 0107945 commit eaba7dc
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

CryoREAD_Predict_Result/
3 changes: 0 additions & 3 deletions .gitmodules

This file was deleted.

23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,25 @@
# AutoClass3D
A deep learning based tool to automatically select the best reconstructed 3D maps within a group of maps.


## Installation
clone the repository:
```
git clone github.itap.purdue.edu/kiharalab/AutoClass3D
```
create conda environment:
```
conda env create -f environment.yml
```

## Arguments (All required)
```
-F: Path to the folder that contain all the input mrc files
-G: The GPU ID to use for the computation
-J: The Job Name
```

## Example
```
python main.py -F ./Class3D/job052 -G 1 -J job052_select
```
24 changes: 24 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: AutoClass3D
channels:
- pytorch
- nvidia
- conda-forge
- anaconda
- defaults
dependencies:
- cudatoolkit=11.8
- pip
- python=3.10
- pytorch
- pytorch-cuda=11.8
- pip:
- biopython
- numba
- numpy
- scipy
- tqdm
- ortools
- mrcfile
- progress
- numba-progress
- loguru
98 changes: 98 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from pathlib import Path
import subprocess
from glob import glob
from loguru import logger
import argparse
import os
from map_utils import calc_map_ccc, calculate_fsc

if __name__ == "__main__":

logger.add("AutoClass3D.log")

parser = argparse.ArgumentParser()
parser.add_argument("-F", type=str, help="Input job folder path containing all MRC files", required=True)
parser.add_argument("-G", type=str, help="GPU ID to use for prediction", required=True)
parser.add_argument("-J", type=str, help="Job name / output folder name", required=True)

args = parser.parse_args()

logger.info("Input job folder path: " + args.F)

CRYOREAD_PATH = "./CryoREAD/main.py"

mrc_files = glob(args.F + "/*.mrc")

OUTDIR = str(Path("./CryoREAD_Predict_Result").absolute() / args.J)
os.makedirs(OUTDIR, exist_ok=True)

# print(mrc_files)
logger.info("MRC files count: " + str(len(mrc_files)))
logger.info("MRC files path:\n" + "\n".join(mrc_files))

# run CryoREAD

map_list = []

for mrc_file in mrc_files:

curr_out_dir = OUTDIR + "/" + Path(mrc_file).stem.split(".")[0]

seg_map_path = curr_out_dir + "/input_segment.mrc"
prot_prob_path = curr_out_dir + "/mask_protein.mrc"

if not os.path.exists(seg_map_path) or not os.path.exists(prot_prob_path):
logger.info(f"Running CryoREAD prediction on {mrc_file}")
cmd = [
"python",
CRYOREAD_PATH,
"--mode=0",
f"-F={mrc_file}",
"--contour=0",
f"--gpu={args.G}",
f"--batch_size=4",
f"--prediction_only",
f"--resolution=8.0",
f"--output={curr_out_dir}",
]
process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)

while True:
output = process.stdout.readline()
if output == "" and process.poll() is not None:
break
if output:
logger.info(output.strip()) # Log stdout

rc = process.poll()
while True:
err = process.stderr.readline()
if err == "" and process.poll() is not None:
break
if err:
logger.error(err.strip()) # Log stderr

real_space_cc = calc_map_ccc(seg_map_path, prot_prob_path)[0]
x, fsc, cutoff_05, cutoff_0143 = calculate_fsc(seg_map_path, prot_prob_path)
map_list.append([mrc_file, real_space_cc, cutoff_05])

map_list.sort(key=lambda x: x[1], reverse=True)
for idx, (mrc_file, real_space_cc, golden_standard_fsc) in enumerate(map_list):
if idx == 0:
logger.opt(colors=True).info(
"Input map: "
+ f"<blue>{mrc_file}</blue>"
+ ", Real space CC: "
+ f"<blue>{real_space_cc:.4f}</blue>"
+ ", Golden standard FSC: "
+ f"<blue>{golden_standard_fsc:.4f}</blue>"
)
else:
logger.opt(colors=True).info(
"Input map: "
+ mrc_file
+ ", Real space CC: "
+ f"{real_space_cc:.4f}"
+ ", Golden standard FSC: "
+ f"{golden_standard_fsc:.4f}"
)
192 changes: 192 additions & 0 deletions map_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import mrcfile
import numpy as np


def calc_map_ccc(input_mrc, input_pred, center=True, overlap_only=False):
"""
Calculate the Concordance Correlation Coefficient (CCC) and overlap percentage of two input MRC files.
Parameters:
input_mrc (str): Path to the MRC file.
input_pred (str): Path to the prediction MRC file.
center (bool, optional): If True, center the data. Defaults to True.
Returns:
float: The calculated CCC.
float: The overlap percentage.
"""
# Open the MRC files and copy their data
with mrcfile.open(input_mrc) as mrc:
mrc_data = mrc.data.copy()
with mrcfile.open(input_pred) as mrc:
pred_data = mrc.data.copy()

# mrc_data = np.where(mrc_data > 1e-8, mrc_data, 0.0)
# pred_data = np.where(pred_data > 1e-8, pred_data, 0.0)

# Determine the minimum count of non-zero values
min_count = np.min([np.count_nonzero(mrc_data), np.count_nonzero(pred_data)])

# Calculate the overlap of non-zero values
overlap = mrc_data * pred_data > 0.0

if overlap_only:
mrc_data = mrc_data[overlap]
pred_data = pred_data[overlap]

# Center the data if specified
if center:
mrc_data = mrc_data - np.mean(mrc_data)
pred_data = pred_data - np.mean(pred_data)

# Calculate the overlap percentage
overlap_percent = np.sum(overlap) / min_count

# Calculate the CCC
ccc = np.sum(mrc_data * pred_data) / np.sqrt(np.sum(mrc_data**2) * np.sum(pred_data**2))

return ccc, overlap_percent


"""Compute FSC between two volumes, adapted from cryodrgn"""
import numpy as np
import torch
from torch.fft import fftshift, ifftshift, fft2, fftn, ifftn


def normalize(img, mean=0, std=None, std_n=None):
if std is None:
# Since std is a memory consuming process, use the first std_n samples for std determination
std = torch.std(img[:std_n, ...])

# logger.info(f"Normalized by {mean} +/- {std}")
return (img - mean) / std


def fft2_center(img):
return fftshift(fft2(fftshift(img, dim=(-1, -2))), dim=(-1, -2))


def fftn_center(img):
return fftshift(fftn(fftshift(img)))


def ifftn_center(img):
if isinstance(img, np.ndarray):
# Note: We can't just typecast a complex ndarray using torch.Tensor(array) !
img = torch.complex(torch.Tensor(img.real), torch.Tensor(img.imag))
x = ifftshift(img)
y = ifftn(x)
z = ifftshift(y)
return z


def ht2_center(img):
_img = fft2_center(img)
return _img.real - _img.imag


def htn_center(img):
_img = fftshift(fftn(fftshift(img)))
return _img.real - _img.imag


def iht2_center(img):
img = fft2_center(img)
img /= img.shape[-1] * img.shape[-2]
return img.real - img.imag


def ihtn_center(img):
img = fftshift(img)
img = fftn(img)
img = fftshift(img)
img /= torch.prod(torch.tensor(img.shape, device=img.device))
return img.real - img.imag


def symmetrize_ht(ht):
if ht.ndim == 2:
ht = ht[np.newaxis, ...]
assert ht.ndim == 3
n = ht.shape[0]

D = ht.shape[-1]
sym_ht = torch.empty((n, D + 1, D + 1), dtype=ht.dtype, device=ht.device)
sym_ht[:, 0:-1, 0:-1] = ht

assert D % 2 == 0
sym_ht[:, -1, :] = sym_ht[:, 0, :] # last row is the first row
sym_ht[:, :, -1] = sym_ht[:, :, 0] # last col is the first col
sym_ht[:, -1, -1] = sym_ht[:, 0, 0] # last corner is first corner

if n == 1:
sym_ht = sym_ht[0, ...]

return sym_ht


def calculate_fsc(vol1_f, vol2_f, Apix=1.0, output_f=None):

import mrcfile

with mrcfile.open(vol1_f, permissive=True) as v1:
vol1 = v1.data.copy()
with mrcfile.open(vol2_f, permissive=True) as v2:
vol2 = v2.data.copy()

assert vol1.shape == vol2.shape

# pad if non-cubic
padding_xyz = np.max(vol1.shape) - vol1.shape

vol1 = np.pad(vol1, ((0, padding_xyz[0]), (0, padding_xyz[1]), (0, padding_xyz[2])), mode="constant")
vol2 = np.pad(vol2, ((0, padding_xyz[0]), (0, padding_xyz[1]), (0, padding_xyz[2])), mode="constant")

if vol1.shape[0] % 2 != 0:
vol1 = np.pad(vol1, ((0, 1), (0, 1), (0, 1)), mode="constant")
vol2 = np.pad(vol2, ((0, 1), (0, 1), (0, 1)), mode="constant")

vol1 = torch.from_numpy(vol1).to(torch.float32)
vol2 = torch.from_numpy(vol2).to(torch.float32)

D = vol1.shape[0]
x = np.arange(-D // 2, D // 2)
x2, x1, x0 = np.meshgrid(x, x, x, indexing="ij")
coords = np.stack((x0, x1, x2), -1)
r = (coords**2).sum(-1) ** 0.5

assert r[D // 2, D // 2, D // 2] == 0.0, r[D // 2, D // 2, D // 2]

vol1 = fftn_center(vol1)
vol2 = fftn_center(vol2)

prev_mask = np.zeros((D, D, D), dtype=bool)
fsc = [1.0]
for i in range(1, D // 2):
mask = r < i
shell = np.where(mask & np.logical_not(prev_mask))
v1 = vol1[shell]
v2 = vol2[shell]
p = np.vdot(v1, v2) / (np.vdot(v1, v1) * np.vdot(v2, v2)) ** 0.5
fsc.append(float(p.real))
prev_mask = mask
fsc = np.asarray(fsc)
x = np.arange(D // 2) / D

res = np.stack((x, fsc), 1)
if output_f:
np.savetxt(output_f, res)
else:
# logger.info(res)
pass

w = np.where(fsc < 0.5)
if w:
cutoff_05 = 1 / x[w[0][0]] * Apix

w = np.where(fsc < 0.143)
if w:
cutoff_0143 = 1 / x[w[0][0]] * Apix

return x, fsc, cutoff_05, cutoff_0143

0 comments on commit eaba7dc

Please sign in to comment.