From 0a4d49b68703465836cd47a29531bdb1781f3e3d Mon Sep 17 00:00:00 2001 From: Han Zhu Date: Wed, 8 May 2024 13:22:05 -0400 Subject: [PATCH] update resize and unify --- CryoREAD/data_processing/Resize_Map.py | 509 ++----------------------- CryoREAD/data_processing/Unify_Map.py | 21 +- 2 files changed, 54 insertions(+), 476 deletions(-) diff --git a/CryoREAD/data_processing/Resize_Map.py b/CryoREAD/data_processing/Resize_Map.py index ec244b3..8d21bcc 100644 --- a/CryoREAD/data_processing/Resize_Map.py +++ b/CryoREAD/data_processing/Resize_Map.py @@ -1,420 +1,48 @@ -import os import mrcfile import numpy as np import torch import torch.nn.functional as F -from multiprocessing import Pool, Lock -from multiprocessing.sharedctypes import Value, Array +import argparse -# global size,data,iterator,s -from numba import jit - -# https://github.com/jglaser/interp3d -import shutil -from numba_progress import ProgressBar - - -@jit(nopython=True, nogil=True) -def interpolate_fast(data, data_new, size, iterator1, iterator2, iterator3, prev_voxel_size, progress_proxy): - - for i in range(1, iterator1, 1): - progress_proxy.update(1) - # if(i%1==0): - # prefix="interpolating" - # display_size=60 - # x = int(display_size*i/iterator1) - # line=["#" for k in range(x)] - # #line="".join(line) - # line1=["." for k in range(size-x)] - # #line1="".join(line1) - # print(prefix, "[",line,line1,"] (", i, iterator1,")", - # end='\r', flush=True) - # print("Finished",i,iterator1) - for j in range(1, iterator2, 1): - for k in range(1, iterator3, 1): - count = [int(i / prev_voxel_size), int(j / prev_voxel_size), int(k / prev_voxel_size)] - e1 = count[0] + 1 - e2 = count[1] + 1 - e3 = count[2] + 1 - - if count[0] >= size[0] - 1: # or count[1]>=size[1]-1 or count[2]>=size[2]-1 ): - # print(count) - e1 = count[0] - continue - if count[1] >= size[1] - 1: - e2 = count[1] - continue - if count[2] >= size[2] - 1: - e3 = count[2] - continue - diff1 = [i - count[0] * prev_voxel_size, j - count[1] * prev_voxel_size, k - count[2] * prev_voxel_size] - diff2 = [e1 * prev_voxel_size - i, e2 * prev_voxel_size - j, e3 * prev_voxel_size - k] - # print(diff) - val1 = data[count[0], count[1], count[2]] - val2 = data[e1, count[1], count[2]] - val3 = data[e1, e2, count[2]] - val4 = data[count[0], e2, count[2]] - val5 = data[count[0], count[1], e3] - val6 = data[e1, count[1], e3] - val7 = data[e1, e2, e3] - val8 = data[count[0], e2, e3] - # val = (val1 + diff[0] * (val2 - val1) + diff[1] * (val4 - val1) + diff[2] * (val5 - val1) + diff[0] * - # diff[1] * (val1 - val2 + val3 - val4) + diff[0] * diff[2] * (val1 - val2 - val5 + val6) + diff[ - # 1] * diff[2] * ( - # val1 - val4 - val5 + val8) + diff[0] * diff[1] * diff[2] * ( - # val8 - val7 + val6 - val5 + val4 - val3 + val2 - val1)) - u1 = diff1[0] - u2 = diff2[0] - v1 = diff1[1] - v2 = diff2[1] - w1 = diff1[2] - w2 = diff2[2] - val = ( - ( - w2 * (v1 * (u1 * val3 + u2 * val4) + v2 * (u1 * val2 + u2 * val1)) - + w1 * (v1 * (u1 * val7 + u2 * val8) + v2 * (u1 * val6 + u2 * val5)) - ) - / (w1 + w2) - / (v1 + v2) - / (u1 + u2) - ) - data_new[i, j, k] = val - return data_new # np.float32(data_new) - - -@jit(nopython=True, nogil=True) -def interpolate_fast_general(data, data_new, size, iterator1, iterator2, iterator3, prev_voxel_size1, prev_voxel_size2, prev_voxel_size3): - for i in range(1, iterator1, 1): - # if(i%10==0): - # print("Finished",i,iterator1) - for j in range(1, iterator2, 1): - for k in range(1, iterator3, 1): - count = [int(i / prev_voxel_size1), int(j / prev_voxel_size2), int(k / prev_voxel_size3)] - e1 = count[0] + 1 - e2 = count[1] + 1 - e3 = count[2] + 1 - - if count[0] >= size[0] - 1: # or count[1]>=size[1]-1 or count[2]>=size[2]-1 ): - # print(count) - e1 = count[0] - continue - if count[1] >= size[1] - 1: - e2 = count[1] - continue - if count[2] >= size[2] - 1: - e3 = count[2] - continue - diff1 = [i - count[0] * prev_voxel_size1, j - count[1] * prev_voxel_size2, k - count[2] * prev_voxel_size3] - diff2 = [e1 * prev_voxel_size1 - i, e2 * prev_voxel_size2 - j, e3 * prev_voxel_size3 - k] - # print(diff) - val1 = data[count[0], count[1], count[2]] - val2 = data[e1, count[1], count[2]] - val3 = data[e1, e2, count[2]] - val4 = data[count[0], e2, count[2]] - val5 = data[count[0], count[1], e3] - val6 = data[e1, count[1], e3] - val7 = data[e1, e2, e3] - val8 = data[count[0], e2, e3] - # val = (val1 + diff[0] * (val2 - val1) + diff[1] * (val4 - val1) + diff[2] * (val5 - val1) + diff[0] * - # diff[1] * (val1 - val2 + val3 - val4) + diff[0] * diff[2] * (val1 - val2 - val5 + val6) + diff[ - # 1] * diff[2] * ( - # val1 - val4 - val5 + val8) + diff[0] * diff[1] * diff[2] * ( - # val8 - val7 + val6 - val5 + val4 - val3 + val2 - val1)) - u1 = diff1[0] - u2 = diff2[0] - v1 = diff1[1] - v2 = diff2[1] - w1 = diff1[2] - w2 = diff2[2] - val = ( - ( - w2 * (v1 * (u1 * val3 + u2 * val4) + v2 * (u1 * val2 + u2 * val1)) - + w1 * (v1 * (u1 * val7 + u2 * val8) + v2 * (u1 * val6 + u2 * val5)) - ) - / (w1 + w2) - / (v1 + v2) - / (u1 + u2) - ) - data_new[i, j, k] = val - return data_new # np.float32(data_new) - - -def interpolate_slow(data, data_new, size, iterator1, iterator2, iterator3, prev_voxel_size, progress_proxy): - for i in range(1, iterator1, 1): - progress_proxy.update(1) - # if(i%10==0): - # print("Finished",i,iterator1) - for j in range(1, iterator2, 1): - for k in range(1, iterator3, 1): - count = [int(i / prev_voxel_size), int(j / prev_voxel_size), int(k / prev_voxel_size)] - e1 = count[0] + 1 - e2 = count[1] + 1 - e3 = count[2] + 1 - - if count[0] >= size[0] - 1: # or count[1]>=size[1]-1 or count[2]>=size[2]-1 ): - # print(count) - e1 = count[0] - continue - if count[1] >= size[1] - 1: - e2 = count[1] - continue - if count[2] >= size[2] - 1: - e3 = count[2] - continue - diff1 = [i - count[0] * prev_voxel_size, j - count[1] * prev_voxel_size, k - count[2] * prev_voxel_size] - diff2 = [e1 * prev_voxel_size - i, e2 * prev_voxel_size - j, e3 * prev_voxel_size - k] - # print(diff) - val1 = data[count[0], count[1], count[2]] - val2 = data[e1, count[1], count[2]] - val3 = data[e1, e2, count[2]] - val4 = data[count[0], e2, count[2]] - val5 = data[count[0], count[1], e3] - val6 = data[e1, count[1], e3] - val7 = data[e1, e2, e3] - val8 = data[count[0], e2, e3] - # val = (val1 + diff[0] * (val2 - val1) + diff[1] * (val4 - val1) + diff[2] * (val5 - val1) + diff[0] * - # diff[1] * (val1 - val2 + val3 - val4) + diff[0] * diff[2] * (val1 - val2 - val5 + val6) + diff[ - # 1] * diff[2] * ( - # val1 - val4 - val5 + val8) + diff[0] * diff[1] * diff[2] * ( - # val8 - val7 + val6 - val5 + val4 - val3 + val2 - val1)) - u1 = diff1[0] - u2 = diff2[0] - v1 = diff1[1] - v2 = diff2[1] - w1 = diff1[2] - w2 = diff2[2] - val = ( - ( - w2 * (v1 * (u1 * val3 + u2 * val4) + v2 * (u1 * val2 + u2 * val1)) - + w1 * (v1 * (u1 * val7 + u2 * val8) + v2 * (u1 * val6 + u2 * val5)) - ) - / (w1 + w2) - / (v1 + v2) - / (u1 + u2) - ) - data_new[i, j, k] = val - return data_new - - -def Reform_Map_Voxel_Final(map_path, new_map_path): - from scipy.interpolate import RegularGridInterpolator - - if not os.path.exists(new_map_path): - with mrcfile.open(map_path, permissive=True) as mrc: - prev_voxel_size = mrc.voxel_size - prev_voxel_size_x = float(prev_voxel_size["x"]) - prev_voxel_size_y = float(prev_voxel_size["y"]) - prev_voxel_size_z = float(prev_voxel_size["z"]) - nx, ny, nz, nxs, nys, nzs, mx, my, mz = ( - mrc.header.nx, - mrc.header.ny, - mrc.header.nz, - mrc.header.nxstart, - mrc.header.nystart, - mrc.header.nzstart, - mrc.header.mx, - mrc.header.my, - mrc.header.mz, - ) - orig = mrc.header.origin - print("Origin:", orig) - print("Previous voxel size:", prev_voxel_size) - print("nx, ny, nz", nx, ny, nz) - print("nxs,nys,nzs", nxs, nys, nzs) - print("mx,my,mz", mx, my, mz) - data = mrc.data - data = np.swapaxes(data, 0, 2) - size = np.shape(data) - x = np.arange(size[0]) - y = np.arange(size[1]) - z = np.arange(size[2]) - my_interpolating_function = RegularGridInterpolator((x, y, z), data) - it_val_x = int(np.floor(size[0] * prev_voxel_size_x)) - it_val_y = int(np.floor(size[1] * prev_voxel_size_y)) - it_val_z = int(np.floor(size[2] * prev_voxel_size_z)) - print("Previouse size:", size, " Current map size:", [it_val_x, it_val_y, it_val_z]) - data_new = np.zeros([it_val_x, it_val_y, it_val_z]) - # from ops.progressbar import progressbar - from progress.bar import Bar - - bar = Bar("Preparing Input: ", max=int(it_val_x)) - for i in range(it_val_x): # progressbar(range(it_val_x), prefix="", size=60): - # if i%10==0: - # print("Resizing finished %d/%d"%(i,it_val_x)) - bar.next() - for j in range(it_val_y): - for k in range(it_val_z): - if i / prev_voxel_size_x >= size[0] - 1: - x_query = size[0] - 1 - else: - x_query = i / prev_voxel_size_x - - if j / prev_voxel_size_y >= size[1] - 1: - y_query = size[1] - 1 - else: - y_query = j / prev_voxel_size_y - if k / prev_voxel_size_z >= size[2] - 1: - z_query = size[2] - 1 - else: - z_query = k / prev_voxel_size_z - current_query = np.array([x_query, y_query, z_query]) - current_value = float(my_interpolating_function(current_query)) - data_new[i, j, k] = current_value - bar.finish() - data_new = np.swapaxes(data_new, 0, 2) - data_new = np.float32(data_new) - mrc_new = mrcfile.new(new_map_path, data=data_new, overwrite=True) - vsize = mrc_new.voxel_size - vsize.flags.writeable = True - vsize.x = 1.0 - vsize.y = 1.0 - vsize.z = 1.0 - mrc_new.voxel_size = vsize - mrc_new.update_header_from_data() - mrc_new.header.nxstart = nxs * prev_voxel_size_x - mrc_new.header.nystart = nys * prev_voxel_size_y - mrc_new.header.nzstart = nzs * prev_voxel_size_z - mrc_new.header.mapc = mrc.header.mapc - mrc_new.header.mapr = mrc.header.mapr - mrc_new.header.maps = mrc.header.maps - mrc_new.header.origin = orig - mrc_new.update_header_stats() - mrc.print_header() - mrc_new.print_header() - mrc_new.close() - del data - del data_new - return new_map_path +def my_reform_1a(input_mrc, output_mrc, use_gpu=False): + with torch.no_grad() and torch.cuda.amp.autocast(enabled=use_gpu): -def Reform_Map_Voxel(map_path, new_map_path): - if not os.path.exists(new_map_path): - with mrcfile.open(map_path, permissive=True) as mrc: - prev_voxel_size = mrc.voxel_size - # assert len(prev_voxel_size)==3 + with mrcfile.open(input_mrc, permissive=True) as orig_map: - if not (prev_voxel_size["x"] == prev_voxel_size["y"] and prev_voxel_size["x"] == prev_voxel_size["z"]): - print( - "Grid size of different axis is different, please specify --resize=1 in command line to call another slow process to deal with it!" - ) - exit(1) - prev_voxel_size = float(prev_voxel_size["x"]) - nx, ny, nz, nxs, nys, nzs, mx, my, mz = ( - mrc.header.nx, - mrc.header.ny, - mrc.header.nz, - mrc.header.nxstart, - mrc.header.nystart, - mrc.header.nzstart, - mrc.header.mx, - mrc.header.my, - mrc.header.mz, - ) - orig = mrc.header.origin - print("Origin:", orig) - print("Previous voxel size:", prev_voxel_size) - data = mrc.data - data = np.swapaxes(data, 0, 2) - size = np.shape(data) - if prev_voxel_size == 1: - shutil.copy(map_path, new_map_path) - return new_map_path - if prev_voxel_size < 1: - print( - "Grid size is smaller than 1, please specify --resize=1 in command line to call another slow process to deal with it!" - ) - exit(1) - it_val1 = int(np.floor(size[0] * prev_voxel_size)) - it_val2 = int(np.floor(size[1] * prev_voxel_size)) - it_val3 = int(np.floor(size[2] * prev_voxel_size)) - print("Previouse size:", size, " Current map size:", it_val1, it_val2, it_val3) - data_new = np.zeros([it_val1, it_val2, it_val3]) - data_new[0, 0, 0] = data[0, 0, 0] - data_new[it_val1 - 1, it_val2 - 1, it_val3 - 1] = data[size[0] - 1, size[1] - 1, size[2] - 1] - # iterator = Value('i', it_val) - # s = Value('d', float(prev_voxel_size)) - # pool = Pool(3) - # out_1d = pool.map(interpolate,enumerate(np.reshape(data_new, (iterator.value * iterator.value * iterator.value,)))) - # data_new = np.array(out_1d).reshape(iterator.value, iterator.value, iterator.value) - try: - with ProgressBar(total=it_val1) as progress: - data_new = interpolate_fast(data, data_new, size, it_val1, it_val2, it_val3, prev_voxel_size, progress) - except: - data_new = np.zeros([it_val1, it_val2, it_val3]) - data_new[0, 0, 0] = data[0, 0, 0] - data_new[it_val1 - 1, it_val2 - 1, it_val3 - 1] = data[size[0] - 1, size[1] - 1, size[2] - 1] - with ProgressBar(total=it_val1) as progress: - data_new = interpolate_slow(data, data_new, size, it_val1, it_val2, it_val3, prev_voxel_size, progress) - data_new = np.swapaxes(data_new, 0, 2) - data_new = np.float32(data_new) - mrc_new = mrcfile.new(new_map_path, data=data_new, overwrite=True) - vsize = mrc_new.voxel_size - vsize.flags.writeable = True - vsize.x = 1.0 - vsize.y = 1.0 - vsize.z = 1.0 - mrc_new.voxel_size = vsize - mrc_new.update_header_from_data() - mrc_new.header.nxstart = nxs * prev_voxel_size - mrc_new.header.nystart = nys * prev_voxel_size - mrc_new.header.nzstart = nzs * prev_voxel_size - mrc_new.header.mapc = mrc.header.mapc - mrc_new.header.mapr = mrc.header.mapr - mrc_new.header.maps = mrc.header.maps - mrc_new.header.origin = orig - mrc_new.update_header_stats() - mrc.print_header() - mrc_new.print_header() - mrc_new.close() - del data - del data_new - # del out_1d - return new_map_path + orig_voxel_size = np.array([orig_map.voxel_size.x, orig_map.voxel_size.y, orig_map.voxel_size.z]) + orig_data = torch.from_numpy(orig_map.data.copy()).unsqueeze(0).unsqueeze(0) -def Resize_Map(input_map_path, new_map_path): - try: - my_reform_1a(input_map_path, new_map_path, use_gpu=True) - except: - print("GPU reform failed, falling back to CPU") - my_reform_1a(input_map_path, new_map_path, use_gpu=False) - # try: - # Reform_Map_Voxel(input_map_path, new_map_path) - # except: - # try: - # Reform_Map_Voxel_Final(input_map_path, new_map_path) - # except: - # exit() - return new_map_path + orig_data = orig_data.cuda() if use_gpu else orig_data + print("Previous shape (ZYX): ", orig_data.shape) + print("Previous voxel size (ZXY): ", np.array([orig_map.voxel_size.z, orig_map.voxel_size.y, orig_map.voxel_size.x])) -def my_reform_1a(input_mrc, output_mrc, use_gpu=False): + # orig = np.array([orig_map.header.origin.x, orig_map.header.origin.y, orig_map.header.origin.z]) - with torch.no_grad() and torch.cuda.amp.autocast(enabled=use_gpu, dtype=torch.float16): + new_grid_size = np.array(orig_data.shape[2:]) * np.array([orig_map.voxel_size.z, orig_map.voxel_size.y, orig_map.voxel_size.x]) - with mrcfile.open(input_mrc, permissive=True) as orig_map: - voxel_size = np.array([orig_map.voxel_size.x, orig_map.voxel_size.y, orig_map.voxel_size.z]) + # print("New grid size (ZYX): ", new_grid_size) - # orig_data = torch.from_numpy(orig_map.data.copy().transpose((2, 1, 0))).unsqueeze(0).unsqueeze(0) - orig_data = torch.from_numpy(orig_map.data.copy()).unsqueeze(0).unsqueeze(0) + new_grid_size = np.floor(new_grid_size).astype(np.int32) # ZYX - orig_data = orig_data.cuda() if use_gpu else orig_data + print("New grid size (ZYX): ", new_grid_size) - print("Previous shape: ", orig_data.shape) - # orig = np.array([orig_map.header.origin.x, orig_map.header.origin.y, orig_map.header.origin.z]) + # for compatibility with torch 1.9 and below + kwargs = {"indexing": "ij"} if (torch.__version__.split(".")[0] >= "2" or torch.__version__.split(".")[0] >= "10") else {} - new_grid_size = np.floor(np.array(orig_data.shape[2:]) * voxel_size).astype(np.int32) + z = torch.arange(0, new_grid_size[0], device="cuda" if use_gpu else "cpu") / orig_voxel_size[2] / (orig_data.shape[2] - 1) * 2 - 1 + y = torch.arange(0, new_grid_size[1], device="cuda" if use_gpu else "cpu") / orig_voxel_size[1] / (orig_data.shape[3] - 1) * 2 - 1 + x = torch.arange(0, new_grid_size[2], device="cuda" if use_gpu else "cpu") / orig_voxel_size[0] / (orig_data.shape[4] - 1) * 2 - 1 - # Voodoo magic + # noinspection PyArgumentList new_grid = torch.stack( torch.meshgrid( - torch.arange(0, new_grid_size[2], device="cuda" if use_gpu else "cpu") / voxel_size[2] / (orig_data.shape[4] - 1) * 2 - - 1, - torch.arange(0, new_grid_size[1], device="cuda" if use_gpu else "cpu") / voxel_size[1] / (orig_data.shape[3] - 1) * 2 - - 1, - torch.arange(0, new_grid_size[0], device="cuda" if use_gpu else "cpu") / voxel_size[0] / (orig_data.shape[2] - 1) * 2 - - 1, - indexing="ij", + x, + y, + z, + **kwargs ), dim=-1, ) @@ -423,8 +51,8 @@ def my_reform_1a(input_mrc, output_mrc, use_gpu=False): new_data = F.grid_sample(orig_data, new_grid, mode="bilinear", align_corners=True).cpu().numpy()[0, 0] new_voxel_size = np.array((1.0, 1.0, 1.0)) - print("Real voxel size: ", new_voxel_size) - print("New shape: ", new_data.shape) + # print("Real voxel size: ", new_voxel_size) + # print("New shape: ", new_data.shape) new_data = new_data.transpose((2, 1, 0)) @@ -436,9 +64,9 @@ def my_reform_1a(input_mrc, output_mrc, use_gpu=False): vox_sizes.z = new_voxel_size[2] mrc.voxel_size = vox_sizes mrc.update_header_from_data() - mrc.header.nxstart = orig_map.header.nxstart * voxel_size[0] - mrc.header.nystart = orig_map.header.nystart * voxel_size[1] - mrc.header.nzstart = orig_map.header.nzstart * voxel_size[2] + mrc.header.nxstart = orig_map.header.nxstart * orig_voxel_size[0] + mrc.header.nystart = orig_map.header.nystart * orig_voxel_size[1] + mrc.header.nzstart = orig_map.header.nzstart * orig_voxel_size[2] mrc.header.origin = orig_map.header.origin mrc.header.mapc = orig_map.header.mapc mrc.header.mapr = orig_map.header.mapr @@ -446,73 +74,18 @@ def my_reform_1a(input_mrc, output_mrc, use_gpu=False): mrc.update_header_stats() mrc.flush() +def Resize_Map(input_map_path,new_map_path): + try: + my_reform_1a(input_map_path, new_map_path, use_gpu=True) + except: + print("GPU reform failed, falling back to CPU") + my_reform_1a(input_map_path, new_map_path, use_gpu=False) + return new_map_path if __name__ == "__main__": - import sys - - def progressbar(it, prefix="", size=60, out=sys.stdout): # Python3.3+ - count = len(it) - - def show(j): - x = int(size * j / count) - print("{}[{}{}] {}/{}".format(prefix, "#" * x, "." * (size - x), j, count), end="\r", file=out, flush=True) - - show(0) - for i, item in enumerate(it): - yield item - show(i + 1) - print("\n", flush=True, file=out) - - from numba_progress import ProgressBar - - data_new = np.zeros([200, 200, 200]) - data = np.zeros([100, 100, 100]) - size = [100, 100, 100] - it_val1 = 200 - it_val2 = 200 - it_val3 = 200 - prev_voxel_size = 0.5 - from scipy.interpolate import RegularGridInterpolator - - with ProgressBar(total=it_val1) as progress: - interpolate_fast(data, data_new, size, it_val1, it_val2, it_val3, prev_voxel_size, progress) - with ProgressBar(total=it_val1) as progress: - data_new2 = interpolate_slow(data, data_new, size, it_val1, it_val2, it_val3, prev_voxel_size, progress) - # from ops.progressbar import progressbar - it_val_x = it_val1 - it_val_y = it_val2 - it_val_z = it_val3 - prev_voxel_size_x = prev_voxel_size - prev_voxel_size_y = prev_voxel_size - prev_voxel_size_z = prev_voxel_size - x = np.arange(size[0]) - y = np.arange(size[1]) - z = np.arange(size[2]) - my_interpolating_function = RegularGridInterpolator((x, y, z), data) - from progress.bar import Bar - - bar = Bar("Preparing Input: ", max=int(it_val_x)) - for i in range(it_val_x): # progressbar(range(it_val_x), prefix="", size=60): - # if i%10==0: - # print("Resizing finished %d/%d"%(i,it_val_x)) - bar.next() - for j in range(it_val_y): - for k in range(it_val_z): - if i / prev_voxel_size_x >= size[0] - 1: - x_query = size[0] - 1 - else: - x_query = i / prev_voxel_size_x - - if j / prev_voxel_size_y >= size[1] - 1: - y_query = size[1] - 1 - else: - y_query = j / prev_voxel_size_y - if k / prev_voxel_size_z >= size[2] - 1: - z_query = size[2] - 1 - else: - z_query = k / prev_voxel_size_z - current_query = np.array([x_query, y_query, z_query]) - current_value = float(my_interpolating_function(current_query)) - data_new[i, j, k] = current_value + args = argparse.ArgumentParser() + args.add_argument("-i", "--input_map_path", type=str, default=None) + args.add_argument("-o", "--output_map_path", type=str, default=None) + args = args.parse_args() - bar.finish() + Resize_Map(args.input_map_path,args.output_map_path) diff --git a/CryoREAD/data_processing/Unify_Map.py b/CryoREAD/data_processing/Unify_Map.py index c9c8450..8416262 100644 --- a/CryoREAD/data_processing/Unify_Map.py +++ b/CryoREAD/data_processing/Unify_Map.py @@ -7,15 +7,14 @@ def Unify_Map(input_map_path, new_map_path): # Read MRC file mrc = mrcfile.open(input_map_path, permissive=True) - data = mrc.data.copy() voxel_size = np.asarray(mrc.voxel_size.tolist(), dtype=np.float32) origin = np.array(mrc.header.origin.tolist(), dtype=np.float32) nstart = np.asarray([mrc.header.nxstart, mrc.header.nystart, mrc.header.nzstart], dtype=np.float32) cella = np.array(mrc.header.cella.tolist(), dtype=np.float32) mapcrs = np.asarray([mrc.header.mapc, mrc.header.mapr, mrc.header.maps], dtype=int) - if np.sum(nstart) == 0: - return input_map_path + # if np.sum(nstart) == 0: + # return input_map_path mrc.print_header() mrc.close() @@ -26,8 +25,9 @@ def Unify_Map(input_map_path, new_map_path): nstart = np.asarray([nstart[i] for i in sort]) data = np.transpose(data, axes=2 - sort[::-1]) - # Move offsets from nstart to origin - origin = origin + nstart * voxel_size + # Move offsets from nstart to origin if origin is zero (MRC2000) + if np.sum(origin) == 0: + origin = origin + nstart * voxel_size # Save the unified map mrc_new = mrcfile.new(new_map_path, data=data, overwrite=True) @@ -43,7 +43,12 @@ def Unify_Map(input_map_path, new_map_path): if __name__ == "__main__": - input_map_path = Path('../example/21051.mrc') - new_map_path = Path('../example/21051_unified.mrc') - new_map_path = Unify_Map(input_map_path, new_map_path) + import argparse + + args = argparse.ArgumentParser() + args.add_argument("-i", "--input_map_path", type=str, default=None) + args.add_argument("-o", "--output_map_path", type=str, default=None) + args = args.parse_args() + new_map_path = Unify_Map(args.input_map_path, args.output_map_path) print(f"New map path is {new_map_path}") +