In [None]:
# Import packages needed
import math
import os 
import glob
import re
import matplotlib.pyplot as plt
from matplotlib import gridspec
import pandas as pd
import numpy as np
import scipy as sp
import seaborn as sb
import copy
from functools import reduce #https://stackoverflow.com/a/595409
import operator #https://stackoverflow.com/a/595409
import scipy.stats #https://stackoverflow.com/questions/12412895/how-to-calculate-probability-in-a-normal-distribution-given-mean-standard-devi
from scipy.ndimage import gaussian_filter
from pathlib import Path #https://stackoverflow.com/a/273227

# Functions definitions   

In [None]:
v = [0,1,2,3,4,5]
np.quantile(v,0.5)

## Re-orient all tails with injury site on the right

In [None]:

# Reorient data: rotate and set the zero to the .99 quantile
def getQuantiles(df, quantiles=[0.0,0.01,0.05,0.5,0.95,0.99,1.0], column='Position X', set='Class A'):
    df_filtered = df[df['Set'] == set]
    return np.quantile(df_filtered[column], quantiles)

def computeInterquantileDistance(df, quantiles, column='Position X', set='Class A'):
    Q = getQuantiles(df, quantiles, column, set)
    return np.sum(np.abs(np.diff(Q)))

def isCutOnRight(df, quantiles=[0.0,0.01,0.05], column='Position X'):
    iqd_left = computeInterquantileDistance(df, quantiles, column, set='Class A')
    iqd_right = computeInterquantileDistance(df, np.subtract(1,quantiles), column, set='Class A')
    return iqd_left <= iqd_right

def computeQuartileBalance(df, quantiles, column='Position X', set='Class A'):
    iqd_left = computeInterquantileDistance(df, quantiles, column, set)
    iqd_right = computeInterquantileDistance(df, np.subtract(1,quantiles), column, set)
    return np.abs(iqd_left - iqd_right)
    # return np.abs(iqd_left**2 - iqd_right**2)

def isXMoreUmbalancedThanY(df, quantiles=[0.0,0.01,0.05], set='Class A'):
    bal_X = computeQuartileBalance(df, quantiles, 'Position X', set)
    bal_Y = computeQuartileBalance(df, quantiles, 'Position Y', set)
    return bal_X >= bal_Y

def width(df, quantile=0.01, set='Class B'):
    return computeInterquantileDistance(df, [quantile, 1-quantile], 'Position X', set)

def height(df, quantile=0.01, set='Class B'):
    return computeInterquantileDistance(df, [quantile, 1-quantile], 'Position Y', set)

def aspectRatio(df, quantile=0.01, set='Class B'):
    w = width(df, quantile, set)
    h = height(df, quantile, set)
    return w/h

# def isHorizontal(fishData, quantile=0.01, threshold=25.0): # Threshold is the min width or height in um
#     df = fishData['DAPI']
#     try:
#         w = width(df, quantile, 'Class B')
#         h = height(df, quantile, 'Class B')
#         ar = aspectRatio(df, quantile, 'Class B')
#         if w < threshold or h < threshold:
#             arA = aspectRatio(df, quantile, 'Class A')
#             fname = os.path.basename(fishData['path'])
#             print(">>> WARN: %s - Notochord too small! Checking aspect ratio of whole fin instead: wB=%f, hB=%f, arB=%f, arA=%f" %(fname,w,h,ar,arA))
#             return arA < 1.0    # While the notochord must be horizontal in aspect, the whole fin view is usually vertical, so that's our best bet when the notochord is too small
#         else:
#             return ar > 1.0
#     except Exception as e:
#         print('Error: failed at fish %s - probably missing Set column' %(fishData['path']))
#         print(df)
#         raise e 

# def mirrorDataX(fishData):
#     """Check if cut is on the right and, if not, rotate the dataFrames to get it there!"""
#     df_DAPI, df_EdU = fishData['DAPI'], fishData['EdU']
#     if not isCutOnRight(df_DAPI):
#         fname = os.path.basename(fishData['path'])
#         # print(">>> Mirroring X - %s" %(fname))
#         df_DAPI['Position X'] = -1 * df_DAPI['Position X']
#         df_EdU['Position X'] = -1 * df_EdU['Position X']
#         fishData['flags'] += 'M'
#     return df_DAPI, df_EdU

# def flipDataXY(fishData):
#     """Check if X/Y axes are properly oriented and, if not, flip them!"""
#     df_DAPI, df_EdU = fishData['DAPI'], fishData['EdU']
#     # if not isXMoreUmbalancedThanY(df_DAPI, set='Class B'):
#     # if not isXMoreUmbalancedThanY(df_DAPI, [0.0,0.01]):
#     # Flip according to notochord aspect ratio
#     if not isHorizontal(fishData, 0.05):    # i.e. if w/h<1 = taller than wide
#         fname = os.path.basename(fishData['path'])
#         # print(">>> Flipping X and Y - %s" %(fname))
#         df_DAPI['Position X'], df_DAPI['Position Y'] = df_DAPI['Position Y'], df_DAPI['Position X']
#         df_EdU['Position X'], df_EdU['Position Y'] = df_EdU['Position Y'], df_EdU['Position X']
#         fishData['flags'] += 'F'
#     return df_DAPI, df_EdU

# def zeroOnQuantile(fishData, quantile=0.99, column='Position X'):
#     df_DAPI, df_EdU = fishData['DAPI'], fishData['EdU']
#     q = getQuantiles(df_DAPI, quantile, column)
#     df_DAPI[column] = df_DAPI[column] - q
#     df_EdU[column] = df_EdU[column] - q
#     return df_DAPI, df_EdU

# def reorientData(fishData):
#     for fish in fishData:
#         # First flip X Y...
#         fish['DAPI'], fish['EdU'] = flipDataXY(fish)
#         # ...then mirror the data along X if necessary...
#         fish['DAPI'], fish['EdU'] = mirrorDataX(fish)
#         # ...then set the zero on the 0.99 quantile
#         fish['DAPI'], fish['EdU'] = zeroOnQuantile(fish)
#     return fishData

def reorientData(fishData): # Here fishData is a vector of RawData
    for fish in fishData:
        fish.reorientData()
    return fishData


## New class for Raw data handling 


In [None]:
class RawData:
    def __init__(self, df_DAPI, df_EdU, path,
                  alignCut=False, alignCutRounds=5, allowMirrorFlip=True): # qui mettiamo i dati cosi' come sono nei files di imaris
        self.DAPI = df_DAPI
        self.EdU = df_EdU
        self.path = path
        self.flags = '' # non le sappiamo ancora
        self.notochordOffset = 0.0 # non l'abbiamo ancora calcolata
        self.alignCut = alignCut
        self.alignCutRounds = alignCutRounds
        self.allowMirrorFlip = allowMirrorFlip
        self.reorientData() # Always reorient and zero on cut estimation
        self.computeNotochordOffset() # sets the notochord offset

    def _add(self, other, alignToNotochord=False):
        #extract dataframes from tip
        A_DAPI, A_EdU = self.getChannels(alignToNotochord)
        B_DAPI, B_EdU = other.getChannels(alignToNotochord)
        # union of DAPI and EdU dataframes
        all_DAPI = pd.concat([A_DAPI, B_DAPI])
        all_EdU = pd.concat([A_EdU, B_EdU])
        # path: this should be filename, but there is no filename for the union of files
        path = "sum"
        # alignCut: maybe check if it works the same way (assert)
        assert self.alignCut==other.alignCut
        alignCut = self.alignCut
        # alignCutRounds: keep like it is
        # alignCutRounds = self.alignCutRounds
        alignCutRounds = 0
        #now we make the new Raw data object
        return RawData(all_DAPI, all_EdU, path, alignCut, alignCutRounds, allowMirrorFlip=False)
    
    def __add__(self, other):
        return self._add(other, alignToNotochord=False)
    
    def __mul__(self, other): # this is also a sum but to get the data in x aligned at the notochord
        return self._add(other, alignToNotochord=True)
    
    @classmethod
    def sum(cls, dataVec):
        return sum(dataVec[1:], dataVec[0]) # aligns at tip and then sums
    
    @classmethod
    def prod(cls, dataVec):
        return reduce(operator.mul, dataVec[1:], dataVec[0]) # aligns at notochord and then sums

    # re-orienting the data
    def isHorizontal(self, quantile=0.01, threshold=25.0): # Threshold is the min width or height in um
        df = self.DAPI
        try:
            w = width(df, quantile, 'Class B')
            h = height(df, quantile, 'Class B')
            ar = aspectRatio(df, quantile, 'Class B')
            if w < threshold or h < threshold:
                arA = aspectRatio(df, quantile, 'Class A')
                fname = os.path.basename(self.path)
                print(">>> WARN: %s - Notochord too small! Checking aspect ratio of whole fin instead: wB=%f, hB=%f, arB=%f, arA=%f" %(fname,w,h,ar,arA))
                return arA < 1.0    # While the notochord must be horizontal in aspect, the whole fin view is usually vertical, so that's our best bet when the notochord is too small
            else:
                return ar > 1.0
        except Exception as e:
            print('Error: failed at fish %s - probably missing Set column' %(self.path))
            print(df)
            raise e 

    def mirrorDataX(self):
        """Check if cut is on the right and, if not, rotate the dataFrames to get it there!"""
        df_DAPI, df_EdU = self.DAPI, self.EdU
        if not isCutOnRight(df_DAPI):
            fname = os.path.basename(self.path)
            # print(">>> Mirroring X - %s" %(fname))
            df_DAPI['Position X'] = -1 * df_DAPI['Position X']
            df_EdU['Position X'] = -1 * df_EdU['Position X']
            self.flags += 'M'
        return df_DAPI, df_EdU

    def flipDataXY(self):
        """Check if X/Y axes are properly oriented and, if not, flip them!"""
        df_DAPI, df_EdU = self.DAPI, self.EdU
        # if not isXMoreUmbalancedThanY(df_DAPI, set='Class B'):
        # if not isXMoreUmbalancedThanY(df_DAPI, [0.0,0.01]):
        # Flip according to notochord aspect ratio
        if not self.isHorizontal(0.05):    # i.e. if w/h<1 = taller than wide
            fname = os.path.basename(self.path)
            # print(">>> Flipping X and Y - %s" %(fname))
            df_DAPI['Position X'], df_DAPI['Position Y'] = df_DAPI['Position Y'], df_DAPI['Position X']
            df_EdU['Position X'], df_EdU['Position Y'] = df_EdU['Position Y'], df_EdU['Position X']
            self.flags += 'F'
        return df_DAPI, df_EdU
    
    def _zeroOnQuantile(self, df, q, column='Position X'):
        df[column] = df[column] - q

    def zeroOnQuantile(self, quantile=0.99, column='Position X', set='Class A'):
        df_DAPI, df_EdU = self.DAPI, self.EdU
        q = getQuantiles(df_DAPI, quantile, column, set=set)
        self._zeroOnQuantile(df_DAPI, q, column=column)
        self._zeroOnQuantile(df_EdU, q, column=column)
        return df_DAPI, df_EdU

    def reorientData(self):
        if self.allowMirrorFlip:
            # First flip X Y...
            self.DAPI, self.EdU = self.flipDataXY()
            # ...then mirror the data along X if necessary...
            self.DAPI, self.EdU = self.mirrorDataX()
        # ...then set the zero on the 0.99 quantile
        self.DAPI, self.EdU = self.zeroOnQuantile()
        if self.alignCut:
            # ...then test aligning cut vertically
            for i in range(self.alignCutRounds):
                self.DAPI, self.EdU = self.alignCutVertically()
        else:
            self.DAPI, self.EdU = self.alignNotochordHorizontally()
        # In any case, zero on the quantile of the whole fin again
        self.DAPI, self.EdU = self.zeroOnQuantile()
        return self
    
    # calculate notochord offset and save it
    def computeNotochordOffset(self):
        q = getQuantiles(self.DAPI, 0.99, set='Class B')
        self.notochordOffset = q
    
    # functions for retrieving the data I want (interface): path, DAPI and EdU KDEs
    # Raw data are always normalized to the cut position, then in the function we will add the calculations for the raw data with notochord offset
    def getPath(self):
        return self.path

    def getFlags(self):
        return self.flags

    def _getChannels(self):
        return copy.deepcopy(self.DAPI), copy.deepcopy(self.EdU)
    
    def getChannelsFromNotochord(self, set=None):
        DAPI, EdU = self._getChannels()
        q = self.notochordOffset
        self._zeroOnQuantile(DAPI, q)
        self._zeroOnQuantile(EdU, q)
        return RawData.filter(DAPI, set), RawData.filter(EdU, set)
    
    def getChannels(self, fromNotochord=False, set=None):
        if fromNotochord:
            return self.getChannelsFromNotochord(set=set)
        else:
            DAPI, EdU = self._getChannels()
            return RawData.filter(DAPI, set), RawData.filter(EdU, set)
    
    @classmethod  
    def filter(cls, data, set=None):
        if set is not None:
            return data[data['Set'] == set]
        else:
            return data
    
    @classmethod
    def pca(cls, data):
        # Return the center O, the singular values S, the vectors Vt (rows)
        X = data[['Position X', 'Position Y']].values
        O = np.mean(X, axis=0)
        Z = X - O #mean centering
        U,S,Vt = np.linalg.svd(Z)
        # print(Vt) #debug
        # if Vt[0,0] < 0:
        #     Vt[:,0] = -Vt[:,0]
        return O, np.sqrt(S), Vt
    
    @classmethod
    def rotate(cls, data, O, v):
        # Rotate data around O such that v is horizontal
        # v must be a unit vector
        newData = copy.deepcopy(data)
        assert np.isclose(np.linalg.norm(v),1.0), "Vector v must be a unit vector, norm=%f" %(np.linalg.norm(v))
        R = np.array([[v[0], v[1]], [-v[1], v[0]]])
        X = data[['Position X', 'Position Y']].values
        rotX = (R @ (X - O).T).T
        newData['Position X'] = rotX[:,0]
        newData['Position Y'] = rotX[:,1]
        return newData
    
    def alignCutVertically(self):
        # First get non-nothocord data
        data = RawData.filter(self.DAPI, 'Class A')
        qThr = getQuantiles(data, 0.75, 'Position X')
        dataRight = data[data['Position X'] > qThr]
        O, S, Vt = RawData.pca(dataRight)
        # print(Vt) #debug
        v = Vt[1,:] # Pick the second vector (should be the horizontal one)
        if np.abs(v[0]) < np.abs(Vt[0,0]): # If not horizontal, pick the first
            v = Vt[0,:]
        if v[0] < 0:    # Make sure it's pointing right
            v = -v
        # print(v) #debug
        return RawData.rotate(self.DAPI, O, v), RawData.rotate(self.EdU, O, v)
    
    def alignNotochordHorizontally(self):
        # First get notochord data
        data = RawData.filter(self.DAPI, 'Class B')
        O, S, Vt = RawData.pca(data)
        # print(Vt) #debug
        v = Vt[0,:] # Pick the first vector (should be the horizontal one)
        if np.abs(v[0]) < np.abs(Vt[1,0]): # If not horizontal, pick the second
            v = Vt[1,:]
        if v[0] < 0:    # Make sure it's pointing right
            v = -v
        # print(v) #debug
        return RawData.rotate(self.DAPI, O, v), RawData.rotate(self.EdU, O, v)
        
    def getChannel(self, channel, set=None, fromNotochord=False):
        DAPI, EdU = self.getChannels(fromNotochord, set)
        if channel == 'DAPI':
            return DAPI
        elif channel == 'EdU':
            return EdU
        else:
            return None
        
    def countDAPI(self):
        return len(self.DAPI.index)
    
    def countEdU(self):
        return len(self.EdU.index)

        


