pytorch的安卓部署

慈云数据 6个月前 (05-13) 技术支持 88 0

Pytorch

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理应用程序

首先是使用pytorch

cuda 和 cudnn

cuDNN是NVIDIA专门针对深度神经网络中的基础操作而设计的基于GPU的加速库。cuDNN为深度神经网络中的标准流程提供了高度优化的实现方式,例如卷积、池化、归一化以及激活层的前向以及后向过程。当开发者们需要用到深度学习GPU加速时,才安装cuDNN库,工作速度相较CPU快很多。

需要安装cuda 和 cudnn 库的卸载与安装,可以参考以下教程:

cuda和cudnn安装

安装完成后,可按照下面教程安装pytorch:

此处需要用到anaconda:Anaconda是一个安装、管理Python相关包的软件,其包含了conda、Python等180多个科学包及其依赖项,可以用于在同一个机器上安装不同版本的软件包及其依赖,并能够在不同的环境之间切换。

pytorch安装

另外,如需在pycharm进行操作,可安装以下教程配置虚拟环境:

pycharm配置讯环境(此步需要完成pytorch环境的安装,有anaconda环境)

所有环境准备好以后,查看配置是否完成:

import torch
# 查看版本
print(torch.__version__)
# 查看gup是否可用
print(torch.cuda.is_available())
# 返回gpu个数
print(torch.cuda.device_count())
# 查看对应cuda版本号
print(torch.backends.cudnn.version())
print(torch.version.cuda)

在这里插入图片描述

现在就开始准备移植所需的文件吧。

pt模型和pth模型的不同

PT模型是一种完整的模型文件,不仅包含了模型的参数,还包括了模型的结构,可以直接被加载到模型中,开始进行训练和预测。而pth文件则只保存了模型的参数,因此在加载模型时需要重新定义模型结构。

所以,需要一个安卓端的pt模型

这里我们使用到图像识别的resnet101模型,新建一个py文件,运行以下代码:

模型转换

