From d89769720f3d75be2a3de67ca7a819b51abc4303 Mon Sep 17 00:00:00 2001 From: Han Zhu Date: Tue, 24 Sep 2024 14:41:02 -0400 Subject: [PATCH] use relative path to script --- .gitignore | 13 ++- contour.py | 239 +++++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 5 +- 3 files changed, 253 insertions(+), 4 deletions(-) create mode 100644 contour.py diff --git a/.gitignore b/.gitignore index cb4e637..ebb8870 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,15 @@ cython_debug/ CryoREAD_Predict_Result/ output_test/ -examples/ \ No newline at end of file +examples/ +Tutorial5.0/ +test_outdir/ +test_cryoread_26363/ +job019_run_it050_class001_test/ +job052_run_it100_class003/ +22647_test_gmm/ +26363_test_gmm/ +CryoREAD/test_cryoread_22647/ +CryoREAD/test_cryoread_26363/ +CryoREAD/job019_run_it050_class001/ +CryoREAD/job052_run_it100_class003/ \ No newline at end of file diff --git a/contour.py b/contour.py new file mode 100644 index 0000000..4a34411 --- /dev/null +++ b/contour.py @@ -0,0 +1,239 @@ +from pathlib import Path + +import mrcfile +import numpy as np +from sklearn import mixture +import os + +from skimage.morphology import ball, opening +from skimage.filters import rank +from skimage.util import img_as_ubyte +from scipy.ndimage import zoom + +import matplotlib.pyplot as plt + +# change to relative path to this file +# CRYOREAD_PATH = "./CryoREAD/main.py" +CURR_SCIPT_PATH = Path(__file__).absolute().parent +CRYOREAD_PATH = CURR_SCIPT_PATH / "CryoREAD" / "main.py" + + +def save_mrc(orig_map_path, data, out_path): + with mrcfile.open(orig_map_path, permissive=True) as orig_map: + with mrcfile.new(out_path, data=data.astype(np.float32), overwrite=True) as mrc: + mrc.voxel_size = orig_map.voxel_size + mrc.header.nxstart = orig_map.header.nxstart + mrc.header.nystart = orig_map.header.nystart + mrc.header.nzstart = orig_map.header.nzstart + mrc.header.origin = orig_map.header.origin + mrc.header.mapc = orig_map.header.mapc + mrc.header.mapr = orig_map.header.mapr + mrc.header.maps = orig_map.header.maps + mrc.update_header_stats() + mrc.update_header_from_data() + mrc.flush() + + +def gen_features(map_array): + non_zero_data = map_array[np.nonzero(map_array)] + data_normalized = (map_array - map_array.min()) * 2 / (map_array.max() - map_array.min()) - 1 + local_grad_norm = rank.gradient(img_as_ubyte(data_normalized), ball(3)) + local_grad_norm = local_grad_norm[np.nonzero(map_array)] + local_grad_norm = (local_grad_norm - local_grad_norm.min()) / (local_grad_norm.max() - local_grad_norm.min()) + non_zero_data_normalized = (non_zero_data - non_zero_data.min()) / (non_zero_data.max() - non_zero_data.min()) + + # stack the flattened data and gradient + local_grad_norm = np.reshape(local_grad_norm, (-1, 1)) + non_zero_data_normalized = np.reshape(non_zero_data_normalized, (-1, 1)) + features = np.hstack((non_zero_data_normalized, local_grad_norm)) + + return features + + +def gmm_mask(input_map_path, output_folder, num_components=2, use_grad=False, n_init=1, plot_all=False): + print("input_map_path", input_map_path) + print("output_folder", output_folder) + + os.makedirs(output_folder, exist_ok=True) + + print("Opening map file") + + with mrcfile.open(input_map_path, permissive=True) as mrc: + map_data = mrc.data.copy() + + print("Input map shape:", map_data.shape) + + non_zero_data = map_data[np.nonzero(map_data)] + + data_normalized = (map_data - map_data.min()) * 2 / (map_data.max() - map_data.min()) - 1 + + print("Non-zero data shape", non_zero_data.shape) + + # Zooming to handling large maps + if len(non_zero_data) >= 5e6: + print("Map is too large, resizing...") + + # resample + zoom_factor = (2e6 / len(non_zero_data)) ** (1 / 3) + print("Resample with zoom factor:", zoom_factor) + + map_data_zoomed = zoom(map_data, zoom_factor, order=3, mode="grid-constant", grid_mode=False) + data_normalized_zoomed = (map_data_zoomed - map_data_zoomed.min()) * 2 / (map_data_zoomed.max() - map_data_zoomed.min()) - 1 + non_zero_data_zoomed = map_data_zoomed[np.nonzero(map_data_zoomed)] + + print("Shape after resample:", data_normalized_zoomed.shape) + + print("Calculating gradient") + local_grad_norm_zoomed = rank.gradient(img_as_ubyte(data_normalized_zoomed), ball(3)) + local_grad_norm_zoomed = local_grad_norm_zoomed[np.nonzero(map_data_zoomed)] + local_grad_norm_zoomed = (local_grad_norm_zoomed - local_grad_norm_zoomed.min()) / ( + local_grad_norm_zoomed.max() - local_grad_norm_zoomed.min() + ) + + non_zero_data_normalized_zoomed = (non_zero_data_zoomed - non_zero_data_zoomed.min()) / ( + non_zero_data_zoomed.max() - non_zero_data_zoomed.min() + ) + + local_grad_norm_zoomed = np.reshape(local_grad_norm_zoomed, (-1, 1)) + non_zero_data_normalized_zoomed = np.reshape(non_zero_data_normalized_zoomed, (-1, 1)) + # print(non_zero_data_normalized_zoomed.shape, local_grad_norm_zoomed.shape) + data_zoomed = np.hstack((non_zero_data_normalized_zoomed, local_grad_norm_zoomed)) + + # calculate guassian gradient norm + local_grad_norm = rank.gradient(img_as_ubyte(data_normalized), ball(3)) + local_grad_norm = local_grad_norm[np.nonzero(map_data)] + + # min-max normalization + local_grad_norm = (local_grad_norm - local_grad_norm.min()) / (local_grad_norm.max() - local_grad_norm.min()) + non_zero_data_normalized = (non_zero_data - non_zero_data.min()) / (non_zero_data.max() - non_zero_data.min()) + + # stack the flattened data and gradient + local_grad_norm = np.reshape(local_grad_norm, (-1, 1)) + non_zero_data_normalized = np.reshape(non_zero_data_normalized, (-1, 1)) + data = np.hstack((non_zero_data_normalized, local_grad_norm)) + + print("Fitting GMM") + + # fit the GMM + g = mixture.BayesianGaussianMixture(n_components=num_components, max_iter=500, n_init=n_init, tol=1e-2) + + if use_grad: + data_to_fit = data_zoomed if len(non_zero_data) >= 5e6 else data + else: + data_to_fit = non_zero_data_normalized_zoomed if len(non_zero_data) >= 5e6 else non_zero_data_normalized + print("Fitting feature shape:", data_to_fit.shape) + g.fit(data_to_fit) + print("Predicting, feature shape:", data.shape) + preds = g.predict(data) + + if plot_all: + fig, ax = plt.subplots(1, 1, figsize=(10, 3)) + all_datas = [] + for pred in np.unique(preds): + mask = np.zeros_like(map_data) + mask[np.nonzero(map_data)] = preds == pred + masked_map_data = map_data * mask + new_data_non_zero = masked_map_data[np.nonzero(masked_map_data)] + all_datas.append(new_data_non_zero.flatten()) + mean = np.mean(new_data_non_zero) + ax.axvline(mean, linestyle="--", color="k", label=f"Mean_{pred}") + labels = [f"Component {i}" for i in range(num_components)] + ax.hist(all_datas, alpha=0.5, bins=256, density=True, log=True, label=labels, stacked=True) + ax.set_yscale("log") + ax.legend(loc="upper right") + ax.set_xlabel("Map Density Value") + ax.set_ylabel("Density (log scale)") + ax.set_title("Stacked Histogram by Component") + fig.tight_layout() + # print("Saving figure to", os.path.join(output_folder, "hist_by_component.png")) + fig.savefig(os.path.join(output_folder, Path(input_map_path).stem + "_hist_by_components.png")) + + # choose ind that is closest to 0, and ind that has the highest mean + ind_noise = np.argmin(np.abs(g.means_[:, 0].flatten())) + # ind_max = np.argmax(g.covariances_[:, 0, 0].flatten()) + + print("Means: ", g.means_.shape, g.means_[:, 0], g.means_[:, 1]) + + # print("Variances: ", g.covariances_.shape, g.covariances_[ind_noise, 0, 0]) + # print("Std: ", np.sqrt(g.covariances_[ind_noise, 0, 0])) + + # generate a mask to keep only the component without the noise + mask = np.zeros_like(map_data) + mask[np.nonzero(map_data)] = preds != ind_noise + + noise_comp = map_data[np.nonzero(map_data)][preds == ind_noise] + revised_contour = np.max(noise_comp) + + prot_comp = map_data[np.nonzero(map_data)][preds != ind_noise] + + print("Revised contour:", revised_contour) + print("Remaining mask region size in voxels:", np.count_nonzero(mask)) + + # use opening to remove small holes + mask = opening(mask.astype(bool), ball(3)) + masked_map_data = map_data * mask + new_data_non_zero = masked_map_data[np.nonzero(masked_map_data)] + + # calcualte new gradient norm + new_fit_data = gen_features(masked_map_data) + print("Fitting feature shape:", new_fit_data.shape) + + # fit the GMM again on the new data + g2 = mixture.BayesianGaussianMixture(n_components=2, max_iter=500, n_init=n_init, tol=1e-2) + g2.fit(new_fit_data) + + # predict the new data + new_preds = g2.predict(new_fit_data) + # ind_noise_second = np.argmin(np.abs(g2.means_[:, 0].flatten())) + ind_noise_second = np.argmin(g2.covariances_[:, 0, 0].flatten()) + noise_comp_2 = masked_map_data[np.nonzero(masked_map_data)][new_preds == ind_noise_second] + prot_comp_2 = masked_map_data[np.nonzero(masked_map_data)][new_preds != ind_noise_second] + # revised_contour_2 = np.median(prot_comp_2) + revised_contour_2 = np.max(noise_comp_2) + + print("Revised contour (high):", revised_contour_2) + + # save the new data + save_mrc(input_map_path, masked_map_data, os.path.join(output_folder, Path(input_map_path).stem + "_mask.mrc")) + + mask_percent = np.count_nonzero(masked_map_data > 1e-8) / np.count_nonzero(map_data > 1e-8) + + # plot the histogram + fig, ax = plt.subplots(figsize=(10, 2)) + ax.hist(non_zero_data.flatten(), alpha=0.5, bins=256, density=False, log=True, label="Original") + ax.hist(new_data_non_zero.flatten(), alpha=0.5, bins=256, density=False, log=True, label="Masked") + ax.axvline(revised_contour, label="Revised Contour") + ax.axvline(revised_contour_2, label="Revised Contour (High)", linestyle="dashed") + ax.legend() + plt.title(input_map_path) + plt.savefig(os.path.join(output_folder, Path(input_map_path).stem + "_hist_overall.png")) + + out_txt = os.path.join(output_folder, Path(input_map_path).stem + "_revised_contour.txt") + + with open(out_txt, "w") as f: + f.write(f"Revised contour: {revised_contour}\n") + f.write(f"Revised contour (high): {revised_contour_2}\n") + f.write(f"Masked percentage: {mask_percent}\n") + + # return revised contour level and mask percent + return revised_contour, mask_percent + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_map_path", type=str, default=None) + parser.add_argument("-o", "--output_folder", type=str, default=None) + parser.add_argument("-p", "--plot_all", action="store_true") + parser.add_argument("-n", "--num_components", type=int, default=2) + parser.add_argument("-r", "--resolution", type=float, default=None) + args = parser.parse_args() + revised_contour, mask_percent = gmm_mask( + input_map_path=args.input_map_path, + output_folder=args.output_folder, + num_components=args.num_components, + use_grad=True, + n_init=3, + plot_all=args.plot_all, + ) diff --git a/main.py b/main.py index 3933bcc..9d9735f 100644 --- a/main.py +++ b/main.py @@ -20,9 +20,8 @@ logger.info("Input job folder path: ", args.F) - CRYOREAD_PATH = "./CryoREAD/main.py" - - # mrc_files = glob(args.F + "/*.mrc") + CURR_SCIPT_PATH = Path(__file__).absolute().parent + CRYOREAD_PATH = CURR_SCIPT_PATH / "CryoREAD" / "main.py" mrc_files = args.F # check if files exists