Skip to content

Commit

Permalink
update gmm contour
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Zhu committed Aug 27, 2024
1 parent b59b995 commit 5949b7d
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 62 deletions.
2 changes: 1 addition & 1 deletion CryoREAD/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def init_save_path(origin_map_path):
mask_map_path = os.path.join(save_path, "mask_protein.mrc")
from data_processing.Gen_MaskDRNA_map import Gen_MaskProtein_map

Gen_MaskProtein_map(chain_prob, cur_map_path, mask_map_path, params["contour"], threshold=0.3)
Gen_MaskProtein_map(chain_prob, cur_map_path, mask_map_path, params["contour"], threshold=0.2)
if params["prediction_only"]:
print(
"Our prediction results are saved in %s with mrc format for visualization check." % save_path_2nd_stage)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ python main.py -F ./Class3D/job052/class1.mrc ./Class3D/job052/class2.mrc ./Clas
-n: Number of intializations (Optional, 3 by default)
```

## Example for generating
## Example for Auto Contouring

```
python gmm_contour.py -i ./Class3D/job052/class1.mrc -o ./output_folder -p
Expand Down
146 changes: 86 additions & 60 deletions gmm_contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,26 @@ def save_mrc(orig_map_path, data, out_path):
mrc.flush()


def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_init=1, plot_all=False):
def gen_features(map_array):
non_zero_data = map_array[np.nonzero(map_array)]
data_normalized = (map_array - map_array.min()) * 2 / (map_array.max() - map_array.min()) - 1
local_grad_norm = rank.gradient(img_as_ubyte(data_normalized), ball(3))
local_grad_norm = local_grad_norm[np.nonzero(map_array)]
local_grad_norm = (local_grad_norm - local_grad_norm.min()) / (local_grad_norm.max() - local_grad_norm.min())
non_zero_data_normalized = (non_zero_data - non_zero_data.min()) / (non_zero_data.max() - non_zero_data.min())

# stack the flattened data and gradient
local_grad_norm = np.reshape(local_grad_norm, (-1, 1))
non_zero_data_normalized = np.reshape(non_zero_data_normalized, (-1, 1))
features = np.hstack((non_zero_data_normalized, local_grad_norm))

return features


def gmm_mask(input_map_path, output_folder, num_components=2, use_grad=False, n_init=1, plot_all=False):
print("input_map_path", input_map_path)
print("output_folder", output_folder)

# if os.path.exists(output_folder):
# # print("Output file already exists")
# raise ValueError("Output FOLD already exists")
# return None, None

os.makedirs(output_folder, exist_ok=True)

print("Opening map file")
Expand All @@ -55,15 +66,14 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_

# Zooming to handling large maps
if len(non_zero_data) >= 5e6:
print("Map is too large")
print("Map is too large, resizing...")

# resample
zoom_factor = (2e6 / len(non_zero_data)) ** (1 / 3)
print("Resample with zoom factor:", zoom_factor)

map_data_zoomed = zoom(map_data, zoom_factor, order=3, mode="grid-constant", grid_mode=True)
data_normalized_zoomed = (map_data_zoomed - map_data_zoomed.min()) * 2 / (
map_data_zoomed.max() - map_data_zoomed.min()) - 1
map_data_zoomed = zoom(map_data, zoom_factor, order=3, mode="grid-constant", grid_mode=False)
data_normalized_zoomed = (map_data_zoomed - map_data_zoomed.min()) * 2 / (map_data_zoomed.max() - map_data_zoomed.min()) - 1
non_zero_data_zoomed = map_data_zoomed[np.nonzero(map_data_zoomed)]

print("Shape after resample:", data_normalized_zoomed.shape)
Expand All @@ -72,11 +82,11 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_
local_grad_norm_zoomed = rank.gradient(img_as_ubyte(data_normalized_zoomed), ball(3))
local_grad_norm_zoomed = local_grad_norm_zoomed[np.nonzero(map_data_zoomed)]
local_grad_norm_zoomed = (local_grad_norm_zoomed - local_grad_norm_zoomed.min()) / (
local_grad_norm_zoomed.max() - local_grad_norm_zoomed.min()
local_grad_norm_zoomed.max() - local_grad_norm_zoomed.min()
)