In [None]:
# New class to set the view for distribution heatmaps plots. Sets size, resolution, origin and formatter (pixels to um) for axis labels

class PlotView:
    def __init__(self, origin, resolution, size): #size in µm
        self.origin = origin # µm
        self.resolution = resolution # µm/px
        self.size = size # µm x µm
        self.shape = self.computeShape() # px x px

    def __repr__(self):
        return "PlotView(origin=%s, resolution=%s, size=%s) : shape=%s" %(self.origin, self.resolution, self.size, self.shape)
    
    @classmethod
    def extrema(cls, X, q=0.0):
        return np.quantile(X, [q, 1-q])
    
    @classmethod
    def computeAxisView(cls, X, padding=0.05): # X in µm, padding is a fraction of the size
        Xmin, Xmax = cls.extrema(X)
        l = Xmax - Xmin
        Xmin = Xmin - padding*l
        Xmax = Xmax + padding*l
        Xsize = Xmax - Xmin
        return Xmin, Xsize
    
    def _getXAxisView(self):
        return self.origin[1], self.size[1]
    
    def _getYAxisView(self):
        return self.origin[0], self.size[0]
    
    def getYAxisLimits(self, multiplier=1.0):
        o, l = self._getYAxisView()
        mid = o + l/2
        d = l * multiplier
        a = mid - d/2
        b = mid + d/2
        return a, b
        
    @classmethod
    def fromData(cls, X, Y, resolution): # X, Y in µm, resolution in µm/px
        # we have to compute origin and size
        xOrigin, xSize = cls.computeAxisView(X)
        yOrigin, ySize = cls.computeAxisView(Y)
        origin = np.array([yOrigin, xOrigin])
        size = np.array([ySize, xSize])
        return cls(origin, resolution, size)
    
    @classmethod
    def fromDataFrame(cls, df, xlabel, ylabel, resolution):
        X = df[xlabel]
        Y = df[ylabel]
        return cls.fromData(X, Y, resolution)
    
    @classmethod
    def union(cls, views):
        resolutions = np.array([v.getResolution() for v in views])
        resolution = views[0].getResolution()
        assert np.all(resolutions == resolution), "All views must have the same resolution"
        origins = np.array([v.getOrigin() for v in views])
        sizes = np.array([v.getSize() for v in views])
        origin = np.min(origins, axis=0)
        size = np.max(sizes, axis=0)
        return cls(origin, resolution, size)
    
    @classmethod
    def decomposedUnion(cls, xAffectingViews, yAffectingViews):
        xView = cls.union(xAffectingViews)
        yView = cls.union(yAffectingViews)
        assert xView.getResolution() == yView.getResolution(), "X and Y views must have the same resolution"
        resolution = xView.getResolution()
        xOrigin, xSize = xView._getXAxisView()
        yOrigin, ySize = yView._getYAxisView()
        origin = np.array([yOrigin, xOrigin])
        size = np.array([ySize, xSize])
        return cls(origin, resolution, size)
    
    def getOrigin(self):
        return self.origin
    
    def getShape(self):
        return self.shape
    
    def getSize(self):
        return self.size
    
    def getResolution(self):
        return self.resolution
    
    def getCoordinates(self, p): # p in µm
        p = np.flip(p) # Invert X,Y to cols,rows
        res = np.subtract(p, self.origin)
        res = np.floor_divide(res, self.resolution).astype(int)
        return res # Coordinates in px
    
    def isInside(self, p):
        c = self.getCoordinates(p)
        return np.all(c >= 0) and np.all(c < self.shape)
    
    def getSigma(self, sigma): # input is sigma in µm
        return np.divide(sigma, self.resolution) # output is sigma in px
    
    def computeShape(self):
        shape = np.add(np.floor_divide(np.flip(self.size), self.resolution), 1)
        return np.rint(shape).astype(int) # shape is size in px rounded up (ceiling)
    
    def getImshowExtent(self): # This lets imshow take care of ticks automatically
        origin = self.origin
        extent = np.add(origin, self.size)
        return [origin[1], extent[1], origin[0], extent[0]]
    
    def getXLinspace(self):
        imshowExtent = self.getImshowExtent()
        return np.linspace(imshowExtent[0], imshowExtent[1], self.shape[1])
    
    # def getXFormatter(self):
    #     return lambda x, pos: '%d' %(np.round(x*self.resolution + self.origin[1]).astype(int)) # inverted X,Y
    
    # def getYFormatter(self):
    #     return lambda y, pos: '%d' %(np.round(y*self.resolution + self.origin[0]).astype(int)) # inverted X,Y
    
    # def getTicks(self, step=50): # step in µm
    #     origin = self.origin
    #     extent = np.add(origin, self.size)
    #     xcomp = (origin[1]<0)
    #     xticks_um = np.arange( ( (origin[1]//step)+xcomp )*step, (extent[1]//step)*step, step)
    #     xticks = np.divide(np.subtract(xticks_um, origin[1]), self.resolution)
    #     ycomp = (origin[0]<0)
    #     yticks_um = np.arange( ( (origin[0]//step)+ycomp )*step, (extent[0]//step)*step, step)
    #     yticks = np.divide(np.subtract(yticks_um, origin[0]), self.resolution)
    #     return xticks, yticks


