目前市面上主流的深度學(xué)習(xí)框架 TensorFlow医咨、pytorch、MxNet都是以Python語言為主昨登,Java工程師們想要利用自己的優(yōu)勢開發(fā)一款深度學(xué)習(xí)應(yīng)用絕非易事;通過本篇文章我們將解決這個問題武鲁,用極少的代碼實現(xiàn)一個圖片分類服務(wù)
場景
【物體分類】
通過Http請求奉狈,向后端服務(wù)傳入一張圖片地址,后端服務(wù)調(diào)用深度學(xué)習(xí)模型對圖片進(jìn)行處理很泊,給出分類預(yù)測結(jié)果
安裝本地庫
以mxnet為例
首先下載本地庫文件,根據(jù)機器配置選擇下載,比我的GPU服務(wù)器
https://publish.djl.ai/mxnet-1.7.0-backport/win/cu102mkl/mxnet_61.dll.gz
win/mkl/libmxnet.dll.gz
win/common/libgcc_s_seh-1.dll.gz
win/common/libgfortran-3.dll.gz
win/common/libopenblas.dll.gz
win/common/libquadmath-0.dll.gz
將文件解壓到 C:\Users\bigdata.djl.ai\mxnet
搭建工程
使用idea或者eclipse構(gòu)建maven工程当叭,并導(dǎo)入以下maven依賴
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.12.1</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>com.sparkjava</groupId>
<artifactId>spark-core</artifactId>
<version>2.8.0</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-native-auto</artifactId>
<version>1.7.0-a</version>
<scope>runtime</scope>
</dependency>
加載模型
使用djl ModelZoo加載ImageNet模型,并對輸入圖片進(jìn)行分類預(yù)測
public class ImageNetTest {
private static Predictor<BufferedImage, Classifications> predictor = null;
static {
load();
}
private static void load(){
Criteria<BufferedImage, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(BufferedImage.class, Classifications.class)
.optFilter("multiplier", "0.75")
.optFilter("flavor", "v1")
.optFilter("dataset", "imagenet")
.optArtifactId("mobilenet")
.optProgress(new ProgressBar())
.build();
try {
ZooModel<BufferedImage, Classifications> model = ModelZoo.loadModel(criteria);
predictor = model.newPredictor();
} catch (IOException e) {
e.printStackTrace();
} catch (ModelNotFoundException e) {
e.printStackTrace();
} catch (MalformedModelException e) {
e.printStackTrace();
}
}
public static String predict(String imagePath) throws Exception {
BufferedImage image;
if (imagePath.startsWith("http")) {
image = BufferedImageUtils.fromUrl(new URL(imagePath));
} else {
image = BufferedImageUtils.fromFile(Paths.get(imagePath));
}
return new Gson().toJson(predictor.predict(image).topK(3));
}
public static void main(String[] args)throws Exception {
System.out.println(predict("src/test/resources/dog-cat.jpg"));
}
}
模型下載地址
下載模型 將文件解壓到
{your_os_user_root}\.djl.ai\cache\repo\model\cv\image_classification\ai\djl\mxnet\mobilenet\v1\0.75
掃碼下載
SimpleHttp
通過web spark 快速實現(xiàn)restful api
public static void main(String[] args) {
port(8899);
get("/img_classes/predict", (request, response) -> {
return ImageNetTest.predict(request.queryParams("img_url"));
});
}
測試效果
斗牛犬
image.png