non_zero_data_normalized_zoomed = (non_zero_data_zoomed - non_zero_data_zoomed.min()) / (
non_zero_data_zoomed.max() - non_zero_data_zoomed.min()
non_zero_data_zoomed.max() - non_zero_data_zoomed.min()
)

local_grad_norm_zoomed = np.reshape(local_grad_norm_zoomed, (-1, 1))
Expand All @@ -100,7 +110,7 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_
print("Fitting GMM")

# fit the GMM
g = mixture.BayesianGaussianMixture(n_components=num_components, max_iter=200, n_init=n_init, tol=1e-2)
g = mixture.BayesianGaussianMixture(n_components=num_components, max_iter=500, n_init=n_init, tol=1e-2)

if use_grad:
data_to_fit = data_zoomed if len(non_zero_data) >= 5e6 else data
Expand All @@ -113,82 +123,92 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_

if plot_all:
fig, ax = plt.subplots(1, 1, figsize=(10, 3))
all_datas = []
for pred in np.unique(preds):
mask = np.zeros_like(map_data)
mask[np.nonzero(map_data)] = preds == pred
new_data = map_data * mask
new_data_non_zero = new_data[np.nonzero(new_data)]
ax.hist(new_data_non_zero.flatten(), alpha=0.5, bins=256, density=False, log=True, label=f"Masked_{pred}")
# plot mean
# mean = g.means_[pred, 0]
# ax.axvline(mean, label=f"Mean_{pred}")
ax.legend(loc="upper right")
masked_map_data = map_data * mask
new_data_non_zero = masked_map_data[np.nonzero(masked_map_data)]
all_datas.append(new_data_non_zero.flatten())
mean = np.mean(new_data_non_zero)
ax.axvline(mean, linestyle="--", color="k", label=f"Mean_{pred}")
labels = [f"Component {i}" for i in range(num_components)]
ax.hist(all_datas, alpha=0.5, bins=256, density=True, log=True, label=labels, stacked=True)
ax.set_yscale("log")
ax.legend(loc="upper right")
ax.set_xlabel("Map Density Value")
ax.set_ylabel("Density (log scale)")
ax.set_title("Stacked Histogram by Component")
fig.tight_layout()
# print("Saving figure to", os.path.join(output_folder, "hist_by_component.png"))
fig.savefig(os.path.join(output_folder, Path(input_map_path).stem + "_hist_by_components.png"))

# generate a mask to keep only the component with the largest variance
mask = np.zeros_like(map_data)
# mask[np.nonzero(masked_prot_data)] = (preds == np.argmax(g.means_[:, 0].flatten()))

# ind = np.argpartition(g.means_[:, 0].flatten(), -3)[-3:]
# choose ind that is closest to 0
ind = np.argmin(np.abs(g.means_[:, 0].flatten()))
# choose ind that is closest to 0, and ind that has the highest mean
ind_noise = np.argmin(np.abs(g.means_[:, 0].flatten()))
# ind_max = np.argmax(g.covariances_[:, 0, 0].flatten())

print("ind to remove", ind)
print("Means: ", g.means_.shape, g.means_[:, 0], g.means_[:, 1])

# mask[np.nonzero(map_data)] = preds in ind
print(
"Means: ",
g.means_.shape,
g.means_[:, 0],
)
print("Variances: ", g.covariances_.shape, g.covariances_[:, 0, 0])
# print("Variances: ", g.covariances_.shape, g.covariances_[ind_noise, 0, 0])
# print("Std: ", np.sqrt(g.covariances_[ind_noise, 0, 0]))

# mask[np.nonzero(map_data)] = (preds == ind[0]) | (preds == ind[1]) | (preds == ind[2])
mask[np.nonzero(map_data)] = (preds != ind)
# generate a mask to keep only the component without the noise
mask = np.zeros_like(map_data)
mask[np.nonzero(map_data)] = preds != ind_noise

noise_comp = map_data[np.nonzero(map_data)][preds == ind]
# 98 percentile
# revised_contour = np.percentile(noise_comp, 98)
noise_comp = map_data[np.nonzero(map_data)][preds == ind_noise]
revised_contour = np.max(noise_comp)

print("Revised contour", revised_contour)
prot_comp = map_data[np.nonzero(map_data)][preds != ind_noise]

print("Remaining mask region size in voxels", np.count_nonzero(mask))
print("Revised contour:", revised_contour)
print("Remaining mask region size in voxels:", np.count_nonzero(mask))

# use opening to remove small artifacts
# use opening to remove small holes
mask = opening(mask.astype(bool), ball(3))
new_data = map_data * mask
new_data_non_zero = new_data[np.nonzero(new_data)]
masked_map_data = map_data * mask
new_data_non_zero = masked_map_data[np.nonzero(masked_map_data)]

# save the new data
save_mrc(input_map_path, new_data,
os.path.join(output_folder, Path(input_map_path).stem + "_mask.mrc"))
# calcualte new gradient norm
new_fit_data = gen_features(masked_map_data)
print("Fitting feature shape:", new_fit_data.shape)

# if use_grad == True:
# # use 1 sigma cutoff from the masked data
# # revised_contour = np.mean(new_data_non_zero) + np.std(new_data_non_zero)
# # use median cutoff from the masked data, could be other percentile
# revised_contour = np.percentile(new_data_non_zero, 50)
# else:
# revised_contour = np.min(new_data[new_data > 1e-8])
# fit the GMM again on the new data
g2 = mixture.BayesianGaussianMixture(n_components=2, max_iter=500, n_init=n_init, tol=1e-2)
g2.fit(new_fit_data)

mask_percent = np.count_nonzero(new_data > 1e-8) / np.count_nonzero(map_data > 1e-8)
# predict the new data
new_preds = g2.predict(new_fit_data)
# ind_noise_second = np.argmin(np.abs(g2.means_[:, 0].flatten()))
ind_noise_second = np.argmin(g2.covariances_[:, 0, 0].flatten())
noise_comp_2 = masked_map_data[np.nonzero(masked_map_data)][new_preds == ind_noise_second]
prot_comp_2 = masked_map_data[np.nonzero(masked_map_data)][new_preds != ind_noise_second]
# revised_contour_2 = np.median(prot_comp_2)
revised_contour_2 = np.max(noise_comp_2)

print("Revised contour (high):", revised_contour_2)

# save the new data
save_mrc(input_map_path, masked_map_data, os.path.join(output_folder, Path(input_map_path).stem + "_mask.mrc"))

mask_percent = np.count_nonzero(masked_map_data > 1e-8) / np.count_nonzero(map_data > 1e-8)

# plot the histogram
fig, ax = plt.subplots(figsize=(10, 2))
ax.hist(non_zero_data.flatten(), alpha=0.5, bins=256, density=False, log=True, label="Original")
ax.hist(new_data_non_zero.flatten(), alpha=0.5, bins=256, density=False, log=True, label="Masked")
ax.axvline(revised_contour, label="Revised Contour")
ax.axvline(revised_contour_2, label="Revised Contour (High)", linestyle="dashed")
ax.legend()
plt.title(input_map_path)
plt.savefig(os.path.join(output_folder, Path(input_map_path).stem + "_hist_overall.png"))

out_txt = os.path.join(output_folder, Path(input_map_path).stem + "_revised_contour.txt")

with open(out_txt, "w") as f:
f.write(f"{revised_contour} {mask_percent}")
f.write(f"Revised contour: {revised_contour}\n")
f.write(f"Revised contour (high): {revised_contour_2}\n")
f.write(f"Masked percentage: {mask_percent}\n")

# return revised contour level and mask percent
return revised_contour, mask_percent
Expand All @@ -201,7 +221,13 @@ def gmm_mask(input_map_path, output_folder, num_components=3, use_grad=False, n_
parser.add_argument("-i", "--input_map_path", type=str, default=None)
parser.add_argument("-o", "--output_folder", type=str, default=None)
parser.add_argument("-p", "--plot_all", action="store_true")
parser.add_argument("-n", "--num_components", type=int, default=3)
parser.add_argument("-n", "--num_components", type=int, default=2)
args = parser.parse_args()
revised_contour, mask_percent = gmm_mask(input_map_path=args.input_map_path, output_folder=args.output_folder,
num_components=3, use_grad=True, n_init=3, plot_all=args.plot_all)
revised_contour, mask_percent = gmm_mask(
input_map_path=args.input_map_path,
output_folder=args.output_folder,
num_components=args.num_components,
use_grad=True,
n_init=3,
plot_all=args.plot_all,
)

0 comments on commit 5949b7d

Please sign in to comment.