import torch
import torchvision
model = torchvision.models.resnet101(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model101.pt")
  • import torch 和 import torchvision 是导入PyTorch和PyTorch的计算机视觉库。
  • model = torchvision.models.resnet101(pretrained=True) 加载一个预训练的ResNet-101模型。pretrained=True表示使用预训练的权重,这些权重是在ImageNet数据集上训练得到的。
  • model.eval() 将模型设置为评估模式。这在推理时使用,例如在测试集或生产环境中。
  • example = torch.rand(1, 3, 224, 224) 创建一个随机输入张量,大小为[1, 3, 224, 224],模拟一个输入图像。
  • traced_script_module = torch.jit.trace(model, example) 使用torch.jit.trace将模型转换为一个TorchScript模型。TorchScript是一种可以优化PyTorch模型的方式,使其在没有Python运行环境的情况下运行。
  • traced_script_module.save(“model101.pt”) 将转换后的模型保存为"model101.pt"。

    标签获取

    这样,我们就获得了一个图像识别的pt模型。

    然后,需要找到该模型的标签文件,我们从以下项目中获取

    https://github.com/ethereon/caffe-tensorflow/blob/master/examples/imagenet/imagenet-classes.txt

    拿到txt标签文件以后,我们需要转换成java中String[]的形式,所以稍微处理一下,

    import re
    with open('imagenet-classes.txt','r',encoding='utf-8') as f:
        text = f.read()
    text = re.sub(r'^', '"', text, flags=re.M)
    text = re.sub(r'$', '",', text, flags=re.M)
    with open('imagenet-classes.txt','w',encoding='utf-8') as f:
        f.write(text)
    

    这里是给开头和结尾都加上了"(双引号)以及结尾用,(逗号)分隔。

    安卓端代码

    完成以后我们打开AndroidStudio

    新建一个空的项目

    在资源文件中放入我们的model101.pt文件:

    在这里插入图片描述

    然后Gradle导入pytorch的依赖:

    implementation("org.pytorch:pytorch_android:1.12.1")
    implementation("org.pytorch:pytorch_android_torchvision:1.12.1")
    

    完成以后修改activity

    
        
        
        
        
    
    
        
        
        
    
    

    MainAvity:

    public class MainActivity extends AppCompatActivity {
        Button takePictureBtn=null;
        Bitmap bitmap = null;
        Module module = null;
        // 将图片显示在界面上
        ImageView imageView = null;
        TextView textView = null;
        @Override
        protected void onCreate(Bundle savedInstanceState) {
            super.onCreate(savedInstanceState);
            setContentView(R.layout.activity_main);
            initView();
            //从assets中获得图片数据
            try {
                bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
                imgClassify(bitmap);
            } catch (IOException e) {
                Log.e("PytorchHelloWorld", "Error reading assets", e);
                finish();
            }
        }
        private void imgClassify(Bitmap img){
            try {
                // 加载PyTorch序列化模型
                module = Module.load(assetFilePath(this, "model101.pt"));
            } catch (IOException e) {
                Log.e("PytorchHelloWorld", "Error reading assets", e);
                finish();
            }
            imageView.setImageBitmap(img);
            // 建立输入张量
            final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(img,
                    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
            // 运行模型推理
            final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
            // 获取推理结果
            final float[] scores = outputTensor.getDataAsFloatArray();
            //  获取概率最高的分类的索引号
            float maxScore = -Float.MAX_VALUE;
            int maxScoreIdx = -1;
            for (int i = 0; i  maxScore) {
                    maxScore = scores[i];
                    maxScoreIdx = i;
                }
            }
            //通过索引号获得分类名称
            String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
            textView.setText("识别结果为:"+className);
        }
        private void initView(){
            takePictureBtn = findViewById(R.id.button);
            // 将图片显示在界面上
            imageView = findViewById(R.id.image);
            // 显示分类名称
            textView = findViewById(R.id.text);
            takePictureBtn.setOnClickListener(new View.OnClickListener() {
                @Override
                public void onClick(View v) {
                    Intent intent = new Intent(MainActivity.this,TakePicturesActivity.class);
                    startActivity(intent);
                }
            });
        }
        //返回模型文件路径
        public static String assetFilePath(Context context, String assetName) throws IOException {
            File file = new File(context.getFilesDir(), assetName);
            if (file.exists() && file.length() > 0) {
                return file.getAbsolutePath();
            }
            try (InputStream is = context.getAssets().open(assetName)) {
                try (OutputStream os = new FileOutputStream(file)) {
                    byte[] buffer = new byte[4 * 1024];
                    int read;
                    while ((read = is.read(buffer)) != -1) {
                        os.write(buffer, 0, read);
                    }
                    os.flush();
                }
                return file.getAbsolutePath();
            }
        }
    }
    

    TakePicturesActivity:

    package com.example.newdemo05;
    import static com.example.newdemo05.MainActivity.assetFilePath;
    import androidx.annotation.NonNull;
    import androidx.appcompat.app.AppCompatActivity;
    import androidx.core.app.ActivityCompat;
    import androidx.core.content.ContextCompat;
    import android.content.ContentResolver;
    import android.content.Intent;
    import android.content.pm.PackageManager;
    import android.graphics.Bitmap;
    import android.graphics.BitmapFactory;
    import android.hardware.camera2.CameraAccessException;
    import android.hardware.camera2.CameraCharacteristics;
    import android.hardware.camera2.CameraDevice;
    import android.hardware.camera2.CameraManager;
    import android.hardware.camera2.params.StreamConfigurationMap;
    import android.net.Uri;
    import android.os.Bundle;
    import android.provider.MediaStore;
    import android.util.Log;
    import android.view.TextureView;
    import android.view.View;
    import android.widget.Button;
    import android.widget.ImageView;
    import android.widget.TextView;
    import com.example.newdemo05.classify.ImageNetClasses;
    import org.pytorch.IValue;
    import org.pytorch.Module;
    import org.pytorch.Tensor;
    import org.pytorch.torchvision.TensorImageUtils;
    import java.io.FileNotFoundException;
    import java.io.IOException;
    import java.io.InputStream;
    public class TakePicturesActivity extends AppCompatActivity {
        private TextureView textureView;
        private TextView textPrediction;
        private Button openCam;
        private ImageView v_img;
        Module module = null;
        private CameraDevice cameraDevice;
        @Override
        protected void onCreate(Bundle savedInstanceState) {
            super.onCreate(savedInstanceState);
            setContentView(R.layout.activity_take_pictures);
            initView();
        }
        private void initView(){
            v_img = findViewById(R.id.image);
            textPrediction = findViewById(R.id.text_prediction);
            openCam = findViewById(R.id.btn_capture);
            openCam.setOnClickListener(new View.OnClickListener() {
                @Override
                public void onClick(View v) {
                    startCamera();
                }
            });
        }
        // 打开相机
        public void startCamera(){
            Intent intent = new Intent(Intent.ACTION_PICK, null);
            //调用setDataAndType方法,指定了选择的数据类型为图片
            //设置数据的URI为MediaStore.Images.Media.EXTERNAL_CONTENT_URI,表示选择外部存储中的图片
            intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*");
            //调用startActivityForResult方法,将Intent发送给系统,并指定一个请求码为2,以便在之后的回调中处理用户选择的图片
            startActivityForResult(intent, 2);
        }
        @Override
        protected void onActivityResult(int requestCode, int resultCode, Intent data) {
            super.onActivityResult(requestCode, resultCode, data);
            if (requestCode == 2) {
                // 从相册返回的数据
                Log.e(this.getClass().getName(), "Result:" + data.toString());
                if (data != null) {
                    // 得到图片的全路径
                    Uri uri = data.getData();
                    v_img.setImageURI(uri);
                    ContentResolver cr = getContentResolver();
                    InputStream inputStream = null;
                    try {
                        inputStream = cr.openInputStream(uri);
                    } catch (FileNotFoundException e) {
                        throw new RuntimeException(e);
                    }
                    Bitmap bitmap = BitmapFactory.decodeStream(inputStream);
                    imgClassify(bitmap);
                    Log.e(this.getClass().getName(), "Uri:" + String.valueOf(uri));
                }
            }
        }
        private void imgClassify(Bitmap img){
            try {
                // 加载PyTorch序列化模型
                module = Module.load(assetFilePath(this, "model101.pt"));
            } catch (IOException e) {
                Log.e("PytorchHelloWorld", "Error reading assets", e);
                finish();
            }
            v_img.setImageBitmap(img);
            // 建立输入张量
            final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(img,
                    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
            // 运行模型推理
            final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
            // 获取推理结果
            final float[] scores = outputTensor.getDataAsFloatArray();
            //  获取概率最高的分类的索引号
            float maxScore = -Float.MAX_VALUE;
            int maxScoreIdx = -1;
            for (int i = 0; i  maxScore) {
                    maxScore = scores[i];
                    maxScoreIdx = i;
                }
            }
            //通过索引号获得分类名称
            String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
            textPrediction.setText("识别结果为:"+className);
        }
    }
    

    然后新建一个ImageNetClasses,来装我们的标签文件:

    public class ImageNetClasses {
      public static String[] IMAGENET_CLASSES = new String[]{};}
    

    然后负责粘贴处理好的标签txt文件数据:

    在这里插入图片描述

    具体代码逻辑如下:

    1. TakePicturesActivity: 主Activity,用于打开相机或相册,获取图片并显示。
    2. initView(): 初始化视图控件,如ImageView、TextView和Button。
    3. startCamera(): 打开系统相册,调用startActivityForResult获取图片。
    4. onActivityResult(): 接收选择的图片,将图片显示在ImageView上。
    5. imgClassify(): 使用PyTorch模型对图片进行图像分类,将结果显示在TextView上。
    6. Module: PyTorch的模型类,用于加载pt模型并进行推理。
    7. Tensor: PyTorch的张量类,用于转换Bitmap为模型输入张量。
    8. TensorImageUtils: PyTorch的图片工具类,用于图像数据的预处理。
    9. ImageNetClasses: 包含ImageNet数据集的所有类别名称。

      主要步骤是:

    10. 初始化视图控件
    11. 打开相册,获取图片uri
    12. 将图片解码为Bitmap
    13. 用TensorImageUtils处理为模型输入张量
    14. 加载PyTorch模型并进行推理
    15. 处理输出结果,显示分类类别名称
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon