1. 程式人生 > >遺傳演算法解決推箱子問題

遺傳演算法解決推箱子問題

遺傳演算法

遺傳演算法(Genetic Algorithm)是一類借鑑生物界的進化規律(適者生存,優勝劣汰遺傳機制)演化而來的隨機化搜尋方法。最早聽說這個演算法是在一門公選課上,當時瞭解的還包括蟻群演算法等。總之,這種演算法通過模擬自然界物種的繁衍,來尋找適宜生存的種群,達到尋找相對優解的過程。這種方法可以很好的避免我們的演算法找到區域性最優解之後就停滯不前。

推箱子問題

推箱子相比大家都玩過,在地圖上,玩家控制小人把地圖上的箱子推到指定的位置。這個問題看似十分簡單,但當地圖變得十分大,箱子非常多的時候,這個問題實際上並不好解決。本文旨在使用遺傳演算法,自動求解推箱子問題,雖然從實際效果來看,不甚理想,可改進的地方很多,但是也可看作對遺傳演算法和數學建模的一次很好的初級實踐。

本文推箱子地圖的表示參考這個網站:http://sokoban.cn/ 裡面推箱子的問題都十分有趣,可以嘗試一下。

遺傳演算法的流程

圖片來自超詳細的遺傳演算法解析

下面,我們根據演算法流程,逐步求解該問題。

遺傳演算法求解推箱子問題

(1)編碼

我們定義0代表向上,1代表向下,2代表向左,3代表向右。因此給定一個數組,便可確定玩家的所有動作。

(2)推箱子問題環境程式設計

根據上文關於推箱子地圖的定義,@代表玩家,#代表牆,-代表空地,·代表目標點,$代表箱子。為方便起見,目標點我直接通過一個靜態陣列宣告。

----#####----------
----#---#----------
----##--#----------
--###--###---------
--#--#-#-#---------
###################
#-------##----$-.$#
#-------##
[email protected]
$# #####---##----$-.$# --------########### ----#######--------

上面為我們這次測試用的地圖(極其簡單)。

target_points = [[6, 16], [6, 17], [7, 16], [7, 17], [8, 16], [8, 17]]

# 讀取推箱子地圖
def read_to_matrix():
    matrix = []
    f = open("test.txt")
    line = f.readline()
    while line:
        matrix_line = []
        for char in line:
            matrix_line.append(char)
        matrix.append(matrix_line)
        line = f.readline()
    np.array(matrix)
    f.close()
    return matrix


# 尋找玩家位置
def find_person(matrix):
    for i in range(0, 11):
        for j in range(0, 19):
            if matrix[i][j] == '@':
                return i, j
    return -1, -1


# 計算箱子到目標點的最近曼哈頓距離
def calculate_points(matrix):
    points = 0
    for i in range(0, 11):
        for j in range(0, 19):
            if matrix[i][j] == '$':
                distances = []
                for point in target_points:
                    distance = abs(i - point[0]) + abs(j - point[1])
                    distances.append(distance)
                points = points + min(distances)
    return points


# 移動
def move(matrix, direction):
    if direction == 0:
        if up(matrix):
            return True
        return False
    elif direction == 1:
        if down(matrix):
            return True
        return False
    elif direction == 2:
        if left(matrix):
            return True
        return False
    elif direction == 3:
        if right(matrix):
            return True
        return False


def up(matrix):
    i, j = find_person(matrix)
    if i > 0:
        up_i = i - 1
        if matrix[up_i][j] == '-' or matrix[up_i][j] == '.':
            matrix[i][j] = '-'
            matrix[up_i][j] = '@'
            return True
        elif matrix[up_i][j] == '#':
            return False
        elif matrix[up_i][j] == '$':
            if up_i > 0:
                box_up = up_i - 1
                if matrix[box_up][j] == '-' or matrix[box_up][j] == '.':
                    matrix[i][j] = '-'
                    matrix[up_i][j] = '@'
                    matrix[box_up][j] = '$'
                    return True
                else:
                    return False
            else:
                return False
    else:
        return False


def down(matrix):
    i, j = find_person(matrix)
    if i < 10:
        down_i = i + 1
        if matrix[down_i][j] == '-' or matrix[down_i][j] == '.':
            matrix[i][j] = '-'
            matrix[down_i][j] = '@'
            return True
        elif matrix[down_i][j] == '#':
            return False
        elif matrix[down_i][j] == '$':
            if down_i < 10:
                box_down = down_i - 1
                if matrix[box_down][j] == '-' or matrix[box_down][j] == '.':
                    matrix[i][j] = '-'
                    matrix[down_i][j] = '@'
                    matrix[box_down][j] = '$'
                    return True
                else:
                    return False
            else:
                return False
    else:
        return False


def left(matrix):
    i, j = find_person(matrix)
    if j > 0:
        left_j = j - 1
        if matrix[i][left_j] == '-' or matrix[i][left_j] == '.':
            matrix[i][j] = '-'
            matrix[i][left_j] = '@'
            return True
        elif matrix[i][left_j] == '#':
            return False
        elif matrix[i][left_j] == '$':
            if left_j > 0:
                box_left = left_j - 1
                if matrix[i][box_left] == '-' or matrix[i][box_left] == '.':
                    matrix[i][j] = '-'
                    matrix[i][left_j] = '@'
                    matrix[i][box_left] = '$'
                    return True
                else:
                    return False
            else:
                return False
    else:
        return False


def right(matrix):
    i, j = find_person(matrix)
    if j < 18:
        right_j = j + 1
        if matrix[i][right_j] == '-' or matrix[i][right_j] == '.':
            matrix[i][j] = '-'
            matrix[i][right_j] = '@'
            return True
        elif matrix[i][right_j] == '#':
            return False
        elif matrix[i][right_j] == '$':
            if right_j < 18:
                box_left = right_j + 1
                if matrix[i][box_left] == '-' or matrix[i][box_left] == '.':
                    matrix[i][j] = '-'
                    matrix[i][right_j] = '@'
                    matrix[i][box_left] = '$'
                    return True
                else:
                    return False
            else:
                return False
    else:
        return False


#根據陣列移動
def sequence_move(move_array, matrix):
    for direction in move_array:
        if move(matrix, direction):
            continue
        else:
            move_array.remove(direction)
    return calculate_points(matrix)

(3)確定程式所用引數

change_rate = 0.05        #變異率
expel_rate = 0.5          #淘汰率
initial_length = 10       #初始移動步數
max_length = 50           #最大移動步數
add_step = 10             #每次增加的步長
max_epoch = 2000          #最大迭代次數
sample = 100              #種群數量
mix_length = 3            #基因混合長度
animals = []

(4)初始化種群

class Animal:
    def __init__(self, move_array, points):
        self.move_array = move_array
        self.points = points


#初始化種群
for i in range(0,sample):
    change_matrix = read_to_matrix()
    move_array = []
    for j in range(0,initial_length):
        move_array.append(random.randint(0,3))
    points = sequence_move(move_array,change_matrix )
    animals.append(Animal(move_array,points))
animals.sort(key=lambda x:x.points,reverse=False)

(5)計算種群適應度

我們定義適應度函式為所有箱子到最近目標點的曼哈頓距離,當該距離為0時,我們可以確定所有的箱子到達目標點,遊戲成功。

其實在上述初始化種群的過程中,計算種群適應度已經完成,這個函式上面已經出現過了。函式如下:

# 計算箱子到目標點的最近曼哈頓距離
def calculate_points(matrix):
    points = 0
    for i in range(0, 11):
        for j in range(0, 19):
            if matrix[i][j] == '$':
                distances = []
                for point in target_points:
                    distance = abs(i - point[0]) + abs(j - point[1])
                    distances.append(distance)
                points = points + min(distances)
    return points

(6)選擇淘汰

這裡我們採用最最簡單的淘汰辦法(其實這樣不大好),直接根據淘汰率,選擇適應度最差的個體淘汰。

因此排序之後,直接:

animals = animals[0:50]

(7)變異

根據變異率,隨機的進行變異,考慮到這是一個不斷增長的序列,因此也要有一定機率隨機增加染色體的基因。

# 變異
def variation(animals):
    for animal in animals:
        if animal.points!=0:
            for m in range(0,len(animal.move_array)):
                ret = random.random()
                if ret<change_rate:
                    animal.move_array[m] = random.randint(0,3)
                elif ret<2*change_rate and len(animal.move_array)<max_length:
                    animal.move_array.insert(m,random.randint(0,3))
    return animals

(8)複製

根據mix_length,隨機選擇父本或母本的基因繼承,考慮到這是一個不斷增長的序列,在不超過最大序列長度限制的前提下,每次隨機在尾部增加幾位。

# 複製
def hybridize(parents):
    child = []
    father = parents[0]
    mather = parents[1]
    i = 0
    while i < len(father.move_array)-mix_length and i < len(mather.move_array)-mix_length:
        flag =  random.randint(0,1)
        if flag == 0:
            for j in range(0,mix_length):
                child.append(father.move_array[i+j])
        else:
            for j in range(0,mix_length):
                child.append(mather.move_array[i+j])
        i = i+mix_length
    if len(child)<max_length:
        for k in range(0,add_step):
            child.append(random.randint(0,3))
    return child

(9)迴圈往復

#種群迭代
for k in range(0,max_epoch):
    animals = animals[0:50]
    animals = variation(animals)    #變異
    for n in range(0,50):
        change_matrix = read_to_matrix()
        # 交配增殖
        parents = random.sample(animals, 2)
        child_array = hybridize(parents=parents)
        points = sequence_move(child_array,change_matrix)
        animals.append(Animal(child_array,points))
    animals.sort(key=lambda x: x.points,reverse=False)
    print(animals[0].points)
    if animals[0].points==0:
        print(animals[0].move_array)
        break

具體程式碼見我的github:genetic-algorithm

參考

超詳細的遺傳演算法解析

十分鐘搞懂遺傳演算法