在之前寫(xiě)的一篇文章 TensorFlow,從一個(gè) Android Demo 開(kāi)始 中通過(guò)編譯官方的 Demo 接觸到了 TensorFlow 實(shí)際使用場(chǎng)景劣挫。這篇文章打算從一個(gè)Android 開(kāi)發(fā)者的角度切入拢驾,看看構(gòu)建一個(gè)基于 TensorFlow 的 Android 應(yīng)用的完整流程。
相關(guān)代碼可查看:GitHub 項(xiàng)目地址
通過(guò) TensorFlow 用已有模型構(gòu)建 Android 應(yīng)用
在 Google 的 TensorFlow examples project 中,有一個(gè) Sample 叫作 TF Classify伪冰,它通過(guò)使用 Google Inception 模型對(duì)實(shí)時(shí)的相機(jī)圖像幀進(jìn)行分類,并顯示展示當(dāng)前圖像的分類推斷結(jié)果樟蠕。
下面我們就基于這個(gè)現(xiàn)有模型贮聂,在 Android 平臺(tái)上實(shí)現(xiàn)一個(gè)可以對(duì)物品進(jìn)行分類的圖像識(shí)別應(yīng)用。
獲取數(shù)據(jù)模型
這里可以直接下載 Google 提供的一個(gè)數(shù)據(jù)模型 inception5h.zip 寨辩,其中 .pb
后綴的文件是已經(jīng)訓(xùn)練好的模型寂汇,而 .txt
對(duì)應(yīng)的是訓(xùn)練數(shù)據(jù)包含的所有標(biāo)簽。
這個(gè)模型可對(duì) 1008 種物品識(shí)別分類捣染,具體有哪些類可以查看標(biāo)簽信息骄瓣,至于每個(gè)類別到底訓(xùn)練了多少?gòu)垐D片就不得而知了。
在 Android 項(xiàng)目中引入 TensorFlow
跟在項(xiàng)目中集成其他第三庫(kù)一樣,先在 build.gradle
中添加對(duì) TensorFlow 的依賴榕栏。
compile 'org.tensorflow:tensorflow-android:1.6.0'
這里我們直接使用了 Google 為我們編譯好的 TensorFlow 現(xiàn)成庫(kù)了畔勤,如果你想自行對(duì) TensorFlow 進(jìn)行 NDK 交叉編譯得到庫(kù)文件也可以。
圖像識(shí)別功能的實(shí)現(xiàn)
復(fù)制模型文件到項(xiàng)目 assets
文件夾:
如下圖所示扒磁,我們?cè)陧?xiàng)目 assets
文件夾下創(chuàng)建一個(gè) model
文件夾庆揪,并把之前下載的 inception5h.zip
解壓后的全部文件復(fù)制到該文件夾下。
添加模型調(diào)用的相關(guān)類
因?yàn)槲覀円獙?shí)現(xiàn)的功能和官方 demo 相似妨托,只是訓(xùn)練的有所模型不同缸榛。既然對(duì)模型的使用方式是一樣的,那這里就直接使用 Google demo 項(xiàng)目中提供的 Classifier.java 和 TensorFlowImageClassifier.java 這兩個(gè)類來(lái)實(shí)現(xiàn)兰伤。
我們可以先跳過(guò)這部分內(nèi)容的具體實(shí)現(xiàn)内颗,等到對(duì)整體流程有個(gè)大致認(rèn)識(shí)后再回過(guò)頭來(lái)消化掉,這樣可以更好地去理解敦腔。
這里我們重點(diǎn)關(guān)注下面兩個(gè)方法均澳,一個(gè)是 TensorFlowImageClassifier
的靜態(tài)方法 create
方法:
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
* @param inputSize The input size. A square image of inputSize x inputSize is assumed.
* @param imageMean The assumed mean of the image values.
* @param imageStd The assumed std of the image values.
* @param inputName The label of the image input node.
* @param outputName The label of the output node.
* @throws IOException
*/
public static Classifier create(AssetManager assetManager, String modelFilename, String labelFilename,
int inputSize, int imageMean, float imageStd, String inputName, String outputName)
該方法需要傳入模型相關(guān)的參數(shù)進(jìn)行初始化,完成后返回一個(gè) Classifier
實(shí)例符衔。
通過(guò) Classifier
對(duì)象找前,我們可以調(diào)用其 recognizeImage
方法來(lái)識(shí)別我們傳入的 bitmap
圖像數(shù)據(jù),該方法會(huì)返回圖像類別后對(duì)物品類別進(jìn)行推斷的標(biāo)簽結(jié)果:
/**
* 進(jìn)行圖片識(shí)別
*/
public List<Recognition> recognizeImage(final Bitmap bitmap)
相關(guān)主要功能代碼的實(shí)現(xiàn):
相關(guān)代碼可查看:GitHub 項(xiàng)目地址
public class MainActivity extends AppCompatActivity implements View.OnClickListener {
...
// 模型相關(guān)配置
private static final int INPUT_SIZE = 224;
private static final int IMAGE_MEAN = 117;
private static final float IMAGE_STD = 1;
private static final String INPUT_NAME = "input";
private static final String OUTPUT_NAME = "output";
private static final String MODEL_FILE = "file:///android_asset/model/tensorflow_inception_graph.pb";
private static final String LABEL_FILE = "file:///android_asset/model/imagenet_comp_graph_label_strings.txt";
private Executor executor;
private Uri currentTakePhotoUri;
private TextView result;
private ImageView ivPicture;
private Classifier classifier;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
if (!isTaskRoot()) {
finish();
}
setContentView(R.layout.activity_main);
findViewById(R.id.iv_choose_picture).setOnClickListener(this);
findViewById(R.id.iv_take_photo).setOnClickListener(this);
ivPicture = findViewById(R.id.iv_picture);
result = findViewById(R.id.tv_classifier_info);
// 避免耗時(shí)任務(wù)占用 CPU 時(shí)間片造成UI繪制卡頓判族,提升啟動(dòng)頁(yè)面加載速度
Looper.myQueue().addIdleHandler(idleHandler);
}
/**
* 主線程消息隊(duì)列空閑時(shí)(視圖第一幀繪制完成時(shí))處理耗時(shí)事件
*/
MessageQueue.IdleHandler idleHandler = new MessageQueue.IdleHandler() {
@Override
public boolean queueIdle() {
// 初始化 Classifier
if (classifier == null) {
// 創(chuàng)建 TensorFlowImageClassifier
classifier = TensorFlowImageClassifier.create(MainActivity.this.getAssets(),
MODEL_FILE, LABEL_FILE, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD, INPUT_NAME, OUTPUT_NAME);
}
// 初始化線程池
executor = new ScheduledThreadPoolExecutor(1, new ThreadFactory() {
@Override
public Thread newThread(@NonNull Runnable r) {
Thread thread = new Thread(r);
thread.setDaemon(true);
thread.setName("ThreadPool-ImageClassifier");
return thread;
}
});
// 請(qǐng)求權(quán)限
requestMultiplePermissions();
// 返回 false 時(shí)只會(huì)回調(diào)一次
return false;
}
};
@Override
public void onClick(View view) {
switch (view.getId()) {
case R.id.iv_choose_picture :
choosePicture();
break;
case R.id.iv_take_photo :
takePhoto();
break;
default:break;
}
}
/**
* 選擇一張圖片并裁剪獲得一個(gè)小圖
*/
private void choosePicture() {
Intent intent = new Intent(Intent.ACTION_GET_CONTENT);
intent.setType("image/*");
startActivityForResult(intent, PICTURE_REQUEST_CODE);
}
/**
* 使用系統(tǒng)相機(jī)拍照
*/
private void takePhoto() {
if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, CAMERA_PERMISSIONS_REQUEST_CODE);
} else {
openSystemCamera();
}
}
/**
* 打開(kāi)系統(tǒng)相機(jī)
*/
private void openSystemCamera() {
//調(diào)用系統(tǒng)相機(jī)
Intent takePhotoIntent = new Intent();
takePhotoIntent.setAction(MediaStore.ACTION_IMAGE_CAPTURE);
//這句作用是如果沒(méi)有相機(jī)則該應(yīng)用不會(huì)閃退躺盛,要是不加這句則當(dāng)系統(tǒng)沒(méi)有相機(jī)應(yīng)用的時(shí)候該應(yīng)用會(huì)閃退
if (takePhotoIntent.resolveActivity(getPackageManager()) == null) {
Toast.makeText(this, "當(dāng)前系統(tǒng)沒(méi)有可用的相機(jī)應(yīng)用", Toast.LENGTH_SHORT).show();
return;
}
String fileName = "TF_" + System.currentTimeMillis() + ".jpg";
File photoFile = new File(FileUtil.getPhotoCacheFolder(), fileName);
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
//通過(guò)FileProvider創(chuàng)建一個(gè)content類型的Uri
currentTakePhotoUri = FileProvider.getUriForFile(this, "gdut.bsx.tensorflowtraining.fileprovider", photoFile);
//對(duì)目標(biāo)應(yīng)用臨時(shí)授權(quán)該 Uri 所代表的文件
takePhotoIntent.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION);
} else {
currentTakePhotoUri = Uri.fromFile(photoFile);
}
//將拍照結(jié)果保存至 outputFile 的Uri中,不保留在相冊(cè)中
takePhotoIntent.putExtra(MediaStore.EXTRA_OUTPUT, currentTakePhotoUri);
startActivityForResult(takePhotoIntent, TAKE_PHOTO_REQUEST_CODE);
}
/**
* 處理圖片
* @param imageUri
*/
private void handleInputPhoto(Uri imageUri) {
// 加載圖片
GlideApp.with(MainActivity.this).asBitmap().listener(new RequestListener<Bitmap>() {
@Override
public boolean onLoadFailed(@Nullable GlideException e, Object model, Target<Bitmap> target, boolean isFirstResource) {
Log.d(TAG,"handleInputPhoto onLoadFailed");
Toast.makeText(MainActivity.this, "圖片加載失敗", Toast.LENGTH_SHORT).show();
return false;
}
@Override
public boolean onResourceReady(Bitmap resource, Object model, Target<Bitmap> target, DataSource dataSource, boolean isFirstResource) {
Log.d(TAG,"handleInputPhoto onResourceReady");
startImageClassifier(resource);
return false;
}
}).load(imageUri).into(ivPicture);
result.setText("Processing...");
}
/**
* 開(kāi)始圖片識(shí)別匹配
* @param bitmap
*/
private void startImageClassifier(final Bitmap bitmap) {
executor.execute(new Runnable() {
@Override
public void run() {
try {
Log.i(TAG, Thread.currentThread().getName() + " startImageClassifier");
Bitmap croppedBitmap = getScaleBitmap(bitmap, INPUT_SIZE);
final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
Log.i(TAG, "startImageClassifier results: " + results);
runOnUiThread(new Runnable() {
@Override
public void run() {
result.setText(String.format("results: %s", results));
}
});
} catch (IOException e) {
Log.e(TAG, "startImageClassifier getScaleBitmap " + e.getMessage());
}
}
});
}
/**
* 請(qǐng)求相機(jī)和外部存儲(chǔ)權(quán)限
*/
private void requestMultiplePermissions() {
String storagePermission = Manifest.permission.WRITE_EXTERNAL_STORAGE;
String cameraPermission = Manifest.permission.CAMERA;
int hasStoragePermission = ActivityCompat.checkSelfPermission(this, storagePermission);
int hasCameraPermission = ActivityCompat.checkSelfPermission(this, cameraPermission);
List<String> permissions = new ArrayList<>();
if (hasStoragePermission != PackageManager.PERMISSION_GRANTED) {
permissions.add(storagePermission);
}
if (hasCameraPermission != PackageManager.PERMISSION_GRANTED) {
permissions.add(cameraPermission);
}
if (!permissions.isEmpty()) {
String[] params = permissions.toArray(new String[permissions.size()]);
ActivityCompat.requestPermissions(this, params, PERMISSIONS_REQUEST);
}
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (resultCode == RESULT_OK) {
if (requestCode == PICTURE_REQUEST_CODE) {
// 處理選擇的圖片
handleInputPhoto(data.getData());
} else if (requestCode == OPEN_SETTING_REQUEST_COED){
requestMultiplePermissions();
} else if (requestCode == TAKE_PHOTO_REQUEST_CODE) {
// 如果拍照成功形帮,加載圖片并識(shí)別
handleInputPhoto(currentTakePhotoUri);
}
}
}
/**
* 對(duì)圖片進(jìn)行縮放
* @param bitmap
* @param size
* @return
* @throws IOException
*/
private static Bitmap getScaleBitmap(Bitmap bitmap, int size) throws IOException {
int width = bitmap.getWidth();
int height = bitmap.getHeight();
float scaleWidth = ((float) size) / width;
float scaleHeight = ((float) size) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
return Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
}
}
運(yùn)行效果
圖片選擇和拍照獲取界面:
物品識(shí)別結(jié)果展示界面:
相關(guān)代碼可查看:GitHub 項(xiàng)目地址
是不是覺(jué)得通過(guò) TensorFlow 在現(xiàn)有的數(shù)據(jù)模型基礎(chǔ)下颗品,我們可以很簡(jiǎn)單就完成了一個(gè)簡(jiǎn)單的圖像識(shí)別應(yīng)用。
在使用這個(gè)模型來(lái)推斷物品類型的過(guò)程中沃缘,發(fā)現(xiàn)好像有時(shí)候準(zhǔn)確率不是那么高躯枢,這時(shí)候改怎么辦。如果說(shuō)只是想識(shí)別一些特定種類的物品槐臀,哪有又該怎么辦锄蹂?
在之前一篇文章中我有提到過(guò),機(jī)器學(xué)習(xí)是依靠對(duì)大量有標(biāo)簽的樣本數(shù)據(jù)進(jìn)行反復(fù)訓(xùn)練后才逐步得到的最佳模型水慨。對(duì)未知無(wú)標(biāo)簽樣本的推斷依賴這個(gè)模型的準(zhǔn)確程度得糜。所以我們可以通過(guò)對(duì)現(xiàn)有模型進(jìn)行遷移訓(xùn)練(retrain)來(lái)定制我們自己的模型。
下面就通過(guò)對(duì)現(xiàn)有的 Google Inception-V3 模型進(jìn)行 retrain 晰洒,對(duì) 5 種花朵樣本數(shù)據(jù)的進(jìn)行訓(xùn)練朝抖,來(lái)完成一個(gè)可以識(shí)別五種花朵的模型。
具體實(shí)現(xiàn)方式可以參考我的另外一篇文章:通過(guò)遷移訓(xùn)練來(lái)定制 TensorFlow 模型