From 5949b7d8dccd903cb4a833537c1c7d82624d50b3 Mon Sep 17 00:00:00 2001 From: Han Zhu Date: Tue, 27 Aug 2024 17:33:02 -0400 Subject: [PATCH] update gmm contour --- CryoREAD/main.py | 2 +- README.md | 2 +- gmm_contour.py | 146 ++++++++++++++++++++++++++++------------------- 3 files changed, 88 insertions(+), 62 deletions(-) diff --git a/CryoREAD/main.py b/CryoREAD/main.py index d4b3374..c3eb2ce 100644 --- a/CryoREAD/main.py +++ b/CryoREAD/main.py @@ -132,7 +132,7 @@ def init_save_path(origin_map_path): mask_map_path = os.path.join(save_path, "mask_protein.mrc") from data_processing.Gen_MaskDRNA_map import Gen_MaskProtein_map - Gen_MaskProtein_map(chain_prob, cur_map_path, mask_map_path, params["contour"], threshold=0.3) + Gen_MaskProtein_map(chain_prob, cur_map_path, mask_map_path, params["contour"], threshold=0.2) if params["prediction_only"]: print( "Our prediction results are saved in %s with mrc format for visualization check." % save_path_2nd_stage) diff --git a/README.md b/README.md index 88b6c56..f04f0cb 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ python main.py -F ./Class3D/job052/class1.mrc ./Class3D/job052/class2.mrc ./Clas -n: Number of intializations (Optional, 3 by default) ``` -## Example for generating +## Example for Auto Contouring ``` python gmm_contour.py -i ./Class3D/job052/class1.mrc -o ./output_folder -p diff --git a/gmm_contour.py b/gmm_contour.py index 6db66bc..61e417e 100644 --- a/gmm_contour.py +++ b/gmm_contour.py @@ -29,15 +29,26 @@ def save_mrc(orig_map_path, data, out_path): mrc.flush() -def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_init=1, plot_all=False): +def gen_features(map_array): + non_zero_data = map_array[np.nonzero(map_array)] + data_normalized = (map_array - map_array.min()) * 2 / (map_array.max() - map_array.min()) - 1 + local_grad_norm = rank.gradient(img_as_ubyte(data_normalized), ball(3)) + local_grad_norm = local_grad_norm[np.nonzero(map_array)] + local_grad_norm = (local_grad_norm - local_grad_norm.min()) / (local_grad_norm.max() - local_grad_norm.min()) + non_zero_data_normalized = (non_zero_data - non_zero_data.min()) / (non_zero_data.max() - non_zero_data.min()) + + # stack the flattened data and gradient + local_grad_norm = np.reshape(local_grad_norm, (-1, 1)) + non_zero_data_normalized = np.reshape(non_zero_data_normalized, (-1, 1)) + features = np.hstack((non_zero_data_normalized, local_grad_norm)) + + return features + + +def gmm_mask(input_map_path, output_folder, num_components=2, use_grad=False, n_init=1, plot_all=False): print("input_map_path", input_map_path) print("output_folder", output_folder) - # if os.path.exists(output_folder): - # # print("Output file already exists") - # raise ValueError("Output FOLD already exists") - # return None, None - os.makedirs(output_folder, exist_ok=True) print("Opening map file") @@ -55,15 +66,14 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_ # Zooming to handling large maps if len(non_zero_data) >= 5e6: - print("Map is too large") + print("Map is too large, resizing...") # resample zoom_factor = (2e6 / len(non_zero_data)) ** (1 / 3) print("Resample with zoom factor:", zoom_factor) - map_data_zoomed = zoom(map_data, zoom_factor, order=3, mode="grid-constant", grid_mode=True) - data_normalized_zoomed = (map_data_zoomed - map_data_zoomed.min()) * 2 / ( - map_data_zoomed.max() - map_data_zoomed.min()) - 1 + map_data_zoomed = zoom(map_data, zoom_factor, order=3, mode="grid-constant", grid_mode=False) + data_normalized_zoomed = (map_data_zoomed - map_data_zoomed.min()) * 2 / (map_data_zoomed.max() - map_data_zoomed.min()) - 1 non_zero_data_zoomed = map_data_zoomed[np.nonzero(map_data_zoomed)] print("Shape after resample:", data_normalized_zoomed.shape) @@ -72,11 +82,11 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_ local_grad_norm_zoomed = rank.gradient(img_as_ubyte(data_normalized_zoomed), ball(3)) local_grad_norm_zoomed = local_grad_norm_zoomed[np.nonzero(map_data_zoomed)] local_grad_norm_zoomed = (local_grad_norm_zoomed - local_grad_norm_zoomed.min()) / ( - local_grad_norm_zoomed.max() - local_grad_norm_zoomed.min() + local_grad_norm_zoomed.max() - local_grad_norm_zoomed.min() ) non_zero_data_normalized_zoomed = (non_zero_data_zoomed - non_zero_data_zoomed.min()) / ( - non_zero_data_zoomed.max() - non_zero_data_zoomed.min() + non_zero_data_zoomed.max() - non_zero_data_zoomed.min() ) local_grad_norm_zoomed = np.reshape(local_grad_norm_zoomed, (-1, 1)) @@ -100,7 +110,7 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_ print("Fitting GMM") # fit the GMM - g = mixture.BayesianGaussianMixture(n_components=num_components, max_iter=200, n_init=n_init, tol=1e-2) + g = mixture.BayesianGaussianMixture(n_components=num_components, max_iter=500, n_init=n_init, tol=1e-2) if use_grad: data_to_fit = data_zoomed if len(non_zero_data) >= 5e6 else data @@ -113,74 +123,82 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_ if plot_all: fig, ax = plt.subplots(1, 1, figsize=(10, 3)) + all_datas = [] for pred in np.unique(preds): mask = np.zeros_like(map_data) mask[np.nonzero(map_data)] = preds == pred - new_data = map_data * mask - new_data_non_zero = new_data[np.nonzero(new_data)] - ax.hist(new_data_non_zero.flatten(), alpha=0.5, bins=256, density=False, log=True, label=f"Masked_{pred}") - # plot mean - # mean = g.means_[pred, 0] - # ax.axvline(mean, label=f"Mean_{pred}") - ax.legend(loc="upper right") + masked_map_data = map_data * mask + new_data_non_zero = masked_map_data[np.nonzero(masked_map_data)] + all_datas.append(new_data_non_zero.flatten()) + mean = np.mean(new_data_non_zero) + ax.axvline(mean, linestyle="--", color="k", label=f"Mean_{pred}") + labels = [f"Component {i}" for i in range(num_components)] + ax.hist(all_datas, alpha=0.5, bins=256, density=True, log=True, label=labels, stacked=True) + ax.set_yscale("log") + ax.legend(loc="upper right") + ax.set_xlabel("Map Density Value") + ax.set_ylabel("Density (log scale)") + ax.set_title("Stacked Histogram by Component") fig.tight_layout() # print("Saving figure to", os.path.join(output_folder, "hist_by_component.png")) fig.savefig(os.path.join(output_folder, Path(input_map_path).stem + "_hist_by_components.png")) - # generate a mask to keep only the component with the largest variance - mask = np.zeros_like(map_data) - # mask[np.nonzero(masked_prot_data)] = (preds == np.argmax(g.means_[:, 0].flatten())) - - # ind = np.argpartition(g.means_[:, 0].flatten(), -3)[-3:] - # choose ind that is closest to 0 - ind = np.argmin(np.abs(g.means_[:, 0].flatten())) + # choose ind that is closest to 0, and ind that has the highest mean + ind_noise = np.argmin(np.abs(g.means_[:, 0].flatten())) + # ind_max = np.argmax(g.covariances_[:, 0, 0].flatten()) - print("ind to remove", ind) + print("Means: ", g.means_.shape, g.means_[:, 0], g.means_[:, 1]) - # mask[np.nonzero(map_data)] = preds in ind - print( - "Means: ", - g.means_.shape, - g.means_[:, 0], - ) - print("Variances: ", g.covariances_.shape, g.covariances_[:, 0, 0]) + # print("Variances: ", g.covariances_.shape, g.covariances_[ind_noise, 0, 0]) + # print("Std: ", np.sqrt(g.covariances_[ind_noise, 0, 0])) - # mask[np.nonzero(map_data)] = (preds == ind[0]) | (preds == ind[1]) | (preds == ind[2]) - mask[np.nonzero(map_data)] = (preds != ind) + # generate a mask to keep only the component without the noise + mask = np.zeros_like(map_data) + mask[np.nonzero(map_data)] = preds != ind_noise - noise_comp = map_data[np.nonzero(map_data)][preds == ind] - # 98 percentile - # revised_contour = np.percentile(noise_comp, 98) + noise_comp = map_data[np.nonzero(map_data)][preds == ind_noise] revised_contour = np.max(noise_comp) - print("Revised contour", revised_contour) + prot_comp = map_data[np.nonzero(map_data)][preds != ind_noise] - print("Remaining mask region size in voxels", np.count_nonzero(mask)) + print("Revised contour:", revised_contour) + print("Remaining mask region size in voxels:", np.count_nonzero(mask)) - # use opening to remove small artifacts + # use opening to remove small holes mask = opening(mask.astype(bool), ball(3)) - new_data = map_data * mask - new_data_non_zero = new_data[np.nonzero(new_data)] + masked_map_data = map_data * mask + new_data_non_zero = masked_map_data[np.nonzero(masked_map_data)] - # save the new data - save_mrc(input_map_path, new_data, - os.path.join(output_folder, Path(input_map_path).stem + "_mask.mrc")) + # calcualte new gradient norm + new_fit_data = gen_features(masked_map_data) + print("Fitting feature shape:", new_fit_data.shape) - # if use_grad == True: - # # use 1 sigma cutoff from the masked data - # # revised_contour = np.mean(new_data_non_zero) + np.std(new_data_non_zero) - # # use median cutoff from the masked data, could be other percentile - # revised_contour = np.percentile(new_data_non_zero, 50) - # else: - # revised_contour = np.min(new_data[new_data > 1e-8]) + # fit the GMM again on the new data + g2 = mixture.BayesianGaussianMixture(n_components=2, max_iter=500, n_init=n_init, tol=1e-2) + g2.fit(new_fit_data) - mask_percent = np.count_nonzero(new_data > 1e-8) / np.count_nonzero(map_data > 1e-8) + # predict the new data + new_preds = g2.predict(new_fit_data) + # ind_noise_second = np.argmin(np.abs(g2.means_[:, 0].flatten())) + ind_noise_second = np.argmin(g2.covariances_[:, 0, 0].flatten()) + noise_comp_2 = masked_map_data[np.nonzero(masked_map_data)][new_preds == ind_noise_second] + prot_comp_2 = masked_map_data[np.nonzero(masked_map_data)][new_preds != ind_noise_second] + # revised_contour_2 = np.median(prot_comp_2) + revised_contour_2 = np.max(noise_comp_2) + + print("Revised contour (high):", revised_contour_2) + + # save the new data + save_mrc(input_map_path, masked_map_data, os.path.join(output_folder, Path(input_map_path).stem + "_mask.mrc")) + + mask_percent = np.count_nonzero(masked_map_data > 1e-8) / np.count_nonzero(map_data > 1e-8) # plot the histogram fig, ax = plt.subplots(figsize=(10, 2)) ax.hist(non_zero_data.flatten(), alpha=0.5, bins=256, density=False, log=True, label="Original") ax.hist(new_data_non_zero.flatten(), alpha=0.5, bins=256, density=False, log=True, label="Masked") ax.axvline(revised_contour, label="Revised Contour") + ax.axvline(revised_contour_2, label="Revised Contour (High)", linestyle="dashed") ax.legend() plt.title(input_map_path) plt.savefig(os.path.join(output_folder, Path(input_map_path).stem + "_hist_overall.png")) @@ -188,7 +206,9 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_ out_txt = os.path.join(output_folder, Path(input_map_path).stem + "_revised_contour.txt") with open(out_txt, "w") as f: - f.write(f"{revised_contour} {mask_percent}") + f.write(f"Revised contour: {revised_contour}\n") + f.write(f"Revised contour (high): {revised_contour_2}\n") + f.write(f"Masked percentage: {mask_percent}\n") # return revised contour level and mask percent return revised_contour, mask_percent @@ -201,7 +221,13 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_ parser.add_argument("-i", "--input_map_path", type=str, default=None) parser.add_argument("-o", "--output_folder", type=str, default=None) parser.add_argument("-p", "--plot_all", action="store_true") - parser.add_argument("-n", "--num_components", type=int, default=3) + parser.add_argument("-n", "--num_components", type=int, default=2) args = parser.parse_args() - revised_contour, mask_percent = gmm_mask(input_map_path=args.input_map_path, output_folder=args.output_folder, - num_components=3, use_grad=True, n_init=3, plot_all=args.plot_all) + revised_contour, mask_percent = gmm_mask( + input_map_path=args.input_map_path, + output_folder=args.output_folder, + num_components=args.num_components, + use_grad=True, + n_init=3, + plot_all=args.plot_all, + )