如何解决如何读取 TensorBuffer
我是 Tensorflow
和 Android
的新手,我正在尝试对来自外部加速度计的数据进行分类,以预测它与左、右、前、后或中心对齐的天气。所以我使用colab
训练模型并将其转换为tflite
并将其添加到Android App中,我无法理解输出
package com.yogai.tensorflowlava;
import androidx.appcompat.app.AppCompatActivity;
import android.content.Context;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.Toast;
import android.widget.Toolbar;
import com.yogai.tensorflowlava.ml.Adxl345;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.TensorFlowLite;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import java.io.IOException;
import java.nio.ByteBuffer;
public class MainActivity extends AppCompatActivity {
Button submitButton;
EditText editText;
String text;
String TAG = "My_APp";
// private Context context;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Context context;
editText = findViewById(R.id.editText);
submitButton = findViewById(R.id.submitButton);
submitButton.setonClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
text = editText.getText().toString().trim();
//String str = "geekss@for@geekss";
// String[] arrOfStr = text.split(",",3);
//
// // String[] strings = new String[] {"1","2","3","4"};
// if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) {
// Float[] floats = Arrays.stream(arrOfStr).map(Float::valueOf).toArray(Float[]::new);
// }
String[] parts = text.split(",");
float[] numbers = new float[parts.length];
for (int i = 0; i < parts.length; ++i) {
float number = Float.parseFloat(parts[i]);
float rounded = (int) Math.round(number * 1000) / 1000f;
numbers[i] = rounded;
}
// float[][] array2d = new float[1][3];
//
//
//
// for(int j=1;j<4;j++) {
// array2d[1][j] = numbers[j];
//
//
// }
// Float testValue = array2d[1][1]+1;
Log.d(TAG,String.valueOf(numbers[1]));
ByteBuffer.allocate(4).putFloat(numbers[0]).array();
byte[] byteArray= FloatArray2ByteArray(numbers);
ByteBuffer byteBuffer = ByteBuffer.wrap(byteArray);
getoutput(byteBuffer);
}
});
}
public static byte[] FloatArray2ByteArray(float[] values){
ByteBuffer buffer = ByteBuffer.allocate(4 * values.length);
for (float value : values){
buffer.putFloat(value);
}
return buffer.array();
}
private void getoutput(ByteBuffer byteBuffer) {
try {
Adxl345 model = Adxl345.newInstance(getApplicationContext());
// Creates inputs for reference.
TensorBuffer inputFeature0 = TensorBuffer.createFixedSize(new int[]{1,3},DataType.FLOAT32);
inputFeature0.loadBuffer(byteBuffer);
// Runs model inference and gets result.
Adxl345.Outputs outputs = model.process(inputFeature0);
TensorBuffer outputFeature0 = outputs.getoutputFeature0AsTensorBuffer();
String converted = new String(buffer.array(),"UTF-8");
Toast.makeText(this,"output: "+outputFeature0.toString(),Toast.LENGTH_SHORT).show();
// Releases model resources if no longer used.
model.close();
} catch (IOException e) {
// Todo Handle the exception
Toast.makeText(this,"Error: "+ e.toString(),Toast.LENGTH_SHORT).show();
}
}
}
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。