TensorFlow.js运行现有模型进行图像分类


编程4046 阅0 评

这篇文章将教会你使用TensorFlow.js通过Javascript在浏览器上运行现有的模型进行图像分类预测。

TensorFlow.js

TensorFlow.js 是一个用于使用 JavaScript 进行机器学习开发的库,可以使用 JavaScript 开发机器学习模型,并直接在浏览器或 Node.js 中使用机器学习模型。

TensorFlow.js 支持以下三种方式进行机器学习模型的开发、训练以及预测:

运行现有模型
使用现成的 JavaScript 模型或转换 Python TensorFlow 模型以在浏览器中或 Node.js 下运行。

重新训练现有模型
使用自己的数据重新训练现有的机器学习模型。

开发机器学习模型
使用灵活且直观的 API 直接用 JavaScript 构建和训练模型。

可谓是极其的强悍呢~

更多的介绍以及教程,可到官网查阅:https://tensorflow.google.cn/js

实操

1. 创建空项目

创建项目工程可参考这篇文章:

为了方便,这里我们使用 Vue CLI 创建一个 TypeScript 的空项目,并修改 src/App.vue 文件中的内容为如下代码:

<template>
  <div id="app">
    TODO
  </div>
</template>

<script lang="ts">
import { Component, Vue } from 'vue-property-decorator'

@Component
export default class App extends Vue {}
</script>

启动项目:

npm run serve

访问地址:http://localhost:8080

2. 准备工作

  • 一张进行分类预测的图片,地址(可以是任意图):
https://github.com/tensorflow/tfjs-models/blob/master/mobilenet/demo/coffee.jpg

将图片下载下来保存到public目录下,并重命名为:coffee.jpg

3. 构建页面

<template>
  <div id="app">
    <div>
      <img id="img" src="/coffee.jpg" :width="imgSize" :height="imgSize" />
    </div>
    <Button @click="onPredict" :disabled="isRunning || !isLoadModel">预测</Button>
    <div v-html="text"></div>
  </div>
</template>

<script lang="ts">
import { Component, Vue } from 'vue-property-decorator'

@Component
export default class App extends Vue {
  // 图片尺寸
  imgSize = 224
  // 模型是否已加载
  isLoadModel = false
  // 正在预测
  isRunning = false
  // 结果
  text = ''

  onPredict () {
    if (!this.isLoadModel) {
      alert('模型加载失败,无法进行预测')
      return
    } else if (this.isRunning) {
      alert('当前正在预测中')
      return
    }
    this.predict()
  }
}
</script>

4. 安装TensorFlow.js

执行安装命令:

npm install @tensorflow/tfjs

当前安装的版本为:2.7.0

安装现有模型,这里使用 MobileNet:

npm install @tensorflow-models/mobilenet

MobileNet 模型项目地址:https://github.com/tensorflow/tfjs-models/tree/master/mobilenet

5. 使用TensorFlow.js

①. 引入库

import * as tf from '@tensorflow/tfjs'
import * as MobileNet from '@tensorflow-models/mobilenet'

②. 加载模型

模型会通过网络下载到浏览器中,似乎需要科学上网
// 图像分类模型  <-- 状态
model: tf.LayersModel | any

mounted () {
  this.text = '正在加载模型...'
  // 使用cpu
  tf.setBackend('cpu')
  // 加载模型
  MobileNet.load()
    .then(mobileNet => {
      this.model = mobileNet
      this.isLoadModel = true
      this.text = '模型加载成功'
    })
    .catch(e => {
      console.error(e)
      this.text = e.message
    })
}

③. 运行预测

/**
 * 预测
 */
predict () {
  this.isRunning = true
  this.text = '正在预测...'
  const img: any = document.getElementById('img')
  // 调用模型进行预测,取出前5个
  this.model.classify(img, 5)
      .then((result: Array<any>) => {
        // 输出
        const items: Array<string> = []
        result.forEach((item: any) => items.push(`${item.className}:${Math.round(100 * item.probability)}%`))
        this.text = items.join('<br>')
      })
      .catch((e: any) => this.text = e)
      .finally(() => this.isRunning = false)
}

6. 运行结果

tfjs_demo

模型

官方已支持并封装的模型:

Type Model Demo Details Install
Images
MobileNet
Classify images with labels from the ImageNet database. npm i @tensorflow-models/mobilenet
source
PoseNet
live A machine learning model which allows for real-time human pose estimation in the browser. See a detailed description here. npm i @tensorflow-models/posenet
source
Coco SSD
Object detection model that aims to localize and identify multiple objects in a single image. Based on the TensorFlow object detection API. npm i @tensorflow-models/coco-ssd
source
BodyPix
live Real-time person and body part segmentation in the browser using TensorFlow.js. npm i @tensorflow-models/body-pix
source
DeepLab v3
Semantic segmentation npm i @tensorflow-models/deeplab
source
Audio
Speech Commands
live Classify 1 second audio snippets from the speech commands dataset. npm i @tensorflow-models/speech-commands
source
Text
Universal Sentence Encoder
Encode text into a 512-dimensional embedding to be used as inputs to natural language processing tasks such as sentiment classification and textual similarity. npm i @tensorflow-models/universal-sentence-encoder
source
Text Toxicity
live Score the perceived impact a comment might have on a conversation, from "Very toxic" to "Very healthy". npm i @tensorflow-models/toxicity
source
General Utilities
KNN Classifier
This package provides a utility for creating a classifier using the K-Nearest Neighbors algorithm. Can be used for transfer learning. npm i @tensorflow-models/knn-classifier
source

详见模型库项目地址:https://github.com/tensorflow/tfjs-models

链接

TensorFlow.js 官网:https://tensorflow.google.cn/js
TensorFlow.js 已支持的模型:https://github.com/tensorflow/tfjs-models
TensorFlow.js API文档:https://js.tensorflow.org/api/lates

示例源码

附上本文源码地址:https://github.com/suimz/example-tfjs-imageclassify-existingmodel

最后更新 2021-06-03
评论 ( 0 )
OωO
隐私评论