最大流最小割问题

可以参考这篇博客

graph cut思想

能量函数


其中前一项是区域项。而后一项是边界项。而函数的优化目的是找能量最低的位置。

首先我们需要一些种子点标注它是前景还是背景,然后我们可以根据前景点和后景点构建概率直方图。
图中的含义是p节点的前景的区域概率为p节点在前景直方图的概率的负对数。

而边界项的计算公式为

也就是说随着颜色差异和距离的增大,B在减小。

gprah cut的目的是为了区分出哪些是前景,哪些是背景来进行分割。因此我们可以采用最大流最小割的思想,将源点认为是前景,将汇点认为是背景,而图像中的像素是其他点,我们需要找到最小割将前景和背景分割开来。

定义完点之后还需要定义流量,流量分为端点到像素点之间的流量和像素点到像素点之间的流量:

  • 端点到像素点: 这一项是前面公式中的R(L),它的定义为如果是前景种子点,则到前景流量为无穷大,到背景点流量为0.其他点到源点和汇点计算遵从上面的公式
  • 像素点到像素点: 这一项是公式中的B(L),遵从上面公式

代码实现

import os

import cv2
import numpy as np
import maxflow
import matplotlib.pyplot as plt
from medpy import metric
from PIL import Image, ImageDraw

left_mouse_down = False
right_mouse_down = False
foreground_index = 0
background_index = 0
foreground_lines = list()
background_lines = list()


class GraphMaker:
foreground = 1
background = 0
segmented = 1
default = 0.5
MAXIMUM = 1000000000

def __init__(self, filename):
self.image = None
self.graph = None
self.segment_overlay = None
self.mask = None
self.filename:str = filename
self.load_image(filename)
self.background_seeds = []
self.foreground_seeds = []
self.background_average = np.array(3)
self.foreground_average = np.array(3)
self.nodes = []
self.edges = []

def load_image(self, filename):
self.filename = filename
self.image = cv2.imread(filename)
self.graph = None
self.segment_overlay = np.zeros(self.image.shape[:2])
self.mask = None

def add_seed(self, x, y, type):
if self.image is None:
print('Please load an image before adding seeds.')
if type == self.background:
if not self.background_seeds.__contains__((x, y)):
self.background_seeds.append((x, y))
elif type == self.foreground:
if not self.foreground_seeds.__contains__((x, y)):
self.foreground_seeds.append((x, y))

def create_graph(self):
if len(self.background_seeds) == 0 or len(self.foreground_seeds) == 0:
print("Please enter at least one foreground and background seed.")
return

print("Making graph")
print("Finding foreground and background averages")
self.find_averages()

print("Populating nodes and edges")
self.populate_graph()

def find_averages(self):
self.graph = np.zeros((self.image.shape[0], self.image.shape[1]))
print(self.graph.shape)
self.graph.fill(self.default) # 初始化填充为0.5
self.background_average = np.zeros(3)
self.foreground_average = np.zeros(3)

for coordinate in self.background_seeds:
self.graph[coordinate[1] - 1, coordinate[0] - 1] = 0
self.background_average += self.image[coordinate[1], coordinate[0]]

self.background_average /= len(self.background_seeds) # 之后没有调用,R(x)需要加上直方图

for coordinate in self.foreground_seeds:
self.graph[coordinate[1] - 1, coordinate[0] - 1] = 1
self.foreground_average += self.image[coordinate[1], coordinate[0]]

self.foreground_average /= len(self.foreground_seeds)

def populate_graph(self):
self.nodes = []
self.edges = []
for (y, x), value in np.ndenumerate(self.graph):
if value == 0.0:
# 索引,背景值,前景值
# nodes是到s,t节点的值
self.nodes.append((self.get_node_num(x, y, self.image.shape), self.MAXIMUM, 0)) # 定义节点到源点和汇点之间的流量

elif value == 1.0:
self.nodes.append((self.get_node_num(x, y, self.image.shape), 0, self.MAXIMUM))

else:
self.nodes.append((self.get_node_num(x, y, self.image.shape), 0, 0)) # 普通节点到源节点之间没有流量

for (y, x), value in np.ndenumerate(self.graph):
if y == self.graph.shape[0] - 1 or x == self.graph.shape[1] - 1:
continue
my_index = self.get_node_num(x, y, self.image.shape)

neighbor_index = self.get_node_num(x + 1, y, self.image.shape)
g = 1 / (1 + np.sum(np.power(self.image[y, x] - self.image[y, x + 1], 2))) # 定义像素节点之间的流量
# print("g is " + str(g))
# edges是节点之间的值
self.edges.append((my_index, neighbor_index, g))

