
在 TensorFlow.js 中加载自己的模型以进行对象检测

如何解决在 TensorFlow.js 中加载自己的模型以进行对象检测

我有一个关于使用 tensorflow.js 通过网络摄像头检测对象的问题。目前我使用的是预训练模型 coco-ssd。


<html lang="en">
    <title>Multiple object detection using pre trained model in tensorflow.js</title>
    <Meta charset="utf-8">
    <Meta http-equiv="X-UA-Compatible" content="IE=edge">
    <Meta name="viewport" content="width=device-width,initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="style.css">
    <h1>Multiple object detection using pre trained model in tensorflow.js</h1>

    <p>Wait for the model to load before clicking the button to enable the webcam - at which point it will become visible to use.</p>
    <section id="demos" class="invisible">

      <p>Hold some objects up close to your webcam to get a real-time classification! When ready click "enable webcam" below and accept access to the webcam when the browser asks (check the top left of your window)</p>
      <div id="liveView" class="camView">
        <button id="webcamButton">Enable Webcam</button>
        <video id="webcam" autoplay width="640" height="480"></video>

    <!-- Import tensorflow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>
    <!-- Load the coco-ssd model to use to recognize things in images -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"></script>
    <!-- Import the page's JavaScript to do some stuff -->
    <script src="script.js" defer></script>


const video = document.getElementById('webcam');
const liveView = document.getElementById('liveView');
const demosSection = document.getElementById('demos');
const enableWebcamButton = document.getElementById('webcamButton');

// Check if webcam access is supported.
function getUserMediaSupported() {
  return !!(navigator.mediaDevices &&

// If webcam supported,add event listener to button for when user
// wants to activate it to call enableCam function which we will 
// define in the next step.
if (getUserMediaSupported()) {
} else {
  console.warn('getUserMedia() is not supported by your browser');

// Enable the live webcam view and start classification.
function enableCam(event) {
  // Only continue if the COCO-SSD has finished loading.
  if (!model) {
  // Hide the button once clicked.
  // getUsermedia parameters to force video but not audio.
  const constraints = {
    video: true

  // Activate the webcam stream.
  navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
    video.srcObject = stream;

// Store the resulting model in the global scope of our app.
var model = undefined;

// Before we can use COCO-SSD class we must wait for it to finish
// loading. Machine Learning models can be large and take a moment 
// to get everything needed to run.
// Note: cocoSsd is an external object loaded from our index.html
// script tag import so ignore any warning in Glitch.
cocoSsd.load().then(function (loadedModel) {
  model = loadedModel;
  // Show demo section Now model is ready to use.

var children = [];

function predictWebcam() {
  // Now let's start classifying a frame in the stream.
  model.detect(video).then(function (predictions) {
    // Remove any highlighting we did prevIoUs frame.
    for (let i = 0; i < children.length; i++) {
    // Now lets loop through predictions and draw them to the live view if
    // they have a high confidence score.
    for (let n = 0; n < predictions.length; n++) {
      // If we are over 66% sure we are sure we classified it right,draw it!
      if (predictions[n].score > 0.66) {
        const p = document.createElement('p');
        p.innerText = predictions[n].class  + ' - with ' 
            + Math.round(parseFloat(predictions[n].score) * 100) 
            + '% confidence.';
        p.style = 'margin-left: ' + predictions[n].bBox[0] + 'px; margin-top: '
            + (predictions[n].bBox[1] - 10) + 'px; width: ' 
            + (predictions[n].bBox[2] - 10) + 'px; top: 0; left: 0;';

        const Highlighter = document.createElement('div');
        Highlighter.style = 'left: ' + predictions[n].bBox[0] + 'px; top: '
            + predictions[n].bBox[1] + 'px; width: ' 
            + predictions[n].bBox[2] + 'px; height: '
            + predictions[n].bBox[3] + 'px;';

    // Call this function again to keep predicting when the browser is ready.

现在我想自定义脚本以使用我自己的模型,我之前使用 Tensorflow for Python 创建和训练了该模型。我已经使用转换器 tfjs_convert 将其转换为 .json 格式。

Files of my own model



您可以使用 @tensorflow/tfjs-converter 中的 loadGraphModel 从 Json 加载。

我喜欢 this 示例。

