上篇文章《机器视觉实战4:OpenCV Android环境搭建(喂饭版)》中介绍了如何使用Android Studio搭建OpenCV开发环境,本节基于之前搭建好的环境开发一个基于神经网络的目标检测App。

准备模型

首先从这里下载已经训练好的模型文件:

  • deploy.prototxt:神经网络结构的描述文件
  • mobilenet_iter_73000.caffemodel:神经网络的参数信息

这个模型是使用Caffe实现的Google MobileNet SSD检测模型。有个Caffe Zoo项目,收集了很多已经训练好的模型,有兴趣的可以看一下。下载好模型之后,在app/src/main/下面创建一个assets目录,把两个模型文件放进去。至此,模型的准备工作就完成了。

编写代码

布局文件activity_main.xml:

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <Button
        android:id="@+id/imageSelect"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginStart="32dp"
        android:layout_marginLeft="32dp"
        android:layout_marginTop="16dp"
        android:text="@string/image_select"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" />

    <Button
        android:id="@+id/recognize"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginStart="16dp"
        android:layout_marginLeft="16dp"
        android:layout_marginTop="16dp"
        android:text="@string/recognize"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toEndOf="@+id/imageSelect"
        app:layout_constraintTop_toTopOf="parent" />

    <ImageView
        android:id="@+id/imageView"
        android:layout_width="387dp"
        android:layout_height="259dp"
        android:layout_marginStart="8dp"
        android:layout_marginLeft="8dp"
        android:layout_marginTop="22dp"
        android:layout_marginEnd="8dp"
        android:layout_marginRight="8dp"
        android:contentDescription="images"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/imageSelect" />


</androidx.constraintlayout.widget.ConstraintLayout>

刚接触安卓开发没几天,布局是瞎写的,仅考虑了功能。

MainActivity.java代码:

package com.niyanchun.demo;

import androidx.annotation.Nullable;
import androidx.appcompat.app.AppCompatActivity;

import android.annotation.SuppressLint;
import android.content.Context;
import android.content.Intent;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.net.Uri;
import android.os.Bundle;
import android.util.Log;
import android.widget.Button;
import android.widget.EditText;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;

import org.opencv.android.OpenCVLoader;
import org.opencv.android.Utils;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.dnn.Dnn;
import org.opencv.dnn.Net;
import org.opencv.imgproc.Imgproc;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;

@SuppressLint("SetTextI18n")
public class MainActivity extends AppCompatActivity {

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        if (OpenCVLoader.initDebug()) {
            Log.i("CV", "load OpenCV Library Successful.");
        } else {
            Log.i("CV", "load OpenCV Library Failed.");
        }

        imageView = findViewById(R.id.imageView);
        imageView.setScaleType(ImageView.ScaleType.FIT_CENTER);
        Button selectBtn = findViewById(R.id.imageSelect);
        selectBtn.setOnClickListener(v -> {
            Intent intent = new Intent();
            intent.setType("image/*");
            intent.setAction(Intent.ACTION_GET_CONTENT);
            startActivityForResult(Intent.createChooser(intent, "选择图片"), PICK_IMAGE_REQUEST);
        });

        Button recognizeBtn = findViewById(R.id.recognize);
        recognizeBtn.setOnClickListener(v -> {
            // 确保加载完成
            if (net == null) {
                Toast.makeText(this, "正在加载模型,请稍后...", Toast.LENGTH_LONG).show();
                while (net == null) {
                    try {
                        Thread.sleep(1000);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
            recognize();
        });
    }

    @Override
    protected void onResume() {
        super.onResume();
        loadModel();
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
        super.onActivityResult(requestCode, resultCode, data);

        if (requestCode == PICK_IMAGE_REQUEST && resultCode == RESULT_OK
                && data != null && data.getData() != null) {
            Uri uri = data.getData();
            try {
                Log.d("image-decode", "start to decode selected image now...");
                InputStream input = getContentResolver().openInputStream(uri);
                BitmapFactory.Options options = new BitmapFactory.Options();
                options.inJustDecodeBounds = true;
                BitmapFactory.decodeStream(input, null, options);
                int rawWidth = options.outWidth;
                int rawHeight = options.outHeight;
                int max = Math.max(rawWidth, rawHeight);
                int newWidth, newHeight;
                float inSampleSize = 1.0f;
                if (max > MAX_SIZE) {
                    newWidth = rawWidth / 2;
                    newHeight = rawHeight / 2;
                    while ((newWidth / inSampleSize) > MAX_SIZE || (newHeight / inSampleSize) > MAX_SIZE) {
                        inSampleSize *= 2;
                    }
                }

                options.inSampleSize = (int) inSampleSize;
                options.inJustDecodeBounds = false;
                options.inPreferredConfig = Bitmap.Config.ARGB_8888;

                image = BitmapFactory.decodeStream(getContentResolver().openInputStream(uri), null, options);
                imageView.setImageBitmap(image);
            } catch (Exception e) {
                Log.e("image-decode", "decode image error", e);
            }
        }
    }

    /**
     * 加载模型
     */
    private void loadModel() {
        if (net == null) {
            Toast.makeText(this, "开始加载模型...", Toast.LENGTH_LONG).show();
            String proto = getPath("MobileNetSSD_deploy.prototxt", this);
            String weights = getPath("mobilenet_iter_73000.caffemodel", this);
            net = Dnn.readNetFromCaffe(proto, weights);
            Log.i("model", "load model successfully.");
            Toast.makeText(this, "模型加载成功!", Toast.LENGTH_LONG).show();
        }
    }


    /**
     * 识别
     */
    private void recognize() {
        // 该网络的输入层要求的图片尺寸为 300*300
        final int IN_WIDTH = 300;
        final int IN_HEIGHT = 300;
        final float WH_RATIO = (float) IN_WIDTH / IN_HEIGHT;
        final double IN_SCALE_FACTOR = 0.007843;
        final double MEAN_VAL = 127.5;
        final double THRESHOLD = 0.2;

        Mat imageMat = new Mat();
        Utils.bitmapToMat(image, imageMat);
        Imgproc.cvtColor(imageMat, imageMat, Imgproc.COLOR_RGBA2RGB);
        Mat blob = Dnn.blobFromImage(imageMat, IN_SCALE_FACTOR,
                new Size(IN_WIDTH, IN_HEIGHT),
                new Scalar(MEAN_VAL, MEAN_VAL, MEAN_VAL),
                false, false);
        net.setInput(blob);
        Mat detections = net.forward();

        int cols = imageMat.cols();
        int rows = imageMat.rows();
        detections = detections.reshape(1, (int) detections.total() / 7);
        boolean detected = false;
        for (int i = 0; i < detections.rows(); ++i) {
            double confidenceTmp = detections.get(i, 2)[0];
            if (confidenceTmp > THRESHOLD) {
                detected = true;
                int classId = (int) detections.get(i, 1)[0];
                int left = (int) (detections.get(i, 3)[0] * cols);
                int top = (int) (detections.get(i, 4)[0] * rows);
                int right = (int) (detections.get(i, 5)[0] * cols);
                int bottom = (int) (detections.get(i, 6)[0] * rows);
                // Draw rectangle around detected object.
                Imgproc.rectangle(imageMat, new Point(left, top), new Point(right, bottom),
                        new Scalar(0, 255, 0), 4);
                String label = classNames[classId] + ": " + confidenceTmp;
                int[] baseLine = new int[1];
                Size labelSize = Imgproc.getTextSize(label, Core.FONT_HERSHEY_COMPLEX, 0.5, 5, baseLine);
                // Draw background for label.
                Imgproc.rectangle(imageMat, new Point(left, top - labelSize.height),
                        new Point(left + labelSize.width, top + baseLine[0]),
                        new Scalar(255, 255, 255), Core.FILLED);
                // Write class name and confidence.
                Imgproc.putText(imageMat, label, new Point(left, top),
                        Core.FONT_HERSHEY_COMPLEX, 0.5, new Scalar(0, 0, 0));
            }
        }

        if (!detected) {
            Toast.makeText(this, "没有检测到目标!", Toast.LENGTH_LONG).show();
            return;
        }

        Utils.matToBitmap(imageMat, image);
        imageView.setImageBitmap(image);
    }

    // Upload file to storage and return a path.
    private static String getPath(String file, Context context) {
        Log.i("getPath", "start upload file " + file);
        AssetManager assetManager = context.getAssets();
        BufferedInputStream inputStream = null;
        try {
            // Read data from assets.
            inputStream = new BufferedInputStream(assetManager.open(file));
            byte[] data = new byte[inputStream.available()];
            inputStream.read(data);
            inputStream.close();
            // Create copy file in storage.
            File outFile = new File(context.getFilesDir(), file);
            FileOutputStream os = new FileOutputStream(outFile);
            os.write(data);
            os.close();
            Log.i("getPath", "upload file " + file + "done");
            // Return a path to file which may be read in common way.
            return outFile.getAbsolutePath();
        } catch (IOException ex) {
            Log.e("getPath", "Failed to upload a file");
        }
        return "";
    }

    private static final int MAX_SIZE = 1024;
    private ImageView imageView;
    private Bitmap image;
    private Net net = null;
    private int PICK_IMAGE_REQUEST = 1;
    private static final String[] classNames = {"background",
            "aeroplane", "bicycle", "bird", "boat",
            "bottle", "bus", "car", "cat", "chair",
            "cow", "diningtable", "dog", "horse",
            "motorbike", "person", "pottedplant",
            "sheep", "sofa", "train", "tvmonitor"};
}

代码中的一些关键点说明如下:

  • loadModel:实现了模型的加载,OpenCV提供了readNetFromCaffe方法用于加载Caffe训练的模型,其输入就是两个模型文件。
  • onActivityResult:实现了选择图片后的图片处理和展示。
  • recognize:实现利用加载的模型进行目标检测,并根据检测结果用框画出目标的位置。和之前的基于HOG特征的目标检测类似。

然后点击运行,效果如下:

result

可以看到,检测到了显示器、盆栽、猫、人等。对安卓还不太熟,后面有时间了弄一从摄像头视频中实时检测的App玩玩。

Reference: