Author: Niels Jeppesen (niejep@dtu.dk)
This notebook reads the TXM data format. In our experience NIfTI is much faster to read than TXM using the packages we use here.
The structure-tensor
package we will be using here is a 2D and 3D strcture tensor package for Python implemented by Vedrana A. Dahl and Niels Jeppesen.
To run this notebook there are some prerequisites that must be installed in our Python environment. We generally recommend creating a new Python environment to run the notebooks in, for instance using conda create -n <env_name> python=3.7
. This way we avoid conflicts with packages in your existing environment.
Install dependencies (for this notebook) using pip install numpy scipy scikit-image matplotlib nibabel tqdm structure-tensor
. To start the notebook we will of course also need jupyter
. Install the dxchange
package from GitHub, see note below.
Note: The current release of the dxchange
package has a bug in the TXM reader. This is fixed in the GitHub master branch, so for now the package must be installed directly from GitHub using pip install git+https://github.com/data-exchange/dxchange.git
. The setup.py
does not specify requirements for the package, which can be found in dxchange
requirements.txt instead.
Now, let's go ahead and import the required modules. The structure_tensor_workers.py
file should be in the notebook folder.
import os
from multiprocessing import Pool
import matplotlib as mpl
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import numpy as np
from dxchange import reader
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from scipy import ndimage
from scipy.spatial.transform import Rotation as R
from skimage import filters
from structure_tensor import eig_special_3d, structure_tensor_3d
from tqdm import tqdm
from structure_tensor_workers import calculate_angles, get_crops, init_worker, structure_tensor_analysis_v1
Warning: dxfile module not found spefile module not found netCDF4 module not found EdfFile module not found astropy module not found
plt.rcParams['image.interpolation'] = 'nearest'
First, we'll load a sample of the data and some meta data. We will be using the following folder structure:
../notebooks
contains the notebooks.../originals
should contain the original data files.../tmp
should contain any temperary files or output we create.../notebooks/figures/<file_name>
contains optionally saved figures.# Set file name and path.
file_name = 'DY06_FoV2.9 B2_recon EV rotate 6.5 degrees.txm'
file_path = os.path.join('../originals/', file_name)
# Change this to True to save figures to folder.
save_figures = False
# Create folder for to save figures.
fig_path = os.path.join('figures', os.path.basename(file_name) + '-l3')
if save_figures and fig_path and not os.path.exists(fig_path):
os.makedirs(fig_path)
# Read the first slice and metadata.
sample, meta = reader.read_txm(file_path, slice_range=(1, None, None))
data_shape = (meta['number_of_images'], meta['image_height'], meta['image_width'])
data_type = sample.dtype
voxel_size = meta['pixel_size']
fiber_diameter = 7.3
print('Shape:', data_shape)
print('Data type:', data_type)
print('Voxel size:', voxel_size, 'μm')
print('Fiber diameter size:', fiber_diameter, 'μm')
Shape: (1003, 1013, 992) Data type: uint16 Voxel size: 2.866401433944702 μm Fiber diameter size: 7.3 μm
Let's have a need look at the data, to ensure it was loaded correctly.
fig, axs = plt.subplots(1, 2, figsize=(20, 20))
i = data_shape[0] // 3
img, _ = reader.read_txm(file_path, slice_range=((i, i + 1), None, None))
axs[0].imshow(img.squeeze(), cmap='gray')
i = data_shape[0] // 3 * 2
img, _ = reader.read_txm(file_path, slice_range=((i, i + 1), None, None))
axs[1].imshow(img.squeeze(), cmap='gray')
[(ax.set_xlabel('Y'), ax.set_ylabel('Z')) for ax in axs]
plt.show()
Clearly a large part of the data is not of interest to us, so let's crop out the area of interest. You can cut out any part of the data you want to analyze.
l2_start = 400
layer_thickness = int(3.3 / 6 / voxel_size * 1000)
x_slice = slice(75, data_shape[0] - 75)
y_slice = slice(l2_start + layer_thickness, l2_start + layer_thickness * 2)
z_slice = slice(155, data_shape[2] - 255)
fig, axs = plt.subplots(1, 2, figsize=(20, 20))
i = data_shape[0] // 3
img, _ = reader.read_txm(file_path, slice_range=((i, i + 1), y_slice, z_slice))
axs[0].imshow(img.squeeze(), cmap='gray')
i = data_shape[0] // 3 * 2
img, _ = reader.read_txm(file_path, slice_range=((i, i + 1), y_slice, z_slice))
axs[1].imshow(img.squeeze(), cmap='gray')
[(ax.set_xlabel('Y'), ax.set_ylabel('Z')) for ax in axs]
plt.show()
Create new shape.
data_shape = (x_slice.stop - x_slice.start, z_slice.stop - z_slice.start, y_slice.stop - y_slice.start)
print('New shape:', data_shape)
New shape: (853, 582, 191)
To work with the data, we will save it as a raw data file.
temp_folder = '../tmp/'
if temp_folder and not os.path.exists(temp_folder):
os.mkdir(temp_folder)
data_path = os.path.join(temp_folder, file_name + f'_{x_slice.start}-{x_slice.stop}_{y_slice.start}-{y_slice.stop}_{z_slice.start}-{z_slice.stop}.raw')
if not os.path.exists(data_path):
# Number of images read at a time.
# Reduce this to lower memory usage.
chunck_size = 1024
print('Creating new file:', data_path)
data = np.memmap(data_path, dtype=data_type, shape=data_shape, mode='w+')
for i in tqdm(range(0, data_shape[0], chunck_size)):
x0 = i + x_slice.start
x1 = min(x_slice.stop, x0 + chunck_size)
chunck, _ = reader.read_txm(file_path, slice_range=((x0, x1), y_slice, z_slice))
data[i:i + chunck_size] = np.rot90(chunck, k=1, axes=(1, 2))
else:
print('Using existing file:', data_path)
data = np.memmap(data_path, dtype=data_type, shape=data_shape, mode='r')
Using existing file: ../tmp/DY06_FoV2.9 B2_recon EV rotate 6.5 degrees.txm_75-928_591-782_155-737.raw
fig, axs = plt.subplots(2, 2, figsize=(20, 20))
axs[0, 0].imshow(data[data.shape[0] // 3], cmap='gray')
axs[0, 1].imshow(data[data.shape[0] // 3 * 2], cmap='gray')
[(ax.set_xlabel('Z'), ax.set_ylabel('Y')) for ax in axs[0]]
axs[1, 0].imshow(data[:, data.shape[1] // 2], cmap='gray')
axs[1, 0].set_xlabel('Y'), axs[1, 0].set_ylabel('X')
axs[1, 1].imshow(data[..., data.shape[-1] // 2], cmap='gray')
axs[1, 1].set_xlabel('Z'), axs[1, 1].set_ylabel('X')
plt.show()
To choose a good intensity threshold for the fibers we can use a histogram and choose the value manually. Alternatively we can use a method such as Otsu's threshold. For choosing the threshold we make sure only to include inside-sample data. This makes it easier to separate foreground (fibers) and background inside the sample.
threshold_bins = 1024
threshold_data = data[25:-25, 25:-25, 25:-25]
%%time
otsu_threshold = filters.threshold_otsu(threshold_data.reshape(-1, 1), nbins=threshold_bins)
hand_picked_threshold = 14400
print('Otsu threshold:', otsu_threshold)
print('Hand-picked threshold:', hand_picked_threshold)
Otsu threshold: 14453 Hand-picked threshold: 14400 CPU times: user 455 ms, sys: 150 ms, total: 605 ms Wall time: 605 ms
print('Otsu fiber fraction:', round(np.count_nonzero(threshold_data >= otsu_threshold) / threshold_data.size, 3))
print('Hand-picked fiber fraction:', round(np.count_nonzero(threshold_data >= hand_picked_threshold) / threshold_data.size, 3))
Otsu fiber fraction: 0.537 Hand-picked fiber fraction: 0.576
percentiles = [.5, 99.5]
percentile_values = np.percentile(threshold_data, percentiles)
print(f'Percentiles ({percentiles}): {percentile_values}')
Percentiles ([0.5, 99.5]): [13391. 15571.]
ax = plt.subplot(1, 1, 1)
ax.hist(threshold_data.flat, bins=threshold_bins, density=True)
ax.axvline(otsu_threshold, c='r')
ax.axvline(hand_picked_threshold, c='g')
for p in percentile_values:
ax.axvline(p, c='k', ls='--')
ax.set_xlim(11000, 19000)
plt.show()
We set $\sigma$ and $\rho$ based on the size of the fibers that we want to analyze. For more details see the related paper and StructureTensorFiberAnalysisDemo notebook.
Then, we set the block size (crop_size
), which will determine the maximum size of the blocks we will partition the data into for the ST calculation. The maximum black size will depend on crop_size
, the truncate
value used for Gaussian filters (default is 4), and $\rho$ or $\sigma$ (usually $\rho$ because ti is largest).
We also set the fiber_threshold
and specify a list of devices to use for the calculation. The list determines how the blocks will be distributed between devices for calculations. If we have a dual GPU system supporting CUDA, we can specify ['cuda:0', 'cuda:1']
, which will distribute blocks evenly between the two GPUs. ['cuda:0', 'cuda:1', 'cuda:1']
would move two thirds to GPU 1, while ['cuda:0', 'cuda:1', 'cpu']
will move one third of the blocks to the CPU. Remember to update the device
list to match your system resources. Specifying CUDA devices that are not available will result in exceptions and/or undefined behaviour. If you don't have CUDA GPUs available, just set device = ['cpu']
and specify the number of processes later.
The class_vectors
area used to segment the voxels based on the orientations of the ST eigenvectors. It is a list of unit vectors, which represent each of the fiber classes. The first vector is considered the primary orientation and will be used for calculating orientation metrics for each voxel. For more details see the related paper and StructureTensorFiberAnalysisDemo notebook.
Lastly, bins_per_degree
determines the number of bins used per degree when aggregating the orientation metrics.
rho = fiber_diameter / voxel_size
rho = round(rho, 2)
sigma = rho / 2
print('sigma:', sigma)
print('rho:', rho)
truncate = 4
kernel_radius = int(max(rho, sigma) * truncate + 0.5)
print('kernel_radius:', kernel_radius)
crop_size = 230
print('crop_size:', crop_size)
print('Maximum block size:', crop_size + int(4 * max(rho, sigma) + 0.5))
# Important: Listing invalid CUDA devices may results in exceptions.
device = ['cuda:0']
class_names = ['0']
class_vectors = np.array([[0, 0, 1]], dtype=np.float64)
class_vectors /= np.linalg.norm(class_vectors, axis=-1)[..., np.newaxis]
bins_per_degree = 10
sigma: 1.275 rho: 2.55 kernel_radius: 10 crop_size: 230 Maximum block size: 240
To create the mask for valid data we create a mask as before, except that we erode the mask with an amount equal to the kernel radius. This was we "remove" voxels that are affected by voxels without values (black/zero intensity).
# Create mask memory map.
mask_path = os.path.join(temp_folder, file_name + '_mask-l3.raw')
mask = np.memmap(mask_path, dtype=np.bool, shape=data.shape, mode='w+')
# Get mask for values different from 0.
mask[:] = data != 0
# Erode mask to keep only valid voxel (that are not affected by the edge).
mask[:] = ndimage.binary_erosion(mask, iterations=kernel_radius)
# Filter out first and last slices on all axis.
ignore = 25
mask[:ignore] = False
mask[-ignore:] = False
mask[:, :ignore] = False
mask[:, -ignore:] = False
mask[..., :ignore] = False
mask[..., -ignore:] = False
Now that we have a threshold and data mask, we combine them into a fiber mask, which is True
for voxels which contain fibers.
fiber_threshold = otsu_threshold
mask &= data > fiber_threshold
We can plot the mask on top of the data. Here, red are fiber voxels while blue are non-fiber voxels, background and voxels ignored for being close to the edge of the data.
plt.figure(figsize=(20, 20))
ax = plt.subplot(2, 2, 1)
ax.imshow(data[data.shape[0] // 3], cmap='gray')
ax.imshow(mask[data.shape[0] // 3], alpha=.5, cmap='bwr', interpolation='nearest')
ax = plt.subplot(2, 2, 2)
ax.imshow(data[data.shape[0] // 3 * 2], cmap='gray')
ax.imshow(mask[data.shape[0] // 3 * 2], alpha=.5, cmap='bwr', interpolation='nearest')
ax = plt.subplot(2, 2, 3)
ax.imshow(data[:, data.shape[1] // 2], cmap='gray')
ax.imshow(mask[:, data.shape[1] // 2], alpha=.5, cmap='bwr', interpolation='nearest')
ax = plt.subplot(2, 2, 4)
ax.imshow(data[..., data.shape[-1] // 2], cmap='gray')
ax.imshow(mask[..., data.shape[-1] // 2], alpha=.5, cmap='bwr', interpolation='nearest')
plt.show()
We can use the get_crops
function to partition the data into blocks (crops
) with proper padding. However, we will actually only use len(crops)
here, as we will use multiprocessing
to distribute the blocks accross multiple devices. We may include a function just for calculating the number of blocks/crops at a later point.
# Get crops as memory views.
crops, crop_positions, crop_paddings = get_crops(data, max(sigma, rho), crop_size=crop_size, truncate=truncate)
We will be using the structure_tensor_analysis_v1
function to calculate the structure tensor, $S$, and do the eigendecomposition for the blocks in parallel. We will be saving the following metrics to the disk using memory maps created below:
theta
contains the angle, $\theta$ between vec
and class 0, at each voxel.To calcualte these metrics, the structure_tensor_analysis_v1
uses the calculate_angles
function, which is explained further in the StructureTensorFiberAnalysisDemo notebook. However, the metrics returned by structure_tensor_analysis_v1
have been aggregated (binned), as returning all the metrics for each block unaggregated would obvously consume large amounts of memory and be infeasible for large volumes. We will combine the aggregated data from each block afterwards.
In the code below we create memory-mapped files for the eigenvectors, along with three other metrics. The structure_tensor_analysis_v1
function will be using these to save the metrics for the volume straight to the disk, which may require a significant amount of disk space, but shouldn't result in memory issues. Besides the five types shown below, the function also supports saving $S$ and the eigenvalues. See structure_tensor_workers.py
for details. In the example below we will be saving the results using 16-bit precision data types to save space. This is usually fine for visualization and most statistics. Saving the metrics to disk is optional. If you don't need the per voxel orientations for later, don't create the memory-maps and remove the entries from the init_args
dictionary below. This will save you a lot of disk space and probably reduce the processing time.
# Output names, dtypes and shapes.
map_names = ['vec', 'theta']
map_dtypes = [np.float16, np.uint8]
map_shapes = [(3, ) + data.shape, data.shape]
# Output paths.
map_paths = {n: data_path.replace('.raw', f'-{n}-{sigma}-{rho}-{fiber_threshold}.raw') for n in map_names}
map_paths['vec'] = data_path.replace('.raw', f'-vec-{sigma}-{rho}.raw')
# Create maps.
maps = {}
new_maps = {}
for n, dtype, shape in zip(map_names, map_dtypes, map_shapes):
path = map_paths[n]
create_map = not os.path.exists(path) or os.stat(path).st_size != np.product(shape) * np.dtype(dtype).itemsize
if create_map:
print(f'Creating memory mapped file:\n{path}')
mmap = np.memmap(path, dtype=dtype, shape=shape, mode='w+')
if np.issubdtype(dtype, np.floating) and n not in ['S', 'val', 'vec']:
mmap[:] = np.nan
new_maps[n] = (path, dtype)
maps[n] = np.memmap(path, dtype=dtype, shape=shape, mode='r')
# Get memory maps.
vec = maps['vec']
theta = maps['theta']
Now we're finally ready to perform the analysis. We will be using multiprocessing.Pool
to distribute the work across multiple CPU cores. If specified in the device
list, the CPU will offload the majority of the work to a GPU, otherwise it'll do the calculations itself. Here, we will create four processes for each device in the list (processes=4 * len(device)
). As our device
list contains four GPUs, we will be starting 16 processes, four for each GPU. Beware that this may require a significant amount of GPU memory. If you experience out of memory exceptions, either reduce crop_size
and/or the number of processes per device. Change the number of processes to fit your system.
if __name__ == '__main__':
results = []
init_args = {'data': data_path,
'dtype': data.dtype,
'shape': data.shape,
'mask': mask_path,
'rho': rho,
'sigma': sigma,
'crop_size': crop_size,
'class_vectors': class_vectors,
'bins_per_degree': bins_per_degree,
'device': device,
'return_aggregate': True
}
# Add output maps to args.
for n in new_maps:
path, dtype = new_maps[n]
init_args[n] = path
init_args[f'{n}_dtype'] = dtype
with Pool(processes=4 * len(device), initializer=init_worker, initargs=(init_args,)) as pool:
for res in tqdm(pool.imap_unordered(structure_tensor_analysis_v1, range(len(crops)), chunksize=1), total=len(crops)):
results.append(res)
100%|██████████| 12/12 [00:16<00:00, 1.37s/it]
All shared variables are passed to the workers using the init_args
dictionary and init_worker
function. This function is only called once per worker. When called, it will create all the memory-maps needed to read and write data during calculations, based on init_args
. We will use the pool.imap_unordered
function to tell the workers to perform the calculations for a block, by passing the structure_tensor_analysis_v1
function along with the index of the block which we want to analyze. The returned results are a tuple of aggregated metrics, as decribed earlier. For more details see the structure_tensor_analysis_v1
code. Below we will show have to combine the aggregated metrics and display them as histograms.
Another option is to save all the metrics to disk (using memory-maps as described earlier) with reasonable precision and perform the statistics directly on the full orientation volumes. This approach would be similar to what's done in the StructureTensorFiberAnalysisDemo notebook, except there data is kept in memory instead of on the disk. If you use this approach you may simply ignore the aggregated data returned by structure_tensor_analysis_v1
, but it will require you to use a significant amount of disk space to store all the orientation and segmentation data. If the volumes are very big, working with them may also be slow and memory intensive, even if you use memmap
to access the data.
The results
list contains the binned metrics for each separate block, so before we can shown these, we have to combine them. First we will filter out elements with no data (i.e. blocks with no fibers) and convert the list to a NumPy array.
res = np.asarray([r[1] for r in results if r[1] is not None], dtype=np.object)
fiber_counts = res[:, :, 0].astype(np.float)
eta_os = res[:, :, 1].astype(np.float)
eta_os[np.isnan(eta_os)] = 0
eta_o = np.average(eta_os, weights=fiber_counts, axis=0)
eta_o = eta_o[0].item()
print(f'eta_o: {round(eta_o, 6)}')
eta_o: 0.99568
The histograms below show the distribution of the absolute angles between the fibers and the X-axis unit vector, along with the median angle. The resoulution of the histograms is determined by bins_per_degree
.
thetas = np.sum(res[:, :, 4], axis=0)
# Calculate theta for each class using the masks.
values = thetas[0]
theta_x = values
xs = np.arange(len(values)) / bins_per_degree
ys = values / (np.sum(values) / bins_per_degree)
m = round(xs[np.argmax(np.cumsum(ys) / bins_per_degree > 0.5)], 1)
fig = plt.figure(figsize=(4.5, 2.5))
ax = plt.subplot(1, 1, 1, title='Angle from X')
ax.set_xlim([0, 10])
ax.axvline(m, c='b', ls='-', label=f'$Med={round(m, 2)}\degree$')
ax.bar(xs, ys, align='edge', width=1 / bins_per_degree)
ax.set_xlabel('Angle ($\degree$)')
ax.set_ylabel('Fraction')
plt.legend()
if save_figures:
ax.set_title(None)
plt.tight_layout()
plt.savefig(os.path.join(fig_path, f'theta.pdf'), bbox_inches='tight')
plt.tight_layout()
plt.show()
The histograms below show the distribution of the rotation around the Z-axis (in-XY-plane rotation), the mean and the standard deviation for each class.
To calculate mean and standard deviation properly for the aggregated data, we create two function. We need these to calculate the mean and standard deviation for the complete volume, based on the means and standard deviations for each block, which is what we have available in the res
array.
def calc_mean(means, fiber_counts):
return np.average(means.astype(np.float), weights=fiber_counts, axis=0)
def calc_std(means, stds, fiber_counts):
mean = calc_mean(means, fiber_counts)
return np.sqrt((np.sum(fiber_counts * stds**2, axis=0) + np.sum(fiber_counts * (means - mean)**2, axis=0))
/ np.sum(fiber_counts, axis=0))
Using these function, we calculate mean and standard deviation for orientations in the XY-plane.
in_xy_angle_means = calc_mean(res[:, :, 5].astype(np.float), fiber_counts)
in_xy_angle_stds = calc_std(res[:, :, 5].astype(np.float), res[:, :, 6].astype(np.float), fiber_counts)
in_xy_angle = np.sum(res[:, :, 7], axis=0)
values = in_xy_angle[0]
theta_xy = values
mean = in_xy_angle_means[0]
std = in_xy_angle_stds[0]
xs = (np.arange(len(values)) - len(values) // 2) / bins_per_degree
ys = values / (np.sum(values) / bins_per_degree)
fig = plt.figure(figsize=(4.5, 2.5))
ax = plt.subplot(1, 1, 1, title='In-plane orientation (XY-plane)')
ax.set_xlim([-10, 10])
ax.bar(xs, ys, align='edge', width=1 / bins_per_degree)
ax.set_xlabel('Angle ($\degree$)')
ax.set_ylabel('Fraction')
ax.axvline(0, c='k', ls='--', label='$0\degree$')
ax.axvline(mean, c='r', ls='-', label=f'$\\bar{{x}}={round(mean, 2)}\degree$')
ax.axvline(mean - std, c='r', ls='--', label=f'$s=\pm{round(std, 2)}\degree$')
ax.axvline(mean + std, c='r', ls='--')
plt.legend()
if save_figures:
ax.set_title(None)
plt.tight_layout()
plt.savefig(os.path.join(fig_path, f'theta-in-xy.pdf'), bbox_inches='tight')
plt.tight_layout()
plt.show()
The histograms below show the distribution of the rotation around the Y-axis (in-XZ-plane rotation), the mean and the standard deviation for each class.
in_xz_angle_means = calc_mean(res[:, :, 11].astype(np.float), fiber_counts)
in_xz_angle_stds = calc_std(res[:, :, 11].astype(np.float), res[:, :, 12].astype(np.float), fiber_counts)
in_xz_angle = np.sum(res[:, :, 13], axis=0)
values = in_xz_angle[0]
theta_xy = values
mean = in_xz_angle_means[0]
std = in_xz_angle_stds[0]
xs = (np.arange(len(values)) - len(values) // 2) / bins_per_degree
ys = values / (np.sum(values) / bins_per_degree)
fig = plt.figure(figsize=(4.5, 2.5))
ax = plt.subplot(1, 1, 1, title='In-plane orientation (XZ-plane)')
ax.set_xlim([-10, 10])
ax.bar(xs, ys, align='edge', width=1 / bins_per_degree)
ax.set_xlabel('Angle ($\degree$)')
ax.set_ylabel('Fraction')
ax.axvline(0, c='k', ls='--', label='$0\degree$')
ax.axvline(mean, c='r', ls='-', label=f'$\\bar{{x}}={round(mean, 2)}\degree$')
ax.axvline(mean - std, c='r', ls='--', label=f'$s=\pm{round(std, 2)}\degree$')
ax.axvline(mean + std, c='r', ls='--')
plt.legend()
if save_figures:
ax.set_title(None)
plt.tight_layout()
plt.savefig(os.path.join(fig_path, f'theta-in-xz.pdf'), bbox_inches='tight')
plt.tight_layout()
plt.show()
The class fractions and orientation distributions are useful for determining if the material follows production specifications. Here we've chosen a specific set of metrics to calculate and plot, but many other distributions can just as easily be obtained and plotted.
As we chose to save the eigenvectors for each voxel in the data
volume to the disk using the vec
memory map, we can actually access all the eigenvectors along with the original data if we like. We also have orientation metrics avilable through the theta
memory-map.
Let's grab som data and their ST eigenvectors.
data_slices = [
data[data.shape[0] // 2].copy(),
data[:, data.shape[1] // 2].copy(),
data[:, :, data.shape[2] // 2].copy(),
]
vec = maps['vec']
vec_slices = [
vec[:, data.shape[0] // 2].copy(),
vec[:, :, data.shape[1] // 2].copy(),
vec[:, :, :, data.shape[2] // 2].copy(),
]
theta = maps['theta']
theta_slices = [
theta[data.shape[0] // 2].copy(),
theta[:, data.shape[1] // 2].copy(),
theta[:, :, data.shape[2] // 2].copy(),
]
mask_slices = [
mask[data.shape[0] // 2].copy(),
mask[:, data.shape[1] // 2].copy(),
mask[:, :, data.shape[2] // 2].copy(),
]
Let's can have a look at the data and the eigenvectors.
fig, axs = plt.subplots(1, len(data_slices), figsize=(20, 20))
for img, ax in zip(data_slices, axs):
ax.imshow(img, cmap='gray')
plt.show()
fig, axs = plt.subplots(1, len(theta_slices), figsize=(20, 20))
for img, ax in zip(theta_slices, axs):
ax.imshow(img)
plt.show()
We can use the calculate_angles
function to calculate the segmentation and orientation metrics for the slices we just read from the volumes.
First we define some plotting functions, then we calculate vec_class
, eta_os
, theta
, out_xy_angle
and in_xy_angle
for each of the slices and plot them. All the figures will be saved to the figures folder we created in the beginning. It easy to create figures for other slices simply be changing which slices are extracted from the volumes.
def add_scalebar(ax, length=100, scale=1, unit='μm', text=None):
ax.set_xticks([])
ax.set_yticks([])
text = f'{length} {unit}' if text is None else text
fontprops = fm.FontProperties(size=18)
scalebar = AnchoredSizeBar(ax.transData,
length / scale, text, 'lower left',
pad=0.3,
color='white',
frameon=False,
size_vertical=3,
fontproperties=fontprops)
ax.add_artist(scalebar)
def savefig(ax, i, dpi=96):
title = ax.get_title()
ax.set_title(None)
title_clean = title.replace(':', '_')
# plt.savefig(os.path.join(fig_path, f'{title_clean}_{i}.svg'), bbox_inches='tight', pad_inches=0, dpi=dpi)
plt.savefig(os.path.join(fig_path, f'{title_clean}_{i}.pdf'), bbox_inches='tight', pad_inches=0, dpi=dpi)
ax.set_title(title)
NOTE: You can easily change which figures are made by commenting out the calls to fig_with_colorbar
in the show_metrics
function.
def fig_with_colorbar(i, d, o, title, alpha=0.5, cmap=None, vmin=None, vmax=None, ax=None, show=True, divider=None):
"""Creates a figure with data, overlay and color bar."""
o = np.rot90(o, k=-1)
d = np.rot90(d, k=-1)
fig_size = 5
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(20, fig_size))
if divider is None:
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size=0.2, pad=0.05)
ax.imshow(d, cmap='gray')
if np.issubdtype(o.dtype, np.integer):
cmap = plt.get_cmap('gist_rainbow', len(class_names))
im = ax.imshow(o, alpha=alpha, cmap=cmap, vmin=-.5, vmax=len(class_names) - .5)
cbar = ax.figure.colorbar(im, cax=cax, orientation='vertical', ticks=np.arange(len(class_names)))
cbar.ax.set_yticklabels(class_names)
else:
im = ax.imshow(o, alpha=alpha, cmap=cmap, vmin=vmin, vmax=vmax)
# Hack to avoid lines in colorbar due to bug when using alpha < 1.
c = plt.cm.get_cmap(cmap)(np.linspace(0, 1, 256))
c[:, :-1] = (1 - alpha) + alpha * c[:, :-1]
acm = mpl.colors.ListedColormap(c)
cbar = ax.figure.colorbar(plt.cm.ScalarMappable(None, acm), cax=cax, orientation='vertical')
cbar.set_ticks(np.linspace(0, 1, 11))
cbar.set_ticklabels([('%f' % x).rstrip('0').rstrip('.') for x in np.linspace(vmin, vmax, len(cbar.ax.get_yticklabels()))])
cbar.set_label('Angle ($\degree$)')
ax.set_title(title)
add_scalebar(ax, length=500, scale=voxel_size, text='0.5 mm')
if show:
if save_figures:
savefig(ax, i, dpi=300)
plt.show()
from matplotlib.colors import to_rgb
def show_metrics(data_slices, vec_slices, fiber_threshold=None):
for i, (d, v, m) in enumerate(zip(data_slices, vec_slices, mask_slices)):
d, v = np.asarray(d, dtype=np.float32), np.asarray(v, dtype=np.float32)
m = ~m
vec_class, eta_os, theta, out_xy_angle, in_xy_angle, in_xz_angle = calculate_angles(v, class_vectors, include_in_plane_xz=True)
vec_class = np.ma.masked_where(m, vec_class)
eta_os[m] = np.nan
theta[m] = np.nan
in_xy_angle[m] = np.nan
in_xz_angle[m] = np.nan
# fig_with_colorbar(i, d, vec_class, 'Class', alpha=0.7)
# fig_with_colorbar(i, d, eta_os, '$\eta_O$(X)', alpha=0.5, vmin=0, vmax=1)
# fig_with_colorbar(i, d, theta, 'Angle from X (0-90 deg.)', alpha=0.5, vmin=0, vmax=90)
# fig_with_colorbar(i, d, in_xy_angle, 'In-XY-plane orientation (-90-90 deg.)', cmap='bwr', alpha=0.5, vmin=-90, vmax=90)
# fig_with_colorbar(i, d, in_xz_angle, 'In-XZ-plane orientation (-90-90 deg.)', cmap='bwr', alpha=0.5, vmin=-90, vmax=90)
fig_with_colorbar(i, d, theta, 'Angle from X (0-10 deg.)', alpha=0.7, vmin=0, vmax=10)
fig_with_colorbar(i, d, in_xy_angle, 'In-XY-plane orientation (-10-10 deg.)', cmap='bwr', alpha=0.7, vmin=-10, vmax=10)
fig_with_colorbar(i, d, in_xz_angle, 'In-XZ-plane orientation (-10-10 deg.)', cmap='bwr', alpha=0.7, vmin=-10, vmax=10)
show_metrics(data_slices, vec_slices, fiber_threshold=fiber_threshold)