import numpy as np
from scipy import ndimage, optimize
import collections
from .cell_describe import node, edge, colony
[docs]class data:
def __init__(self, v, t):
"""
Data class made specifically for the pickle file format used at the
Allen Institute for Cell Science
Parameters
---------
v is data structure obtained after loading the pickle file
t is time step
---------
"""
self.v = v
self.t = t
self.length = len(self.v[2][self.t])
[docs] def x(self, index, f_or_l):
"""
Returns x co-ordinate of branch number "index" at branch end "f_or_l"
If f_or_l (str value - 'first', or 'last') not specified,
returns all x co-ordinates along the branch
-------------
Parameters
-------------
index - index of branch in the list of branches in the data
f_or_l - either first or last index on a branch
"""
if f_or_l == "first":
return self.v[1][self.t][self.v[2][self.t][index][0], 1]
elif f_or_l == "last":
loc = len(self.v[2][self.t][index][:]) - 1
return self.v[1][self.t][self.v[2][self.t][index][loc], 1]
else:
return self.v[1][self.t][self.v[2][self.t][index][:], 1]
[docs] def y(self, index, f_or_l):
"""
Returns y co-ordinate of branch number "index" at branch end "f_or_l"
If f_or_l (str value - 'first', or 'last') not specified,
returns all y co-ordinates along the branch
-------------
Parameters
-------------
index - index of branch in the list of branches in the data
f_or_l (str) - either "first" or "last" index on a branch
"""
if f_or_l == "first":
return self.v[1][self.t][self.v[2][self.t][index][0], 0]
elif f_or_l == "last":
loc = len(self.v[2][self.t][index][:]) - 1
return self.v[1][self.t][self.v[2][self.t][index][loc], 0]
else:
return self.v[1][self.t][self.v[2][self.t][index][:], 0]
[docs] def add_node(self, index, f_or_l):
"""
Define a node on branch "index" and location on branch "f_or_l" (str)
"""
return node((self.x(index, f_or_l), self.y(index, f_or_l)))
[docs] def add_edge(self, node_a, node_b, index=None, x=None, y=None):
"""
Define an edge given a branch index and end nodes of class node
Calls fit() to fit a curve to the data set
-----------
Parameters
-----------
node_a - node object at one end of the edge
node_b - node object at other end of the edge
index = Branch location. If this is specified, can get the x and y co-ordinates
of this branch location from data
If index is not specified, x and y need to be provided.
x - x co-ordinates along the edge
y - y co-ordinates along the edge
"""
# Get all co-ordinates along the branch
if index is not None:
x = self.x(index, None)
y = self.y(index, None)
# we want to fit a curve to this. Use least squares fitting.
# output is radius and x,y co-ordinates of the centre of circle
radius, xc, yc, residu_2 = self.fit(x, y)
# Check the direction of the curve. Do this by performing cross product
x1, y1, x2, y2 = x[0], y[0], x[-1], y[-1]
v1 = [x1 - xc, y1 - yc]
v2 = [x2 - xc, y2 - yc]
cr = np.cross(v1, v2)
a = 0.5 * np.linalg.norm(np.subtract([x2, y2], [x1, y1])) # dist to midpoint
# Check if radius is 0
if radius > 0:
# Check for impossible arc
if a < radius:
# if cross product is negative, then we want to go from node_a to node_b
# if positive, we want to go from node_b to node_a
if cr > 0: # correct is cr > 0
ed = edge(node_b, node_a, radius, None, None, x, y)
else:
ed = edge(node_a, node_b, radius, None, None, x, y)
else:
# Add rnd to radius, since its fitting an impossible arc.
# Can also choose to exit/break code if we enter this block,
# since arc-fit doesn't work.
rnd = a - radius + 5
if cr > 0:
ed = edge(node_b, node_a, radius + rnd, None, None, x, y)
else:
ed = edge(node_a, node_b, radius + rnd, None, None, x, y)
else:
# if no radius, leave as None
ed = edge(node_a, node_b, None, None, None, x, y)
ed.center_of_circle = [xc, yc]
ed.curve_fit_residual = residu_2
return ed
[docs] def fit(self, x, y):
"""
Fit a circular arc to a list of co-ordinates
-----------
Parameters
-----------
x, y
"""
def calc_R(xc, yc):
""" calculate the distance of each 2D points from the center (xc, yc) """
return np.sqrt((x - xc) ** 2 + (y - yc) ** 2)
def f_2(c):
""" calculate the algebraic distance between the data points and the mean circle centered at c=(xc, yc) """
ri = calc_R(*c)
return ri - ri.mean()
x_m = np.mean(x)
y_m = np.mean(y)
center_estimate = x_m, y_m
center_2, ier = optimize.leastsq(f_2, center_estimate)
xc_2, yc_2 = center_2
ri_2 = calc_R(*center_2)
r_2 = ri_2.mean()
residu_2 = np.sum((ri_2 - r_2) ** 2)
theta1 = np.rad2deg(np.arctan2(y[np.argmax(x)] - yc_2, x[np.argmax(x)] - xc_2)) # starting angle
theta2 = np.rad2deg(np.arctan2(y[np.argmin(x)] - yc_2, x[np.argmin(x)] - xc_2))
return r_2, xc_2, yc_2, residu_2
[docs] def post_processing(self, cutoff, num=None):
"""
post process the data to merge nodes that are within a distance
specified as 'cutoff'. Also calls functions
(1) remove_dangling_edges
(2) remove_two_edge_connections
----------
Parameters
------------
cutoff - distance within which we merge nodes
num (optional) - Number of branches to consider (default - all of them)
"""
nodes, edges = [], []
if not num:
num = self.length
for index in range(num):
# Add 2 nodes at both ends of the branch
node_a = self.add_node(index, "first")
node_b = self.add_node(index, "last")
dist_a, dist_b = [], []
for n in nodes:
# Find distance of all nodes in the list (nodes) from current node_a
dist_a.append(np.linalg.norm(np.subtract(n.loc, node_a.loc)))
if not dist_a:
# If dist = [], then nodes was empty -> add node_a to list
nodes.append(node_a)
else:
# If all values in dist are larger than a cutoff, add the node
if all(i >= cutoff for i in dist_a):
nodes.append(node_a)
else:
# Find index of minimum distance value. Replace node_a with the node at that point
ind = dist_a.index(min(dist_a))
node_a = nodes[ind]
# Have to do this separately and not with node_a because
# we want to check that distance between node_a and node_b is very small too
# So if node_b is so close to node_a that we replace node_b with node_a, have to not add that edge
for n in nodes:
dist_b.append(np.linalg.norm(np.subtract(n.loc, node_b.loc)))
if not dist_b:
# If dist = [], then nodes was empty -> add node_a to list
nodes.append(node_b)
else:
# If all values in dist are larger than a cutoff, add the node
if all(i >= cutoff for i in dist_b):
nodes.append(node_b)
else:
# Find index of minimum distance value. Replace node_a with the node at that point
ind = dist_b.index(min(dist_b))
node_b = nodes[ind]
if node_a.loc != node_b.loc:
ed = self.add_edge(node_a, node_b, index)
edges.append(ed)
# Below are the next 3 post processing steps
# Step 1 - remove small stray edges (nodes connected to 1 edge)
# These are not stray edges along the boundary, but rather small
# stray edges sticking out from another edge that shouldn't be there.
# Sometimes, these could be edges that are not connected to any other edge.
# This is a common error in automated skeletonization
nodes, edges = self.remove_dangling_edges(nodes, edges)
# Step 2 - remove small cells
# Small cells occur when automated skeletonization says there should be multiple nodes
# at a location where there should only be one node. We want to merge these nodes/cells
# into a single node
nodes, edges, new_edges = self.remove_small_cells(nodes, edges)
# Step 3 - remove nodes connected to 2 edges
# Finally, we want to check that there are no nodes connected to 2 edges as
# we cannot perform a force balance at such a node
nodes, edges = self.remove_two_edge_connections(nodes, edges)
return nodes, edges, new_edges
[docs] def remove_dangling_edges(self, nodes, edges):
"""
Clean up nodes connected to 1 edge
Do this by -
Removing edges that are really small and connected to 2 other edges at a nearly 90 deg angle
Also remove edges that are connected to nobody else
"""
# Get nodes connected to 1 edge
n_1_edges = [n for n in nodes if len(n.edges) == 1]
# Get those edges
if len(n_1_edges) > 0:
e_1 = [e for f in n_1_edges for e in f.edges]
# Get the other node on these edges
n_1_edges_b = [n for j, f in enumerate(e_1) for n in f.nodes if n != n_1_edges[j]]
for j, e in enumerate(e_1):
# Check that this edge is really small
if e.straight_length < 3:
# Get all the edges on node_b of the edge e
other_edges = [a for a in n_1_edges_b[j].edges if a != e]
# Get the angle and edge of the edges that are perpendicular (nearly) to e
perps = []
perps = [b for b in other_edges if 85 < abs(e.edge_angle(b)) < 95] # 85 - 95, 40 - 140
# If there is such a perpendicular edge, we want to delete e
if perps:
other_node = [n for n in e.nodes if n != n_1_edges_b[j]][0]
e.kill_edge(n_1_edges_b[j])
if e in edges:
edges.remove(e)
nodes.remove(other_node)
# Check for special case -> 2 nodes connected to single edge which they share - so connected to each other
repeated_edges = [item for item, count in collections.Counter(e_1).items() if count > 1]
for e in repeated_edges:
edges.remove(e)
nodes.remove(e.node_a)
nodes.remove(e.node_b)
return nodes, edges
[docs] def remove_two_edge_connections(self, nodes, edges):
"""
Clean up nodes connected to 2 edges
"""
# Get nodes connected to 2 edges
n_2 = [n for n in nodes if len(n.edges) == 2]
# If there is such a node
if len(n_2) > 0:
for n in n_2:
angle = n.edges[0].edge_angle(n.edges[1])
if 0 < abs(angle) < 180 or 0 < abs(angle) < 60:
# Get non common node in edge 0
node_a = [a for a in n.edges[0].nodes if a != n][0]
# Get non common node in edge 1
node_b = [a for a in n.edges[1].nodes if a != n][0]
# Remove edge 0 from node_a and edge 1 from node_b
# Remove corresponding tension vectors saved in node_a and node_b
ind_a = node_a.edges.index(n.edges[0])
ind_b = node_b.edges.index(n.edges[1])
node_a.tension_vectors.pop(ind_a)
node_b.tension_vectors.pop(ind_b)
node_a.edges.remove(n.edges[0])
node_b.edges.remove(n.edges[1])
# Get co-ordinates of edge 0 and edge 1
x1, y1 = n.edges[0].co_ordinates
x2, y2 = n.edges[1].co_ordinates
# Extend the list x1, y1 to include x2 and y2 values
if x1[-1] == x2[0]:
new_x = np.append(x1, x2)
new_y = np.append(y1, y2)
else:
new_x = np.append(x1, x2[::-1])
new_y = np.append(y1, y2[::-1])
# Define a new edge with these co-ordinates
try:
new_edge = self.add_edge(node_a, node_b, None, new_x, new_y)
# Finish cleanup. remove edge 0 and edge 1 from node n and then remove node n
# Add a new edge to the list
edges.remove(n.edges[0])
edges.remove(n.edges[1])
nodes.remove(n)
edges.append(new_edge)
except AssertionError:
pass
return nodes, edges
[docs] def remove_small_cells(self, nodes, edges):
"""
Clean up small cells that have a small perimeter
"""
# Get unique cells
cells = self.find_cycles(edges)
# Define a cutoff perimeter. We use 150, an arbitrary small value
# that works for cleaning up most skeletonization errors in AICS dataset
cutoff_perim = 150
small_cells = [cell for cell in cells if cell.perimeter() < cutoff_perim]
new_edges = []
for cell in small_cells:
# Delete the edges and tension vector saved in the nodes that are part of this cell
for ed in cell.edges:
if ed in ed.node_a.edges:
ed.kill_edge(ed.node_a)
if ed in ed.node_b.edges:
ed.kill_edge(ed.node_b)
# Also remove this edge from the list of edges
if ed in edges:
edges.remove(ed)
# Make a new node
all_loc = [cell.nodes[i].loc for i in range(len(cell.nodes))]
x, y = [i[0] for i in all_loc], [i[1] for i in all_loc]
new_x, new_y = np.mean(x), np.mean(y)
new_node = node((new_x, new_y))
# Now we defined a new node, have to add new edges
# Lets add a new edge with the first edge on a node - node.edges[0]
# Old edge is node.edges[0]. want to replace it with a new_edge
for n in cell.nodes:
if len(n.edges) == 0:
if n in nodes:
nodes.remove(n)
for n in cell.nodes:
if len(n.edges) > 0:
for ned in n.edges:
node_b = [a for a in ned.nodes if a != n][0]
x1, y1 = ned.co_ordinates
new_x1, new_y1 = np.append(x1, new_x), np.append(y1, new_y)
ned.kill_edge(n)
ned.kill_edge(node_b)
# Finish cleanup
# Delete memory of the old edge from the nodes and then remove it from the list of edges
if ned in edges:
edges.remove(ned)
# Add new edge
new_edge = self.add_edge(node_b, new_node, None, new_x1, new_y1)
new_edges.append(new_edge)
edges.append(new_edge)
if n in nodes:
nodes.remove(n)
nodes.append(new_node)
# Check for nodes that are not connected to any edges
for n in nodes:
if len(n.edges) == 0:
nodes.remove(n)
return nodes, edges, new_edges
[docs] @staticmethod
def find_cycles(edges):
"""
Find cycles given a list of edges.
Takes a list of edges and for every edge, gives a maximum of 2 cells that its connected to
This method calls which_cell which in turn calls recursive_cycle_finder
"""
# Set max iterations for cycle finding
max_iter = 300
# Set initial cells
cells = []
for e in edges:
cell = e.which_cell(edges, 0, max_iter)
check = 0
if cell:
for c in cells:
if set(cell.edges) == set(c.edges):
check = 1
if check == 0:
for edge in cell.edges:
edge.cells = cell
cells.append(cell)
cell = e.which_cell(edges, 1, max_iter)
check = 0
if cell:
for c in cells:
if set(cell.edges) == set(c.edges):
check = 1
if check == 0:
for edge in cell.edges:
edge.cells = cell
cells.append(cell)
return cells
[docs] def compute(self, cutoff, nodes=None, edges=None):
"""
Computation process. Steps ->
(1) Call post_processing() -> returns nodes and edges
(2) Call which_cell() for each edge -> returns cells
(3) Define colony
(4) Call calculate_tension() - find tensions
(1) If any bad edges (> 3 std away, call remove_outliers() and repeat computation)
(5) Call calculate_pressure() - find pressure
"""
# Get nodes, edges
if nodes is None and edges is None:
nodes, edges, _ = self.post_processing(cutoff, None)
# Get unique cells
cells = self.find_cycles(edges)
# Get tension and pressure
edges2 = [e for e in edges if e.radius is not None]
col1 = colony(cells, edges2, nodes)
tensions, p_t, a_mat = col1.calculate_tension()
# Check for bad tension values
# Find mean and std
mean = np.mean(tensions)
sd = np.std(tensions)
# Possibly recompute tensions by deleting edges if there are any poorly scaled tensions
# Find tensions more than 3 standard deviations away
# bad_tensions = [x for x in tensions if (x < mean - 3 * sd) or (x > mean + 3 * sd)]
# if len(bad_tensions) > 0:
# new_nodes, new_edges = col1.remove_outliers(bad_tensions, tensions)
# col1, tensions, _, P_T, _, A, _ = self.compute(cutoff, new_nodes, new_edges)
pressures, p_p, b_mat = col1.calculate_pressure()
return col1, tensions, pressures, p_t, p_p, a_mat, b_mat
[docs] def plot(self, ax, type=None, num=None, **kwargs):
"""
Plot the data set
----------------
Parameters
----------------
ax - axes to be plotted on
type - "edge_and_node", "node", "edge", "image" - specifying what you want to plot
num - number of branches to be plotted
"""
if not num:
num = self.length
else:
pass
if type == "edge_and_node" or type == "edges_and_nodes" or type == "edge_and_nodes" or type == "edges_and_node":
for i in range(num):
ax.plot(self.x(i, None), self.y(i, None), **kwargs)
ax.plot(self.x(i, "first"), self.y(i, "first"), 'ok')
ax.plot(self.x(i, "last"), self.y(i, "last"), 'ok')
elif type == "node" or type == "nodes":
for i in range(num):
ax.plot(self.x(i, "first"), self.y(i, "first"), 'ok')
ax.plot(self.x(i, "last"), self.y(i, "last"), 'ok')
elif type == "edge" or type == "edges":
for i in range(num):
ax.plot(self.x(i, None), self.y(i, None), **kwargs)
elif type == "image" or type == "images":
# plot image
img = ndimage.rotate(self.v[0][self.t] == 2, 0)
# plot the image with origin at lower left
ax.imshow(img, origin='lower')
ax.set(xlim=[0, 1000], ylim=[0, 1000], aspect=1)