前言
Prisma在2016上線后就大火含思,該APP是利用神經(jīng)網(wǎng)絡(luò)和人工智能技術(shù)城榛,為普通照片加入藝術(shù)效果的照片編輯軟件俱恶。
同年Google也發(fā)布了一篇《A LEARNED REPRESENTATION FOR ARTISTIC STYLE》論文武学,實現(xiàn)了前向運算一次為照片整合多種藝術(shù)風格的功能栈源,并且優(yōu)化了內(nèi)存使用和運算速度,可以在移動設(shè)備上快速運算荡含。
最近在研究Tensorflow整合iOS過程中咒唆,發(fā)現(xiàn)google公開了論文實現(xiàn)的源碼和訓練數(shù)據(jù),也就是說我們可以通過自己寫一個前向運算圖释液,整合其訓練參數(shù)就可以快速實現(xiàn)類Prisma的應用全释。
下面就介紹一下如何在iPhone上跑一個自己的"Prisma"。
準備工作
- 安裝Tensorflow误债,這個官網(wǎng)上有詳細教程這里就不多說了浸船。
- 搭建iOS+Tensorflow工程,這個可以根據(jù)Git上的步驟實現(xiàn)寝蹈,也可以參考官方的Demo程序配置李命。(這個過程有很多坑,多次嘗試箫老,應該可以配置成功)
- 下載模型封字,本次使用的模型是image_stylization,google已開源在GitHub上耍鬓。
- 下載訓練好的參數(shù)阔籽,Google提供了2個:
Monet
Varied
Monet訓練了10種藝術(shù)圖片,Varied訓練了32種牲蜀。
當然你也可以自己訓練藝術(shù)圖片笆制,但是得下載VGG的訓練參數(shù)和ImageNet數(shù)據(jù),然后自己訓練涣达,比較花時間在辆。
構(gòu)建計算圖
雖然Google提供了模型的源碼,但是并沒有在源碼中輸出運算圖以方便遷移到移動設(shè)備中使用峭判,Android的Demo中倒是提供了生成的pb,如果覺得自己寫計算圖麻煩可以直接拷到自己iOS工程中使用棕叫。
我這里創(chuàng)建了一個python的工程林螃,然后把Google源碼中model.py相關(guān)的文件都加入了工程。
我的建圖代碼如下:
import numpy as np
import tensorflow as tf
import ast
import os
from tensorflow.python import pywrap_tensorflow
from matplotlib import pyplot
from matplotlib.pyplot import imshow
import image_utils
import model
import ops
import argparse
import sys
num_styles = 32
imgWidth = 512
imgHeight = 512
channel = 3
checkpoint = "/Users/Jiao/Desktop/TFProject/style-image/checkpoint/multistyle-pastiche-generator-varied.ckpt"
inputImage = tf.placeholder(tf.float32,shape=[None,imgWidth,imgHeight,channel],name="input")
styles = tf.placeholder(tf.float32,shape=[num_styles],name="style")
with tf.name_scope(""):
transform = model.transform(inputImage,
normalizer_fn=ops.weighted_instance_norm,
normalizer_params={
# 'weights': tf.constant(mixture),
'weights' : styles,
'num_categories': num_styles,
'center': True,
'scale': True})
model_saver = tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
tf.train.write_graph(sess.graph_def, "/Users/Jiao/Desktop/TFProject/style-image/protobuf", "input.pb")
#checkpoint = os.path.expanduser(checkpoint)
#if tf.gfile.IsDirectory(checkpoint):
# checkpoint = tf.train.latest_checkpoint(checkpoint)
# tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint))
#model_saver.restore(sess, checkpoint)
#newstyle = np.zeros([num_styles], dtype=np.float32)
#newstyle[18] = 0.5
#newstyle[17] = 0.5
#newImage = np.zeros((1,imgWidth,imgHeight,channel))
#style_image = transform.eval(feed_dict={inputImage:newImage,styles:newstyle})
#style_image = style_image[0]
#imshow(style_image)
#pyplot.show()
這里輸入節(jié)點是input
和style
俺泣,輸出節(jié)點是model中的transformer/expand/conv3/conv/Sigmoid
疗认。
到此就將模型的計算圖保存到了本地文件夾中完残。
接下來就是將圖和ckpt中的參數(shù)合并,并且生成移動端的可以使用的pb文件横漏,這一步可以參考我上一篇文章《iOS+Tensorflow實現(xiàn)圖像識別》谨设,很容易就實現(xiàn)。
iOS工程
在上面準備工作中缎浇,如果你已經(jīng)按步驟搭建好iOS+TF的工程扎拣,這里你只需要導入生成的最終pb文件就行了。工程結(jié)構(gòu)如圖:
然后在iOS使用pb文件素跺,我這里直接導入了Google提供的tensorflow_utils
二蓝,使用這個類里面的LoadModel方法可以很快的生成含有計算圖的session。
- (void)viewDidLoad {
[super viewDidLoad];
tensorflow::Status load_status;
load_status = LoadModel(@"rounded_graph", @"pb", &tf_session);
if (!load_status.ok()) {
LOG(FATAL) << "Couldn't load model: " << load_status;
}
currentStyle = 0;
isDone = true;
_styleImageView.layer.borderColor = [UIColor grayColor].CGColor;
_styleImageView.layer.borderWidth = 0.5;
_ogImageView.layer.borderColor = [UIColor grayColor].CGColor;
_ogImageView.layer.borderWidth = 0.5;
}
最后就是獲取圖片指厌,執(zhí)行運算刊愚,生成藝術(shù)圖片展示。這里圖片需要轉(zhuǎn)換成bitmap然后獲取data值踩验,展示圖片也是相識的過程鸥诽。具體代碼如下:
- (void)runCnn:(UIImage *)compressedImg
{
unsigned char *pixels = [self getImagePixel:compressedImg];
int image_channels = 4;
tensorflow::Tensor image_tensor(
tensorflow::DT_FLOAT,
tensorflow::TensorShape(
{1, wanted_input_height, wanted_input_width, wanted_input_channels}));
auto image_tensor_mapped = image_tensor.tensor<float, 4>();
tensorflow::uint8 *in = pixels;
float *out = image_tensor_mapped.data();
for (int y = 0; y < wanted_input_height; ++y) {
float *out_row = out + (y * wanted_input_width * wanted_input_channels);
for (int x = 0; x < wanted_input_width; ++x) {
tensorflow::uint8 *in_pixel =
in + (x * wanted_input_width * image_channels) + (y * image_channels);
float *out_pixel = out_row + (x * wanted_input_channels);
for (int c = 0; c < wanted_input_channels; ++c) {
out_pixel[c] = in_pixel[c];
}
}
}
tensorflow::Tensor style(tensorflow::DT_FLOAT, tensorflow::TensorShape({32}));
float *style_data = style.tensor<float, 1>().data();
memset(style_data, 0, sizeof(float) * 32);
style_data[currentStyle] = 1;
if (tf_session.get()) {
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status run_status = tf_session->Run(
{{contentNode, image_tensor},
{styleNode, style}},
{outputNode},
{},
&outputs);
if (!run_status.ok()) {
LOG(ERROR) << "Running model failed:" << run_status;
isDone = true;
free(pixels);
} else {
float *styledData = outputs[0].tensor<float,4>().data();
UIImage *styledImg = [self createImage:styledData];
dispatch_async(dispatch_get_main_queue(), ^{
_styleImageView.image = styledImg;
dispatch_after(dispatch_time(DISPATCH_TIME_NOW, (int64_t)(0.3 * NSEC_PER_SEC)), dispatch_get_main_queue(), ^{
isDone = true;
free(pixels);
});
});
}
}
}
- (unsigned char *)getImagePixel:(UIImage *)image
{
int width = image.size.width;
int height = image.size.height;
CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB();
unsigned char *rawData = (unsigned char*) calloc(height * width * 4, sizeof(unsigned char));
NSUInteger bytesPerPixel = 4;
NSUInteger bytesPerRow = bytesPerPixel * width;
NSUInteger bitsPerComponent = 8;
CGContextRef context = CGBitmapContextCreate(rawData, width, height,
bitsPerComponent, bytesPerRow, colorSpace,
kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big);
CGColorSpaceRelease(colorSpace);
CGContextDrawImage(context, CGRectMake(0, 0, width, height), image.CGImage);
UIImage *ogImg = [UIImage imageWithCGImage:CGBitmapContextCreateImage(context)];
dispatch_async(dispatch_get_main_queue(), ^{
_ogImageView.image = ogImg;
});
CGContextRelease(context);
return rawData;
}
- (UIImage *)createImage:(float *)pixels
{
unsigned char *rawData = (unsigned char*) calloc(wanted_input_height * wanted_input_width * 4, sizeof(unsigned char));
for (int y = 0; y < wanted_input_height; ++y) {
unsigned char *out_row = rawData + (y * wanted_input_width * 4);
for (int x = 0; x < wanted_input_width; ++x) {
float *in_pixel =
pixels + (x * wanted_input_width * 3) + (y * 3);
unsigned char *out_pixel = out_row + (x * 4);
for (int c = 0; c < wanted_input_channels; ++c) {
out_pixel[c] = in_pixel[c] * 255;
}
out_pixel[3] = UINT8_MAX;
}
}
CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB();
NSUInteger bytesPerPixel = 4;
NSUInteger bytesPerRow = bytesPerPixel * wanted_input_width;
NSUInteger bitsPerComponent = 8;
CGContextRef context = CGBitmapContextCreate(rawData, wanted_input_width, wanted_input_height,
bitsPerComponent, bytesPerRow, colorSpace,
kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big);
CGColorSpaceRelease(colorSpace);
UIImage *retImg = [UIImage imageWithCGImage:CGBitmapContextCreateImage(context)];
CGContextRelease(context);
free(rawData);
return retImg;
}
這里說明一下,前面python工程已經(jīng)定義了箕憾,我的輸入和輸出圖片的大小是512?512牡借。
連接iPhone,運行工程_
最后連上手機運行厕九,就可以自己創(chuàng)建自己的藝術(shù)類圖片了蓖捶。??
放幾張運行效果圖: