微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 TensorFlow.js 中加载自己的模型以进行对象检测

如何解决在 TensorFlow.js 中加载自己的模型以进行对象检测

我有一个关于使用 tensorflow.js 通过网络摄像头检测对象的问题。目前我使用的是预训练模型 coco-ssd。

index.html:

<html lang="en">
  <head>
    <title>Multiple object detection using pre trained model in tensorflow.js</title>
    <Meta charset="utf-8">
    <Meta http-equiv="X-UA-Compatible" content="IE=edge">
    <Meta name="viewport" content="width=device-width,initial-scale=1">
    <!-- Import the webpage's stylesheet -->
    <link rel="stylesheet" href="style.css">
  </head>  
  <body>
    <h1>Multiple object detection using pre trained model in tensorflow.js</h1>

    <p>Wait for the model to load before clicking the button to enable the webcam - at which point it will become visible to use.</p>
    
    <section id="demos" class="invisible">

      <p>Hold some objects up close to your webcam to get a real-time classification! When ready click "enable webcam" below and accept access to the webcam when the browser asks (check the top left of your window)</p>
      
      <div id="liveView" class="camView">
        <button id="webcamButton">Enable Webcam</button>
        <video id="webcam" autoplay width="640" height="480"></video>
      </div>
    </section>

    <!-- Import tensorflow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>
    <!-- Load the coco-ssd model to use to recognize things in images -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"></script>
    
    <!-- Import the page's JavaScript to do some stuff -->
    <script src="script.js" defer></script>
  </body>
</html>

script.js:

const video = document.getElementById('webcam');
const liveView = document.getElementById('liveView');
const demosSection = document.getElementById('demos');
const enableWebcamButton = document.getElementById('webcamButton');

// Check if webcam access is supported.
function getUserMediaSupported() {
  return !!(navigator.mediaDevices &&
    navigator.mediaDevices.getUserMedia);
}

// If webcam supported,add event listener to button for when user
// wants to activate it to call enableCam function which we will 
// define in the next step.
if (getUserMediaSupported()) {
  enableWebcamButton.addEventListener('click',enableCam);
} else {
  console.warn('getUserMedia() is not supported by your browser');
}

// Enable the live webcam view and start classification.
function enableCam(event) {
  // Only continue if the COCO-SSD has finished loading.
  if (!model) {
    return;
  }
  
  // Hide the button once clicked.
  event.target.classList.add('removed');  
  
  // getUsermedia parameters to force video but not audio.
  const constraints = {
    video: true
  };

  // Activate the webcam stream.
  navigator.mediaDevices.getUserMedia(constraints).then(function(stream) {
    video.srcObject = stream;
    video.addEventListener('loadeddata',predictWebcam);
  });
}


// Store the resulting model in the global scope of our app.
var model = undefined;

// Before we can use COCO-SSD class we must wait for it to finish
// loading. Machine Learning models can be large and take a moment 
// to get everything needed to run.
// Note: cocoSsd is an external object loaded from our index.html
// script tag import so ignore any warning in Glitch.
cocoSsd.load().then(function (loadedModel) {
  model = loadedModel;
  // Show demo section Now model is ready to use.
  demosSection.classList.remove('invisible');
});


var children = [];

function predictWebcam() {
  // Now let's start classifying a frame in the stream.
  model.detect(video).then(function (predictions) {
    // Remove any highlighting we did prevIoUs frame.
    for (let i = 0; i < children.length; i++) {
      liveView.removeChild(children[i]);
    }
    children.splice(0);
    
    // Now lets loop through predictions and draw them to the live view if
    // they have a high confidence score.
    for (let n = 0; n < predictions.length; n++) {
      // If we are over 66% sure we are sure we classified it right,draw it!
      if (predictions[n].score > 0.66) {
        const p = document.createElement('p');
        p.innerText = predictions[n].class  + ' - with ' 
            + Math.round(parseFloat(predictions[n].score) * 100) 
            + '% confidence.';
        p.style = 'margin-left: ' + predictions[n].bBox[0] + 'px; margin-top: '
            + (predictions[n].bBox[1] - 10) + 'px; width: ' 
            + (predictions[n].bBox[2] - 10) + 'px; top: 0; left: 0;';

        const Highlighter = document.createElement('div');
        Highlighter.setAttribute('class','Highlighter');
        Highlighter.style = 'left: ' + predictions[n].bBox[0] + 'px; top: '
            + predictions[n].bBox[1] + 'px; width: ' 
            + predictions[n].bBox[2] + 'px; height: '
            + predictions[n].bBox[3] + 'px;';

        liveView.appendChild(Highlighter);
        liveView.appendChild(p);
        children.push(Highlighter);
        children.push(p);
      }
    }
    
    // Call this function again to keep predicting when the browser is ready.
    window.requestAnimationFrame(predictWebcam);
  });
}

现在我想自定义脚本以使用我自己的模型,我之前使用 Tensorflow for Python 创建和训练了该模型。我已经使用转换器 tfjs_convert 将其转换为 .json 格式。

Files of my own model

如何修改我的代码以便现在使用我自己的模型?我已经尝试了一些东西,但很遗憾没有取得任何进展。

解决方法

您可以使用 @tensorflow/tfjs-converter 中的 loadGraphModel 从 Json 加载。

我喜欢 this 示例。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。