#!/usr/bin/env python
import h5py
import numpy as np
from skimage import filters,io,img_as_float,exposure,morphology,segmentation,measure,feature,color
from scipy import ndimage as nd
import itertools
import pickle
import multiprocessing
import argparse
from pysmFISH import utils
from pysmFISH import object_based_segmentation
[docs]def staining_segmentation():
"""
This script will segment the selected staining and output the identified
objects.
All the parameters are entered via argparse.
Parameters:
-----------
scheduler: string
tcp address of the dask.distributed scheduler (ex. tcp://192.168.0.4:7003).
default = False. If False the process will run on the local computer using nCPUs-1
path: string
Path to the processing directory
processing_file: string
Path to the hdf5 file with the staning to process
segmentation_staining: string
Staining to be segmented
"""
# Inputs of the function
parser = argparse.ArgumentParser(description='Segmentation script')
parser.add_argument('-scheduler', default=False, help='dask scheduler address ex. tcp://192.168.0.4:7003')
parser.add_argument('-path', help='processing directory')
parser.add_argument('-processing_file', help='path to the file with the \
staning to process')
parser.add_argument('-segmentation_staining', help='staining to be \
segmented')
args = parser.parse_args()
# Directory to process
processing_directory = args.path
# File to process
processing_file = args.processing_file
# staining to segment
segmentation_staining = args.segmentation_staining
# Dask scheduler address
scheduler_address = args.scheduler
if scheduler_address:
# Start dask client on server or cluster
client=Client(scheduler_address)
else:
# Start dask client on local machine. It will use all the availabe
# cores -1
# number of core to use
ncores = multiprocessing.cpu_count()-1
cluster = LocalCluster(n_workers=ncores)
client=Client(cluster)
# Determine the operating system running the code
os_windows, add_slash = utils.determine_os()
# Check training slash in the processing directory
processing_directory=utils.check_trailing_slash(processing_directory,os_windows)
segmentation_parameters = utils.general_yaml_parser(processing_directory+'Staining_segmentation.config.yaml')
# Chunking parameters
chunk_size = segmentation_parameters[segmentation_staining]['image_chunking_parameters']['chunk_size']
percent_padding = segmentation_parameters[segmentation_staining]['image_chunking_parameters']['percent_padding']
# Segmentation parameters
trimming = segmentation_parameters[segmentation_staining]['segmentation_parameters']['trimming']
min_object_size = segmentation_parameters[segmentation_staining]['segmentation_parameters']['min_object_size']
disk_radium_rank_filer = segmentation_parameters[segmentation_staining]['segmentation_parameters']['disk_radium_rank_filer']
min_distance = segmentation_parameters[segmentation_staining]['segmentation_parameters']['min_distance']
threshold_rel = segmentation_parameters[segmentation_staining]['segmentation_parameters']['threshold_rel']
# Load the image (will be modified after the change to hdf5 input)
img = io.imread(processing_file)
# Image chunking
nr_chunks,nc_chunks,Coords_Chunks_list, Coords_Padded_Chunks_list,r_coords_tl_all_padded,\
c_coords_tl_all_padded,r_coords_br_all_padded,c_coords_br_all_padded = \
object_based_segmentation.image_chunking(img,chunk_size,percent_padding)
# Create the chunks idx
Chunks_idxs_linear=np.arange(len(Coords_Padded_Chunks_list),dtype='int32')
# Distribute the chunks idx and distridute them in an array according to the position
# in the chunked image
Chunks_idxs=Chunks_idxs_linear.reshape(nr_chunks,nc_chunks)
# Flatten the array for make it easier the creation of the coords combination
Chunks_idxs_rows=np.ravel(Chunks_idxs)
Chunks_idxs_cols=np.ravel(Chunks_idxs,order='F')
# Calculate coords of the overlapping chunks
Overlapping_chunks_coords=list()
counter=0
left_pos=Chunks_idxs_rows[0]
for el in Chunks_idxs_rows[1:]:
if counter < nc_chunks-1:
Coords_left=Coords_Padded_Chunks_list[left_pos]
Coords_right=Coords_Padded_Chunks_list[el]
row_tl=Coords_left[0]
row_br=Coords_left[1]
col_tl=Coords_right[2]
col_br=Coords_left[3]
Overlapping_chunks_coords.append((row_tl,row_br,col_tl,col_br))
left_pos=el
counter+=1
else:
left_pos=el
counter=0
counter=0
top_pos=Chunks_idxs_cols[0]
for el in Chunks_idxs_cols[1:]:
if counter < nr_chunks-1:
Coords_top=Coords_Padded_Chunks_list[top_pos]
Coords_bottom=Coords_Padded_Chunks_list[el]
row_tl=Coords_bottom[0]
row_br=Coords_top[1]
col_tl=Coords_top[2]
col_br=Coords_top[3]
Overlapping_chunks_coords.append((row_tl,row_br,col_tl,col_br))
counter+=1
top_pos=el
else:
top_pos=el
counter=0
# Now i use this approach for testing. If the image gets to big to fit in RAM
# then save the files and load them separately in each node
chunked_image_seq = list()
for coords in Coords_Padded_Chunks_list:
chunked_image_seq.append(img[coords[0]:coords[1],coords[2]:coords[3]])
# Run the segmentation
futures_processes = client.map(object_based_segmentation.polyT_segmentation,chunked_image_seq,
min_object_size=min_object_size,
min_distance=min_distance,
disk_radium_rank_filer=disk_radium_rank_filer,
threshold_rel=threshold_rel,
trimming=trimming)
Objects_list = client.gather(futures_processes)
# Recalculate labels and coords
processed_images_data = dict()
max_starting_label = 0
total_data_dict = dict()
for idx, objs_chunk in enumerate(Objects_list):
for label ,cvalues in objs_chunk.items():
new_label=max_starting_label+1
coords = Coords_Padded_Chunks_list[idx][0::2]
total_data_dict[new_label] = cvalues+coords
max_starting_label = new_label
# Calculate all the intersecting objects
futures_processes = client.map(object_based_segmentation.OverlappingCouples,Overlapping_chunks_coords,
TotalDataDict = total_data_dict)
All_intersecting = client.gather(futures_processes)
# Put together the couple with the same label for multiple intersection
# for the labels of objects where there is intersection between multiple regions
# Then scatter all of them and calculate intersection
# Combine the results from the parallel processing
flatten_couple = [el for grp in All_intersecting for el in grp]
# Remove duplicates
flatten_couple=list(set(flatten_couple))
# Create a list of the labels (removing the repeats)
singles=list()
[singles.append(x) for cpl in flatten_couple for x in cpl]
singles=list(set(singles))
# Identify the couples containing singles
Combined_all_singles=list()
for item in singles:
Combined_single=list()
for couple in flatten_couple:
if item in couple:
Combined_single.append(couple)
Combined_all_singles.append(Combined_single)
if Combined_all_singles:
# Combine all the intersecting labeles
start=Combined_all_singles[0]
ComparisonList=Combined_all_singles[1:].copy()
#merged=start.copy()
merged=list()
SavedCombinations=list()
tmp_list=ComparisonList.copy()
KeepGoing=True
Loop=0
while KeepGoing:
Loop+=1
for idx,el in enumerate(ComparisonList):
if set(start).intersection(set(el)):
#merged=list(set(merged)|set(el))
[merged.append(x) for x in el]
tmp_list = [e for e in tmp_list if e != el]
intersection=list(set.intersection(set(merged),set(start)))
if intersection:
merged=list(set.union(set(merged),set(start)))
#merged=list(set(merged))
start=merged.copy()
merged=list()
ComparisonList=tmp_list.copy()
#tmp_list.append(merged)
else:
SavedCombinations.append(start)
start=tmp_list[0]
tmp_list=tmp_list[1:]
ComparisonList=tmp_list.copy()
if len(tmp_list)<1:
[SavedCombinations.append(x) for x in tmp_list]
KeepGoing =False
# Remove all the duplicated labeled that intersect
# in this case the labeled are merged. It will be nice to run an extra
# segmentation on the merged objects
# If it is too slow this step can be parallelised
SavedLab_list=list()
CleanedDict=total_data_dict.copy()
for couple in SavedCombinations:
SaveLab, RemoveLabs,NewCoords=object_based_segmentation.IntersectionCouples(couple,total_data_dict)
SavedLab_list.append(SaveLab)
for lab in RemoveLabs:
del CleanedDict[lab]
CleanedDict[SaveLab]=NewCoords
else:
CleanedDict=total_data_dict
# Calculate all objects properties
all_objects_list = [(key,coords) for key,coords in CleanedDict.items()]
futures_processes = client.map(object_based_segmentation.obj_properties_calculator,all_objects_list)
all_objects_properties_list = client.gather(futures_processes)
# convert the list to a dictionary
all_objects_properties_dict = { k: v for d in all_objects_properties_list for k, v in d.items() }
# Save all the objects
segmented_objs_fname = processing_directory + 'segmented_' + segmentation_staining + '_all_objs_properties.pkl'
pickle.dump(all_objects_properties_dict,open(segmented_objs_fname,'wb'))
if __name__ == "__main__":
staining_segmentation()