題目來自莫煩python教學(xué)
tips:
1)當(dāng)你的算法總是不收斂,誒反正就是你怎么改參數(shù)它都不收斂的時候,可能是fitness函數(shù)寫錯了(幽怨臉),問問自己胞得,numpy矩陣操作對了嗎?打個輸出看看真的符合預(yù)期嗎屹电?
2)把numpy數(shù)組里的數(shù)字按照ascii編碼變成字符串:
row = np.array([123,122,98]).astype(np.int8) #一定一定要astype(np.int8)否則會出錯
row = row.tostring().decode("ascii")
代碼實現(xiàn)效果:
源代碼:
import numpy as np
TARGET = 'Do you ever loved me' #霧草我居然用這么酸的話當(dāng)輸入阶剑?
DNA_SIZE = len(TARGET)
GENERATION = 10000
CROSSOVER_RATE = 0.4
MUTATE_RATE = 0.01
POP_SIZE = 300
DNA_BOUND = [32,123]
TARGET_ARR = np.fromstring(TARGET,dtype = np.uint8)
class GA(object):
def __init__(self):
self.pop = np.random.randint(DNA_BOUND[0],DNA_BOUND[1],(1,DNA_SIZE)).astype(np.int8).repeat(POP_SIZE,axis=0)
def getFitness(self,pop):
root = np.fromstring(TARGET,dtype = np.int8).reshape(1,DNA_SIZE).repeat(POP_SIZE,axis=0)
#print("root is ",root)
root = root - pop
return np.sum(root==0,axis=1)
def select(self,fitness):
idx = np.random.choice(np.arange(POP_SIZE),size = POP_SIZE,replace = True,p = fitness/fitness.sum())
#print("idx : ",idx)
return self.pop[idx]
def mutate(self,child):
for index in range(DNA_SIZE):
if np.random.rand() < MUTATE_RATE:
child[index] = np.random.randint(DNA_BOUND[0],DNA_BOUND[1],size=1)
return child
def crossover(self,parent,pop):
if np.random.rand() < CROSSOVER_RATE:
i = np.random.randint(0,POP_SIZE,size=1)
cross_points = np.random.randint(0,2,size=DNA_SIZE).astype(np.bool)
parent[cross_points] = pop[i,cross_points]
return parent
def translateDNA(self,row):
#print("row:",row,"length ",len(row))
return row.tostring().decode('ascii')
def evolution(self,gen):
fitness = self.getFitness(self.pop) + 1e-4
self.pop = self.select(fitness)
#print("Gen : ",gen,"pop :",self.pop)
bestRes = self.translateDNA(self.pop[np.argmax(fitness)])
print("Gen : ",gen,"best result:",bestRes," target is ",TARGET)
pop_copy = self.pop.copy()
for parent in self.pop:
child = self.crossover(parent,pop_copy)
child = self.mutate(child)
parent[:] = child
return bestRes
if __name__ == '__main__':
a = GA()
for gen in range(GENERATION):
res = a.evolution(gen)
if res == TARGET :
break