neighbor_index = self.get_node_num(x, y + 1, self.image.shape)
g = 1 / (1 + np.sum(np.power(self.image[y, x] - self.image[y + 1, x], 2)))
self.edges.append((my_index, neighbor_index, g))

def cut_graph(self):
self.segment_overlay = np.zeros_like(self.segment_overlay)
self.mask = np.zeros_like(self.image, dtype=bool)
g = maxflow.Graph[float](len(self.nodes), len(self.edges))
nodelist = g.add_nodes(len(self.nodes)) # 创建图中的点

for node in self.nodes:
g.add_tedge(nodelist[node[0]], node[1], node[2]) # 普通点到源点和汇点之间的路,其中源点是background,汇点是foreground

for edge in self.edges:
g.add_edge(edge[0], edge[1], edge[2], edge[2]) # 普通点和普通点之间的流量

flow = g.maxflow() # 执行最大流最小割算法,返回的是最小割的流量
print("maximum flow is {}".format(flow))

for index in range(len(self.nodes)):
if g.get_segment(index) == 1: # 获得划分属于前景的点
xy = self.get_xy(index, self.image.shape) # 前景的xy坐标
self.segment_overlay[xy[1], xy[0]] = 1
self.mask[xy[1], xy[0]] = (True, True, True)

def swap_overlay(self, overlay_num):
self.current_overlay = overlay_num

def save_image(self, outfilename):
if self.mask is None:
print('Please segment the image before saving.')
return
print(outfilename)
# print(self.image.name())
to_save = np.zeros_like(self.image)

np.copyto(to_save, self.image, where=self.mask)
cv2.imwrite(outfilename, to_save)
save_stroke = np.zeros_like(self.image)
np.copyto(save_stroke, self.image)
for foreground_line in foreground_lines:
self.draw_polyline(save_stroke, foreground_line, (0, 0, 255))
for background_line in background_lines:
self.draw_polyline(save_stroke, background_line, (255, 0, 0))
cv2.imwrite(outfilename[: -4] + "storke.jpg", save_stroke)
return self.segment_overlay

@staticmethod
def evaluate(prediction_path, reference_path):
reference = cv2.imread(reference_path)
prediction = cv2.imread(prediction_path)
dice = metric.binary.dc(prediction, reference)
hd = metric.binary.hd95(prediction, reference)
sensitivity = metric.binary.sensitivity(prediction, reference)
specificity = metric.binary.specificity(prediction, reference)
accuracy = metric.positive_predictive_value(prediction, reference)
print("{:.2f} {:.2f} {:.2f} {:.2f} {:.2f}".format(dice, hd, sensitivity, specificity, accuracy))

@staticmethod
def get_node_num(x, y, array_shape):
return y * array_shape[1] + x

@staticmethod
def get_xy(nodenum, array_shape):
return (nodenum % array_shape[1]), (int(nodenum / array_shape[1]))

@staticmethod
def draw_polyline(img, lines, color):
for i in range(1, len(lines)):
cv2.line(img, lines[i-1], lines[i], color=color, thickness=3)


def onMouseClick(event, x, y, flags, param):
global left_mouse_down
global right_mouse_down
global foreground_index
global background_index
global foreground_lines
global background_lines
if event == cv2.EVENT_LBUTTONDOWN:
foreground_lines.append([])
left_mouse_down = True
elif event == cv2.EVENT_LBUTTONUP:
foreground_index = foreground_index + 1
left_mouse_down = False
elif event == cv2.EVENT_RBUTTONDOWN:
background_lines.append([])
right_mouse_down = True
elif event == cv2.EVENT_RBUTTONUP:
background_index = background_index + 1
right_mouse_down = False
elif event == cv2.EVENT_MOUSEMOVE:
if left_mouse_down:
param.add_seed(x, y, param.foreground)
foreground_lines[foreground_index].append((x, y))
elif right_mouse_down:
param.add_seed(x, y, param.background)
background_lines[background_index].append((x, y))


if __name__ == '__main__':
# 完成main函数
files = os.walk("data/img")
for path, dir_list, file_list in files:
for file in file_list:
foreground_index = 0
foreground_lines.clear()
background_lines.clear()
background_index = 0
marker = GraphMaker(path + "/" + file)
img = cv2.imread(path + "/" + file)
cv2.imshow(file, img)
cv2.setMouseCallback(file, onMouseClick, marker)
cv2.waitKey(0)
cv2.destroyAllWindows()
marker.create_graph()
marker.cut_graph()
marker.save_image("data/res/" + file)
marker.evaluate("data/res/" + file, "data/mask/" + file[:-4] + ".png")