In [None]:
# find position files (x and y per channel per fish)
def findFish(folder, amputated=True): #here I only need a name for the argument (name space)
    globPattern = 'amputated'
    rePattern = 'amputated'
    if not amputated:
        globPattern = 'nonamputated'
        rePattern = 'nonamputated'
    fishListDir = glob.glob(folder + '/*_' + globPattern + '_*')
    fishList = []
    try:
        # Try the old pattern
        fishList = [ re.search('fish[0-9]+_' + rePattern + '_[0-9]+h', x).group(0) for x in fishListDir ] # list comprehension
    except:
        fishList = [ re.search('[^/]+' + '_[0-9]+h_' + rePattern + '_[A-Z][0-9]+', x).group(0) for x in fishListDir ] # list comprehension
    assert len(fishList)>0, "FATAL: couldn't find any fish file matching the regexp"
    fishListUniq = list(set(fishList)) # select only unique names
    fishListUniq.sort()
    return fishListUniq

# find xy files per un pesce
def findPositionFiles(directory, fishPrefix):
    globPattern = directory + fishPrefix + '_*/*_Position_[XY].csv'
    print(globPattern) #debug
    filesXY = glob.glob(globPattern)
    filesXY.sort() # questo e' per avere in ordine alfabetico
    return filesXY

# leggi dati per channel
def readChannelFilesXY(file_X, file_Y):
    df_x = pd.read_csv(file_X, skiprows=3) #skip first useless rows in csv
    df_y = pd.read_csv(file_Y, skiprows=3)
    df_xy = pd.merge(df_x, df_y[['ID', 'Position Y']], on='ID') #merge + select the non duplicated columns
    df_xy = df_xy.rename(columns={"Set 1": "Set", "Set 2": "Set"}) #rename columns
    return df_xy

# leggi dati per pesce
def readFishFilesXY(DAPI_X, DAPI_Y, EdU_X, EdU_Y, amputated=True):
    df_DAPI = readChannelFilesXY(DAPI_X, DAPI_Y)
    df_EdU = readChannelFilesXY(EdU_X, EdU_Y)
    path = os.path.dirname(DAPI_X)
    # return RawData(df_DAPI, df_EdU, path, alignCut=amputated) # this is the class with the rawdata
    return RawData(df_DAPI, df_EdU, path, alignCut=False) # this is the class with the rawdata
    # return {
    #         'DAPI' : df_DAPI, 
    #         'EdU' : df_EdU, 
    #         'path' : os.path.dirname(DAPI_X),
    #         'flags' : "",   # flags for keeping track of the manipulations done to the data
    #         } # dictionary :)

# this is the function that from the directory and list of prefixis returns all the dataframes
def readFishFiles(directory, prefixList, amputated=True):
    fishData = []
    for p in prefixList:
        positionFiles = findPositionFiles(directory, p)
        try:
            # take the first 4 files to have only EdU and DAPI as now I also have some other folders that come later in lexicographic order
            fish_rawData = readFishFilesXY(*positionFiles[:4], amputated=amputated) # splat operator, puts the content of the list as argument for the function
            fishData.append(fish_rawData)
        except TypeError as e: 
            print('Error: failed to read fish files')
            print(p, positionFiles)
            # raise(e) # uncomment to stop the execution on read error
        except Exception as e: 
            print('Error: failed to read fish files')
            print(p, positionFiles)
            raise(e) # uncomment to stop the execution on read error
    return fishData

## New class for KDE
All data for each fish, plus averages, std and normalisation

In [None]:

# define an object per channel with PDF normalised to number of cells, and math (divisions, mean, variance)

class ChannelDensity:
    def __init__(self, density):
        self.density = density
    
    @classmethod
    def fromKDE(cls, KDE, n): #n number of cells
        density = lambda x : n * KDE.pdf(x)
        return cls(density) # si chiama facendo ChannelDensity.fromKDE(KDE,n)
    
    @classmethod
    def fromData(cls,X):
        n = len(X)
        KDE = sp.stats.gaussian_kde(X)
        return cls.fromKDE(KDE, n)
    
    @classmethod
    def fromDataFrame(cls, df):
        X = df['Position X']
        return cls.fromData(X)
    
    @classmethod
    def fromRawData(cls, fish, fromNotochord=False, set='Class A'):
        df_DAPI, df_EdU = fish.getChannels(fromNotochord=fromNotochord)
        try:
            df_DAPI = df_DAPI[df_DAPI['Set'] == set]
            density_DAPI = cls.fromDataFrame(df_DAPI)
        except Exception as e:
            print('Error: failed at fish %s, channel DAPI' %(fish.getPath()))
            print(df_DAPI)
            raise e
        try:
            df_EdU = df_EdU[df_EdU['Set'] == set]
            density_EdU = cls.fromDataFrame(df_EdU)
        except Exception as e:
            print('Error: failed at fish %s, channel EdU' %(fish.getPath()))
            # print(df_EdU)
            raise e 
        return density_DAPI, density_EdU
    
    def __truediv__(self, other): # x / y invokes x.__truediv__(y)
        density = lambda x : self.density(x)/other.density(x)
        return ChannelDensity(density)
    
    @classmethod
    def mean(cls, *args): # send as many values as you need and it put them into a vector called args (varargs)
        density = lambda x : np.mean([a.density(x) for a in args])
        return ChannelDensity(density)
    
    @classmethod
    def stdev(cls, *args):
        density = lambda x : np.std([a.density(x) for a in args])
        return ChannelDensity(density)

# funzione che da fishData mi rende un vettore di dizionari di FishDensity
def computeFishDensities(fishData, fromNotochord=False):
    fishDensities = []
    for fish in fishData:
        d = {} # inizializza dizionario
        d['path'] = fish.getPath()
        density_DAPI, density_EdU = ChannelDensity.fromRawData(fish, fromNotochord=fromNotochord)
        d['DAPI'] = density_DAPI
        d['EdU'] = density_EdU
        fishDensities.append(d)
    return fishDensities

def getChannel(array, channelName):
    return [ x[channelName] for x in array ]

