import os
import cv2 import numpy as np import maxflow import matplotlib.pyplot as plt from medpy import metric from PIL import Image, ImageDraw
left_mouse_down = False right_mouse_down = False foreground_index = 0 background_index = 0 foreground_lines = list() background_lines = list()
class GraphMaker: foreground = 1 background = 0 segmented = 1 default = 0.5 MAXIMUM = 1000000000
def __init__(self, filename): self.image = None self.graph = None self.segment_overlay = None self.mask = None self.filename:str = filename self.load_image(filename) self.background_seeds = [] self.foreground_seeds = [] self.background_average = np.array(3) self.foreground_average = np.array(3) self.nodes = [] self.edges = []
def load_image(self, filename): self.filename = filename self.image = cv2.imread(filename) self.graph = None self.segment_overlay = np.zeros(self.image.shape[:2]) self.mask = None
def add_seed(self, x, y, type): if self.image is None: print('Please load an image before adding seeds.') if type == self.background: if not self.background_seeds.__contains__((x, y)): self.background_seeds.append((x, y)) elif type == self.foreground: if not self.foreground_seeds.__contains__((x, y)): self.foreground_seeds.append((x, y))
def create_graph(self): if len(self.background_seeds) == 0 or len(self.foreground_seeds) == 0: print("Please enter at least one foreground and background seed.") return
print("Making graph") print("Finding foreground and background averages") self.find_averages()
print("Populating nodes and edges") self.populate_graph()
def find_averages(self): self.graph = np.zeros((self.image.shape[0], self.image.shape[1])) print(self.graph.shape) self.graph.fill(self.default) self.background_average = np.zeros(3) self.foreground_average = np.zeros(3)
for coordinate in self.background_seeds: self.graph[coordinate[1] - 1, coordinate[0] - 1] = 0 self.background_average += self.image[coordinate[1], coordinate[0]]
self.background_average /= len(self.background_seeds)
for coordinate in self.foreground_seeds: self.graph[coordinate[1] - 1, coordinate[0] - 1] = 1 self.foreground_average += self.image[coordinate[1], coordinate[0]]
self.foreground_average /= len(self.foreground_seeds)
def populate_graph(self): self.nodes = [] self.edges = [] for (y, x), value in np.ndenumerate(self.graph): if value == 0.0: self.nodes.append((self.get_node_num(x, y, self.image.shape), self.MAXIMUM, 0))
elif value == 1.0: self.nodes.append((self.get_node_num(x, y, self.image.shape), 0, self.MAXIMUM))
else: self.nodes.append((self.get_node_num(x, y, self.image.shape), 0, 0))
for (y, x), value in np.ndenumerate(self.graph): if y == self.graph.shape[0] - 1 or x == self.graph.shape[1] - 1: continue my_index = self.get_node_num(x, y, self.image.shape)
neighbor_index = self.get_node_num(x + 1, y, self.image.shape) g = 1 / (1 + np.sum(np.power(self.image[y, x] - self.image[y, x + 1], 2))) self.edges.append((my_index, neighbor_index, g))
neighbor_index = self.get_node_num(x, y + 1, self.image.shape) g = 1 / (1 + np.sum(np.power(self.image[y, x] - self.image[y + 1, x], 2))) self.edges.append((my_index, neighbor_index, g))
def cut_graph(self): self.segment_overlay = np.zeros_like(self.segment_overlay) self.mask = np.zeros_like(self.image, dtype=bool) g = maxflow.Graph[float](len(self.nodes), len(self.edges)) nodelist = g.add_nodes(len(self.nodes))
for node in self.nodes: g.add_tedge(nodelist[node[0]], node[1], node[2])
for edge in self.edges: g.add_edge(edge[0], edge[1], edge[2], edge[2])
flow = g.maxflow() print("maximum flow is {}".format(flow))
for index in range(len(self.nodes)): if g.get_segment(index) == 1: xy = self.get_xy(index, self.image.shape) self.segment_overlay[xy[1], xy[0]] = 1 self.mask[xy[1], xy[0]] = (True, True, True)
def swap_overlay(self, overlay_num): self.current_overlay = overlay_num
def save_image(self, outfilename): if self.mask is None: print('Please segment the image before saving.') return print(outfilename) to_save = np.zeros_like(self.image)
np.copyto(to_save, self.image, where=self.mask) cv2.imwrite(outfilename, to_save) save_stroke = np.zeros_like(self.image) np.copyto(save_stroke, self.image) for foreground_line in foreground_lines: self.draw_polyline(save_stroke, foreground_line, (0, 0, 255)) for background_line in background_lines: self.draw_polyline(save_stroke, background_line, (255, 0, 0)) cv2.imwrite(outfilename[: -4] + "storke.jpg", save_stroke) return self.segment_overlay
@staticmethod def evaluate(prediction_path, reference_path): reference = cv2.imread(reference_path) prediction = cv2.imread(prediction_path) dice = metric.binary.dc(prediction, reference) hd = metric.binary.hd95(prediction, reference) sensitivity = metric.binary.sensitivity(prediction, reference) specificity = metric.binary.specificity(prediction, reference) accuracy = metric.positive_predictive_value(prediction, reference) print("{:.2f} {:.2f} {:.2f} {:.2f} {:.2f}".format(dice, hd, sensitivity, specificity, accuracy))
@staticmethod def get_node_num(x, y, array_shape): return y * array_shape[1] + x
@staticmethod def get_xy(nodenum, array_shape): return (nodenum % array_shape[1]), (int(nodenum / array_shape[1]))
@staticmethod def draw_polyline(img, lines, color): for i in range(1, len(lines)): cv2.line(img, lines[i-1], lines[i], color=color, thickness=3)
def onMouseClick(event, x, y, flags, param): global left_mouse_down global right_mouse_down global foreground_index global background_index global foreground_lines global background_lines if event == cv2.EVENT_LBUTTONDOWN: foreground_lines.append([]) left_mouse_down = True elif event == cv2.EVENT_LBUTTONUP: foreground_index = foreground_index + 1 left_mouse_down = False elif event == cv2.EVENT_RBUTTONDOWN: background_lines.append([]) right_mouse_down = True elif event == cv2.EVENT_RBUTTONUP: background_index = background_index + 1 right_mouse_down = False elif event == cv2.EVENT_MOUSEMOVE: if left_mouse_down: param.add_seed(x, y, param.foreground) foreground_lines[foreground_index].append((x, y)) elif right_mouse_down: param.add_seed(x, y, param.background) background_lines[background_index].append((x, y))
if __name__ == '__main__': files = os.walk("data/img") for path, dir_list, file_list in files: for file in file_list: foreground_index = 0 foreground_lines.clear() background_lines.clear() background_index = 0 marker = GraphMaker(path + "/" + file) img = cv2.imread(path + "/" + file) cv2.imshow(file, img) cv2.setMouseCallback(file, onMouseClick, marker) cv2.waitKey(0) cv2.destroyAllWindows() marker.create_graph() marker.cut_graph() marker.save_image("data/res/" + file) marker.evaluate("data/res/" + file, "data/mask/" + file[:-4] + ".png")
|