本文基于 優(yōu)化算法筆記(六)遺傳算法 - 簡書 (jianshu.com) 進(jìn)行實(shí)現(xiàn),建議先看原理碉碉。
輸出結(jié)果如下
GA.gif
實(shí)現(xiàn)代碼如下
# 遺傳算法
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image
import shutil
import os
import glob
import random
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用來正常顯示中文標(biāo)簽
plt.rcParams['axes.unicode_minus'] = False # 用來正常顯示負(fù)號(hào)
def plot_jpg(start, point_best, err, m, n, lower, upper, title):
point_g = min(start.tolist(), key=target)
plt.figure(figsize=(8, 12))
gs = gridspec.GridSpec(3, 2)
ax1 = plt.subplot(gs[:2, :2])
ax1.scatter(start[:, 0], start[:, 1], alpha=0.3, color='green', s=20, label='當(dāng)前位置') # 當(dāng)前位置
ax1.scatter(point_g[0], point_g[1], alpha=1, color='blue', s=20, label='當(dāng)前最優(yōu)點(diǎn)') # 全局最優(yōu)點(diǎn)
ax1.scatter(point_best[0], point_best[1], alpha=0.3, color='red', label='目標(biāo)點(diǎn)') # 最優(yōu)點(diǎn)
for i in range(n):
ax1.text(start[i][0]+2, start[i][1]+2, f'{i}', alpha=0.3, fontsize=10, color='red')
ax1.grid(True, color='gray', linestyle='-.', linewidth=0.5)
ax1.set_xlim(lower[0]*1.2, upper[0]*1.2)
ax1.set_ylim(lower[1]*1.2, upper[1]*1.2)
ax1.set_xlabel(f'iter:{m} dist: {err[-1]:.8f}')
ax1.set_title(title)
ax1.legend(loc='lower right', bbox_to_anchor=(1, 0), ncol=1)
ax2 = plt.subplot(gs[2, :])
ax2.plot(range(len(err)), err, marker='o', markersize=5)
ax2.grid(True, color='gray', linestyle='-.', linewidth=0.5)
ax2.set_xlim(0, max_iter)
ax2.set_ylim(0, np.ceil(max(err)))
ax2.set_xticks(range(0, max_iter, 5))
plt.savefig(rf'./tmp/tmp_{m:04}.png')
plt.close()
# 目標(biāo)函數(shù)
def target(point):
return (point[0]-a)**2 + (point[1]-b)**2
# 選擇
def select(population):
s = 0
for i in range(d):
s += ((upper_lim[i]-lower_lim[i]) ** 2)
s = s**0.5
weight = [s-(target(i)**0.5) for i in population]
weight = list((weight - (min(weight)-0.01))/sum(weight))
res = []
# c1 = min(range(n), key= lambda x: target(population[x])) # 最優(yōu)個(gè)體
for i in range(n):
tmp_i = list(range(n))
tmp_w = weight[:]
c1 = random.choices(tmp_i, tmp_w, k=1)[0]
tmp_i.pop(c1)
tmp_w.pop(c1)
c2 = random.choices(tmp_i, tmp_w, k=1)[0]
res.append([c1, c2])
return res
# 交叉
def cross(population, res, CR):
population_new = []
for i in range(n):
c1, c2 = res[i]
population_new.append(list(population[c1]))
# if random.random() < CR and i > 0:
if random.random() < CR :
k = random.randint(0,d-1)
population_new[i][k] = population[c2][k]
return np.array(population_new)
# 變異
def mutation(population, AR):
population_new = population.copy()
for i in range(n):
# if random.random() < AR and i > 0:
if random.random() < AR:
r, k = random.random(), random.randint(0,d-1)
population_new[i][k] = r*(upper_lim[k]-lower_lim[k])+lower_lim[k]
return population_new
def GA():
# 初始化種群
population = np.random.random(size=(n, d))
for _ in range(d):
population[:, _] = population[:, _]*(upper_lim[_]-lower_lim[_])+lower_lim[_]
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
os.makedirs(tmp_path, exist_ok=True)
errors = [target(min(population.tolist(), key=target))**0.5]
for _ in range(max_iter):
title = f'GA\nn:{n} CR:{CR} AR:{AR} max_iter:{max_iter}'
plot_jpg(population, point_best, errors, _, n, lower_lim, upper_lim, title)
choices = select(population)
population_new = cross(population, choices, CR)
population = mutation(population_new, AR)
errors.append(target(min(population.tolist(), key=target))**0.5)
plot_jpg(population, point_best, errors, max_iter, n, lower_lim, upper_lim, title)
return errors
CR = 0.8 # 交叉率
AR = 0.05 # 變異率
n = 20 # 粒子數(shù)量
d = 2 # 粒子維度
max_iter = 200 # 迭代次數(shù)
# 搜索區(qū)間
lower_lim = [-100, -100]
upper_lim = [100, 100]
# 目標(biāo)點(diǎn)
a, b = 0, 0
point_best = (a, b)
# 臨時(shí)文件路徑
tmp_path = r'./tmp/'
err = GA()
images = [Image.open(png) for png in glob.glob(os.path.join(tmp_path, '*.png'))[::5]]
im = images.pop(0)
im.save(r"./GA.gif", save_all=True, append_images=images, duration=500)
im = Image.open(r"./GA.gif")
im.show()
im.close()