In [None]:
# New class for binned EdU/DAPI in space (50µm bins)
class BinnedData:
    def __init__(self, bins, values):
        assert len(bins)==len(values) + 1, "Number of bins and values don't match"
        self.bins = bins
        self.values = values

    @classmethod
    def fromData(cls, bins, X):
        values, bins = np.histogram(X, bins)
        return cls(bins, values)
    
    @classmethod
    def fromDataFrame(cls, bins, df):
        X = df['Position X']
        return cls.fromData(bins, X)

    @classmethod
    def fromRawData(cls, bins, fish, fromNotochord=False, set='Class A'):
        df_DAPI, df_EdU = fish.getChannels(fromNotochord=fromNotochord)
        try:
            df_DAPI = df_DAPI[df_DAPI['Set'] == set]
            hist_DAPI = cls.fromDataFrame(bins, df_DAPI)
        except Exception as e:
            print('Error: failed at fish %s, channel DAPI' %(fish.getPath()))
            print(df_DAPI)
            raise e
        try:
            df_EdU = df_EdU[df_EdU['Set'] == set]
            hist_EdU = cls.fromDataFrame(bins, df_EdU)
        except Exception as e:
            print('Error: failed at fish %s, channel EdU' %(fish.getPath()))
            # print(df_EdU)
            raise e 
        return hist_DAPI, hist_EdU
    
    def getBinCenter(self):
        return np.array( [self.bins[:-1], self.bins[1:]] ).transpose().mean(axis=1)
    
    def __truediv__(self, other): # x / y invokes x.__truediv__(y)
        assert np.array_equal(self.bins, other.bins)
        values = np.divide( self.values, other.values )
        return BinnedData(self.bins, values)
    
    @classmethod
    def mean(cls, *args): # send as many values as you need and it put them into a vector called args (varargs)
        assert len(args)>0, "Feed me some args, plz, i'm hungry!"
        a0 = args[0]
        assert type(a0)==cls, "%s is not %s" %(type(a0), cls)
        for a in args[1:]:
            assert np.array_equal(a0.bins, a.bins), "Ca s'appelle un bin, ce pour m'avoir menti!"
        # values = np.array( [ a.values for a in args ] ).transpose().mean(axis=1)
        M = np.array( [ a.values for a in args ] ).transpose()
        values = np.mean(M, axis=1)
        print(values) #debug
        return cls(a0.bins, values)
    
    @classmethod
    def stdev(cls, *args):
        assert len(args)>0, "Feed me some args, plz, i'm hungry!"
        a0 = args[0]
        assert type(a0)==cls, "%s is not %s" %(type(a0), cls)
        for a in args[1:]:
            assert np.array_equal(a0.bins, a.bins), "Ca s'appelle un bin, ce pour m'avoir menti!"
        values = np.array( [ a.values for a in args ] ).transpose().std(axis=1)
        return cls(a0.bins, values)

In [None]:
# qui si cercano i files per ogni time point e si mettono i dati e le densities in un nuovo dizionario?ah
def loadDataAndDensities(directory, timepoints, amputated=True):
    data = {}
    densities = {}
    densitiesNotochord = {}
    for t in timepoints:
        print(">>> Time point: %d" %(t))
        directory = './%s/%dh/' %(base_directory,t) # Here it's important to have a trailing '/'
        fishes = findFish(directory, amputated=amputated)  # prefix list
        print(">>> Number of fish: %d" %(len(fishes))) # check number of images
        fishData = readFishFiles(directory, fishes, amputated=amputated)
        # fishData = reorientData(fishData)
        data[t] = fishData
        fishDensities = computeFishDensities(fishData, fromNotochord=False)
        densities[t] = fishDensities
        fishDensitiesNoto = computeFishDensities(fishData, fromNotochord=True)
        densitiesNotochord[t] = fishDensitiesNoto
    return data, densities, densitiesNotochord

def computeBinnedRatios(bins, fishData, fromNotochord=False):
    fishHisto = []
    for fish in fishData:
        d = {} # inizializza dizionario
        d['path'] = fish.getPath()
        histo_DAPI, histo_EdU = BinnedData.fromRawData(bins, fish, fromNotochord=fromNotochord)
        d['DAPI'] = histo_DAPI
        d['EdU'] = histo_EdU
        fishHisto.append(d)
    return fishHisto

def computeBinnedRatiosTime(bins, data, fromNotochord=False):
    histos = {}
    timepoints = data.keys()
    for t in timepoints:
       print(">>> Time point: %d" %(t))
       fishHisto = computeBinnedRatios(bins, data[t], fromNotochord=fromNotochord)
       histos[t] = fishHisto
    return histos




In [None]:
# Here we load the data
experiment = "WT"
# base_directory = 'test_data'    # test
base_directory = 'data/'+experiment         # prod
plots_directory = 'plots_'+experiment
Path(plots_directory).mkdir(parents=True, exist_ok=True) # Create plots directory if it doesn't exist
# base_directory = 'data'         # temporary, I need to update the data folder to the new structure
timePoints = [1,3,5,7,9,11]
# timePoints = [1]
dataNonAmputated, densitiesNonAmputated, densitiesNonAmputatedNotochord = loadDataAndDensities(base_directory, timePoints, amputated=False)
try:
    dataAmputated, densitiesAmputated, densitiesAmputatedNotochord = loadDataAndDensities(base_directory, timePoints, amputated=True)
except:
    dataAmputated, densitiesAmputated, densitiesAmputatedNotochord = copy.deepcopy(dataNonAmputated), copy.deepcopy(densitiesNonAmputated), copy.deepcopy(densitiesNonAmputatedNotochord)
# bins = [x for x in range(-200, 1, 50)]
bins = [x for x in range(-200, 1, 20)]
# binsNoto = [x for x in range(-100, 51, 50)]
binsNoto = [x for x in range(-100, 51, 20)]
ratiosAmputated = computeBinnedRatiosTime(bins, dataAmputated, fromNotochord=False)
ratiosNonAmputated = computeBinnedRatiosTime(bins, dataNonAmputated, fromNotochord=False)
ratiosAmputatedNoto = computeBinnedRatiosTime(binsNoto, dataAmputated, fromNotochord=True)
ratiosNonAmputatedNoto = computeBinnedRatiosTime(binsNoto, dataNonAmputated, fromNotochord=True)

In [None]:
# Plot all fins to check orientation
def getGrid(n, maxColumns=5):  # https://stackoverflow.com/a/39248503
    val = math.ceil(math.sqrt(n))
    val = min(val, maxColumns) # max 5 columns
    val2 = math.ceil(n/val)
    return val2, val

def plotPCA(ax, data, color):
    o, s, Vt = RawData.pca(data)
    S = np.diag(s)
    O = np.repeat(np.array([o]), len(s), axis=0)
    SVt = S @ Vt
    # print(SVt) #debug
    # return ax.quiver(*O.T, *SVt.T, color=color, scale=0.5, scale_units='xy', angles='xy', width=0.02)
    return ax.quiver(*O[0,:].T, *SVt[0,:].T, color=color, scale=0.5, scale_units='xy', angles='xy', width=0.02)

def scatterPlotFish(curFish, ax, set='Class A'):
    flags = curFish.getFlags()
    if flags:
        ax.title.set_text(os.path.basename(curFish.getPath()) + " - " + curFish.getFlags())
    else:
        ax.title.set_text(os.path.basename(curFish.getPath()))
    # ax.set_box_aspect(1)
    df_DAPI, df_EdU = curFish.getChannels()
    df_DAPIclassA = RawData.filter(df_DAPI, 'Class A')
    df_DAPIclassB = RawData.filter(df_DAPI, 'Class B')
    df_EdUclassA = RawData.filter(df_EdU, 'Class A')
    
    Qax = np.quantile(df_DAPIclassA['Position X'],[0.0,0.01,0.05,0.5,0.95,0.99,1.0])
    Qb = np.quantile(df_DAPIclassB['Position Y'],[0.0,0.01,0.5,0.99,1.0])
    sb.scatterplot(df_DAPIclassA, x = 'Position X', y = 'Position Y', ax = ax )
    sb.scatterplot(df_EdUclassA, x = 'Position X', y = 'Position Y', ax = ax )
    sb.scatterplot(df_DAPIclassB, x = 'Position X', y = 'Position Y', ax = ax )
    for q in Qax:
        ax.axvline(q, color='r')
    for q in Qb:
        ax.axhline(q, color='g')
    plotPCA(ax, df_DAPIclassA, 'tab:blue')
    # plotPCA(ax, df_EdUclassA, 'tab:orange')
    plotPCA(ax, df_DAPIclassB, 'tab:green')

def plotAllSpots(fishData, set='Class A'):
    N = len(fishData)
    print(N, getGrid(N)) #debug
    # fig, axes = plt.subplots(*getGrid(N), figsize=(40,40))
    # fig, axes = plt.subplots(*getGrid(N), figsize=np.multiply(6,getGrid(N)))
    fig, axes = plt.subplots(*getGrid(N), figsize=np.flip( np.multiply(6,getGrid(N)) ), subplot_kw=dict(box_aspect=1),
                         sharex=True, sharey=True, layout="constrained")
    for i, ax in enumerate(axes.flatten()):
        if i >= N:
            # ax.axis('off')
            continue
        curFish = fishData[i]
        scatterPlotFish(curFish, ax, set=set)
    return fig, axes

# merge df of all fish

In [None]:
# Plot amputated raw data
finsAmputated = sum(dataAmputated.values(), []) # Concatenate all the values of amputated data
finsAmputated = finsAmputated[0:15]
print(finsAmputated) #debug
S = RawData.sum(finsAmputated)
# S = RawData.prod(finsAmputated)
finsAmputated.append(S)
# print(len(finsAmputated))
assert len(finsAmputated) > 0, "FATAL: No fish data found! Check the data folder"
# Create plot folder
os.makedirs('plots', exist_ok=True)
fig, axes = plotAllSpots(finsAmputated)
plt.savefig('%s/amputated_fins.pdf' %(plots_directory))

#fig

In [None]:
# Plot non amputated raw data
finsNonAmputated = sum(dataNonAmputated.values(), []) # Concatenate all the values of non amputated data
finsNonAmputated = finsNonAmputated[0:15]
# S = RawData.sum(finsNonAmputated)
S = RawData.prod(finsNonAmputated)
finsNonAmputated.append(S)
# print(len(finsNonAmputated))
fig, axes = plotAllSpots(finsNonAmputated)
plt.savefig('%s/non_amputated_fins.pdf' %(plots_directory))
#fig

In [None]:
# plots with respect to cut
def plotTimePointDensities(ax, fishDensities, start=-200, stop=0, **kwargs):
    DAPI = getChannel(fishDensities,'DAPI')
    EdU = getChannel(fishDensities, 'EdU')
    ratio = np.divide(EdU, DAPI)
    # X = np.linspace(-200, 0, 100)
    # X = np.linspace(-200, 200, 100)
    X = np.linspace(start, stop, 100)
    Y = np.array([ChannelDensity.mean(*ratio).density(s) for s in X])
    error = np.array([ChannelDensity.stdev(*ratio).density(s) for s in X]) # standard deviation of KDEs
    out = ax.plot(X,Y, **kwargs)
    fillkwargs = {i:kwargs[i] for i in kwargs if i not in ['label','alpha']}
    fillalpha = 0.2
    if "alpha" in kwargs:
        fillalpha *= kwargs["alpha"]
    ax.fill_between(X, Y-error, Y+error, **fillkwargs, alpha=fillalpha)
    ax.set_ylim(bottom=0.0, top=0.25)
    return out

fig, ax = plt.subplots(1,2, subplot_kw=dict(box_aspect=1), sharex=True, sharey=True, layout="constrained")
palette = sb.color_palette("rocket_r", n_colors=len(timePoints), as_cmap=True)
for i,t in reversed( list(enumerate(timePoints)) ):
    print(i, palette(i/len(timePoints)))
    ax.flatten()[0].title.set_text("Uncut")
    plotTimePointDensities(ax.flatten()[0], densitiesNonAmputated[t], color=palette(i/len(timePoints)), label="%dh" %(t) )
    ax.flatten()[1].title.set_text("Cut")
    plotTimePointDensities(ax.flatten()[1], densitiesAmputated[t], color=palette(i/len(timePoints)), label="%dh" %(t) )
    # plotTimePointDensities(ax.flatten()[1], densitiesNonAmputatedNotochord[t], color=palette(i/len(timePoints)), label="%dh" %(t) )

fig.suptitle('Distribution of proliferative cells %s' %(experiment), fontsize=16, y=0.95)
plt.legend(loc="upper right")
# plt.xlabel("Position w.r.t. cut (um)")
fig.supxlabel('Position w.r.t. cut (µm)', y=0.1)
plt.ylabel("Proliferative cells/total")
plt.savefig('%s/distribution_analysis_from_injury.pdf' %(plots_directory))


In [None]:
# plots with respect to 
fig, ax = plt.subplots(1,2, subplot_kw=dict(box_aspect=1), sharex=True, sharey=True, layout="constrained")
palette = sb.color_palette("rocket_r", n_colors=len(timePoints), as_cmap=True)
for i,t in reversed( list(enumerate(timePoints)) ):
    print(i, palette(i/len(timePoints)))
    ax.flatten()[0].title.set_text("Uncut")
    # ax[0].vlines(x=0, ymin=0, ymax=0.25, colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position
    plotTimePointDensities(ax.flatten()[0], densitiesNonAmputatedNotochord[t], -100, 150, color=palette(i/len(timePoints)), label="%dh" %(t))
    ax.flatten()[1].title.set_text("Cut")
    plotTimePointDensities(ax.flatten()[1], densitiesAmputatedNotochord[t], -100, 50, color=palette(i/len(timePoints)), label="%dh" %(t) )
    # ax[1].vlines(x=0, ymin=0, ymax=0.25, colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position



fig.suptitle('Distribution of proliferative cells %s' %(experiment), fontsize=16, y=0.95)
# plt.axvline(x=0, color='black', ls='--', label='notochord')
plt.legend(loc="upper right")
# plt.xlabel("Position w.r.t. notochord (um)")
fig.supxlabel('Position w.r.t. notochord (µm)', y=0.1)
plt.ylabel("Proliferative cells/total")
plt.savefig('%s/distribution_analysis_from_notochord.pdf' %(plots_directory))


In [None]:
#plot uncut and cut per timepoint
_ts = [1,3,5,7,9,11]
_l = len(_ts)
_c = 2
_r = -(_l // -_c) # Ceiling integer division: https://stackoverflow.com/a/17511341
_subplotScale = 5
fig, ax = plt.subplots(_r, _c, 
                       subplot_kw=dict(box_aspect=1), 
                       figsize=np.multiply(_subplotScale,(_r,_c)),
                       sharex=True, sharey=True, layout="constrained")
palette = sb.color_palette("rocket_r", n_colors=len(timePoints), as_cmap=True)
for i,t in list(enumerate(_ts)):
    color = palette(timePoints.index(t)/len(timePoints))
    curAx = ax.flatten()[i]
    print(i, color)
    curAx.title.set_text("%dh" %(t)) 
    plotTimePointDensities(curAx, densitiesNonAmputatedNotochord[t], -100, 150, color=color, alpha=1.0, label="Uncut", linestyle='dashed')
    plotTimePointDensities(curAx, densitiesAmputatedNotochord[t], -100, 50, color=color, alpha=0.5, label="Cut" )
    curAx.legend(loc="upper right")
    # ax[1].vlines(x=0, ymin=0, ymax=0.25, colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position

fig.suptitle('Distribution of proliferative cells %s' %(experiment), fontsize=16)#, y=0.95)
# plt.axvline(x=0, color='black', ls='--', label='notochord')
# plt.xlabel("Position w.r.t. notochord (um)")
fig.supxlabel('Position w.r.t. notochord (µm)')#, y=0.1)
plt.ylabel("Proliferative cells/total")
plt.savefig('%s/distribution_analysis_from_notochord_Comparison.pdf' %(plots_directory))


## Plot spatial distribution in xy

In [None]:
#heatmaps of proliferation in summed fins 
# levels = [0.01,0.05,0.1,0.2,0.3,1.0]
# levels = [0.01,0.05,0.1]
# levels = [0.5,0.9,1.0]

fig, ax = plt.subplots(len(timePoints),2, subplot_kw=dict(box_aspect=1), sharex=True, sharey=True, layout="constrained", figsize=(15, 15))
# print(ax.shape) 

for i, t in enumerate(timePoints):
    naAx = ax[i,0] # matrix of the subplots
    naAx.title.set_text("Non Amputated %dh" %t)
    naSum = RawData.prod(dataNonAmputated[t])
    dapi, edu = naSum.getChannelsFromNotochord(set='Class A')
    sb.kdeplot(data=edu, x='Position X', y='Position Y', fill=True, ax=naAx, cbar=True, cmap="rocket_r")
    Y = edu['Position Y']
    naAx.vlines(x=0, ymin=min(Y), ymax=max(Y), colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position
    aAx = ax[i,1]
    aAx.title.set_text("Amputated %dh" %t)
    aSum = RawData.sum(dataAmputated[t])
    dapi, edu = aSum.getChannelsFromNotochord(set='Class A')
    Y = edu['Position Y']
    sb.kdeplot(data=edu, x='Position X', y='Position Y', fill=True, ax=aAx, cbar=True, cmap="rocket_r")
    aAx.vlines(x=0, ymin=min(Y), ymax=max(Y), colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position


# curAx = ax.flatten()[2]
# scatterPlotFish(naSum, curAx)

# curAx = ax.flatten()[3]
# scatterPlotFish(aSum, curAx)
fig.set_figwidth(8)
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig('%s/spatial_analysis_from_notochord.pdf' %(plots_directory))




In [None]:
#Test #2: fill matrix with points coordinates XY, and then gaussian blur (weighted sum). 
   
def printPointOn(M, p, view):   #p, origin and step are coordinates tuples
    if view.isInside(p): # We print only points that fall within the view
        coord = view.getCoordinates(p)
        # print(coord) #debug
        M[coord[0],coord[1]] += 1.0 # the value that the function puts in the matrix cell

# Make matrix
def makeCoordMatrix(X, Y, view):      #initialize empty matrix and fill with cell positions X, Y
    M = np.zeros(view.getShape())
    for p in zip(X,Y):
        # print(p) #debug
        printPointOn(M, p, view)        #this changes M
    return M

# Make function to smooth position of XY
def plotGaussPoints(ax, X, Y, sigma, view, cmax=0.001):
    assert len(X)==len(Y), "X and Y must have the same length"
    Nfish = len(X)
    M = makeCoordMatrix(X, Y, view)
    M = np.divide(M, Nfish) # normalise by number of fish
    xysigma = view.getSigma(sigma)
    # print(xysigma) #debug
    Mblur = gaussian_filter(M, sigma = xysigma)
    print(Mblur.min(), Mblur.max()) #debug
    # vmax = 0.0015
    vmax = cmax
    # interpolation = 'nearest'
    interpolation = 'sinc'
    ax.imshow(Mblur, 
              interpolation=interpolation, interpolation_stage='data',
              cmap="rocket_r", 
              vmin=0, vmax=vmax, origin='lower', extent=view.getImshowExtent())
    # return ax.imshow(M, cmap="rocket_r")
    return Mblur

def plotGaussPointsDf(ax, df, xlabel, ylabel, sigma, plotView, cmax=0.001): # sigma in um and resolution in um/pixel
    X = df[xlabel]
    Y = df[ylabel]
    print(plotView) #debug
    return plotGaussPoints(ax, X, Y, sigma, plotView, cmax=cmax)



In [None]:

timePoints = [1,7]
# fig, ax = plt.subplots(len(timePoints),2, subplot_kw=dict(box_aspect=1), sharex=True, sharey=True, layout="constrained", figsize=(15, 15))
fig, ax = plt.subplots(len(timePoints),2, sharex=True, sharey=True, layout="constrained")
# print(ax.shape) 
resolution = 5.0 # um per pixel
sigma = 10.0 # um
cmax = 0.001 # Max value of the colorscale

nonAmputatedPVs = []
amputatedPVs = []
for i, t in enumerate(timePoints): # Here we compute the views
    naSum = RawData.prod(dataNonAmputated[t])
    dapi, edu = naSum.getChannelsFromNotochord(set='Class A')
    naPV = PlotView.fromDataFrame(edu, 'Position X', 'Position Y', resolution)
    nonAmputatedPVs.append(naPV)
    aSum = RawData.sum(dataAmputated[t])
    dapi, edu = aSum.getChannelsFromNotochord(set='Class A')
    aPV = PlotView.fromDataFrame(edu, 'Position X', 'Position Y', resolution)
    amputatedPVs.append(aPV)

allPVs = np.concatenate([nonAmputatedPVs, amputatedPVs])
commonView = PlotView.decomposedUnion(nonAmputatedPVs, allPVs) # Here we get the common view for all the data
# xticks, yticks = commonView.getTicks()
# print(xticks, yticks) #debug
yLims = commonView.getYAxisLimits(0.75)

for i, t in enumerate(timePoints): # Here we actually plot with the common view
    naAx = ax[i,0] # matrix of the subplots
    # naAx.xaxis.set_major_formatter(commonView.getXFormatter())
    # naAx.yaxis.set_major_formatter(commonView.getYFormatter())
    # naAx.set_xticks(xticks)
    # naAx.set_yticks(yticks)
    naAx.title.set_text("Uncut %dh" %t)
    naSum = RawData.prod(dataNonAmputated[t])
    dapi, edu = naSum.getChannelsFromNotochord(set='Class A')
    plotGaussPointsDf(naAx, edu,'Position X', 'Position Y', sigma, commonView, cmax=cmax)
    naAx.vlines(x=0, ymin=yLims[0], ymax=yLims[1], colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position
    aAx = ax[i,1]
    # aAx.xaxis.set_major_formatter(commonView.getXFormatter())
    # aAx.yaxis.set_major_formatter(commonView.getYFormatter())
    # aAx.set_xticks(xticks)
    # aAx.set_yticks(yticks)
    aAx.title.set_text("Cut %dh" %t)
    aSum = RawData.sum(dataAmputated[t])
    dapi, edu = aSum.getChannelsFromNotochord(set='Class A')
    plotGaussPointsDf(aAx, edu,'Position X', 'Position Y', sigma, commonView, cmax=cmax)
    aAx.vlines(x=0, ymin=yLims[0], ymax=yLims[1], colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position

fig.set_figheight(15)

# curAx = ax.flatten()[2]
# scatterPlotFish(naSum, curAx)

# curAx = ax.flatten()[3]
# scatterPlotFish(aSum, curAx)
# fig.set_figwidth(8)
# plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig('%s/spatial_analysis_gauss.pdf' %(plots_directory))


In [None]:
#Test #3: fill matrix with points coordinates XY, and then gaussian blur (weighted sum) + marginal histo

TPs = timePoints
lTP = len(TPs)
# w, h = plt.figaspect(lTP*0.5)
fig = plt.figure(layout='constrained')
gs = gridspec.GridSpec(lTP*2, 2, figure=fig, height_ratios=[1,0.2]*lTP, width_ratios=[1,1])
# gs = gridspec.GridSpec(lTP*2, 2, figure=fig)

resolution = 5.0 # um per pixel
sigma = 10.0 # um

nonAmputatedPVs = []
amputatedPVs = []
for i, t in enumerate(TPs): # Here we compute the views
    naSum = RawData.prod(dataNonAmputated[t])
    dapi, edu = naSum.getChannelsFromNotochord(set='Class A')
    naPV = PlotView.fromDataFrame(edu, 'Position X', 'Position Y', resolution)
    nonAmputatedPVs.append(naPV)
    aSum = RawData.sum(dataAmputated[t])
    dapi, edu = aSum.getChannelsFromNotochord(set='Class A')
    aPV = PlotView.fromDataFrame(edu, 'Position X', 'Position Y', resolution)
    amputatedPVs.append(aPV)

allPVs = np.concatenate([nonAmputatedPVs, amputatedPVs])
commonView = PlotView.decomposedUnion(nonAmputatedPVs, allPVs) # Here we get the common view for all the data
yLims = commonView.getYAxisLimits(0.75)
commonX = commonView.getXLinspace()

AXes = np.empty((2*lTP,2), dtype=object)
for i, t in enumerate(TPs): # Here we actually plot with the common view
    naAx = plt.subplot( gs[2*i + 0, 0] )   # non-amputated main plot
    aAx = plt.subplot( gs[2*i + 0, 1] )    # amputated main plot
    
    AXes[ 2*i + 0, 0 ] = naAx
    AXes[ 2*i + 0, 1 ] = aAx
    naHisto = plt.subplot( gs[2*i + 1, 0] )   # non-amputated marginal histo
    aHisto = plt.subplot( gs[2*i + 1, 1] )    # amputated marginal histo
    AXes[ 2*i + 1, 0 ] = naHisto
    AXes[ 2*i + 1, 1 ] = aHisto

    naAx.title.set_text("Non Amputated %dh" %t)
    naSum = RawData.prod(dataNonAmputated[t])
    dapi, edu = naSum.getChannelsFromNotochord(set='Class A')
    naM = plotGaussPointsDf(naAx, edu,'Position X', 'Position Y', sigma, commonView)
    naAx.vlines(x=0, ymin=yLims[0], ymax=yLims[1], colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position

    aAx.title.set_text("Amputated %dh" %t)
    aSum = RawData.sum(dataAmputated[t])
    dapi, edu = aSum.getChannelsFromNotochord(set='Class A')
    aM = plotGaussPointsDf(aAx, edu,'Position X', 'Position Y', sigma, commonView)
    aAx.vlines(x=0, ymin=yLims[0], ymax=yLims[1], colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position
    
    naY = naM.sum(axis=0)
    aY = aM.sum(axis=0)
    maxY = np.max((naY,aY))
    naHisto.plot(commonX, naY)
    naHisto.vlines(x=0, ymin=0.0, ymax=maxY, colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position
    aHisto.plot(commonX, aY)
    aHisto.vlines(x=0, ymin=0.0, ymax=maxY, colors='black', linestyles='dashed', linewidth=0.8) # plots dash line at notochord position
    
    # naAx.set_aspect('auto')
    # aAx.set_aspect('auto')
    naAx.set_aspect('equal')
    aAx.set_aspect('equal')
    # naHisto.set_box_aspect(0.2 * naM.shape[0]/naM.shape[1])
    # aHisto.set_box_aspect(0.2 * aM.shape[0]/aM.shape[1])
    
    # Now set the ax sharing
    if i==0:
        aAx.sharey(naAx)
        naHisto.sharex(naAx)
        aHisto.sharex(aAx)
        aHisto.sharey(naHisto)
    else:
        naAx.sharex(AXes[0,0])
        aAx.sharex(AXes[0,1])
        aAx.sharey(naAx)
        naHisto.sharex(AXes[0,0])
        aHisto.sharex(AXes[0,1])
        aHisto.sharey(naHisto)
    
    # Now take care of labels
    plt.setp(aAx.get_yticklabels(), visible=False)
    plt.setp(naHisto.get_yticklabels(), visible=False)
    plt.setp(aHisto.get_yticklabels(), visible=False)
    if i==lTP-1:
        naHisto.set_xlabel('Position X from notochord [µm]')
        aHisto.set_xlabel('Position X from notochord [µm]')
    else:
        plt.setp(naAx.get_xticklabels(), visible=False)
        plt.setp(aAx.get_xticklabels(), visible=False)
        plt.setp(naHisto.get_xticklabels(), visible=False)
        plt.setp(aHisto.get_xticklabels(), visible=False)
        
    # # Finally autoscale everything
    # naAx.autoscale()
    # aAx.autoscale()
    # naHisto.autoscale()
    # aHisto.autoscale()
    # end of loop

scale = 2.5
fig.set_figheight(lTP*scale)
fig.set_figwidth(2*scale)
plt.savefig('%s/spatial_analysis_gauss_marginalHisto.pdf' %(plots_directory))


In [None]:
scale = 2.5
fig.set_figheight(lTP*scale)
fig.set_figwidth(2*scale)
plt.savefig('%s/spatial_analysis_gauss_marginalHisto2.pdf' %(plots_directory))
fig