微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在GCP TPU上嵌入TPU性能

如何解决在GCP TPU上嵌入TPU性能

我正在使用单个TPU内核在GCP TPU v3上测试TPUEmbedding的性能。 我发现我只能获得大约1-2 GB / s的内存带宽。这很低 与规格(900GB / s)相比。想知道代码出了什么问题。 这是使用tensroflow'2.3.0-dev20200620'

要运行代码,您需要设置环境var TPU_TP

import time
import tensorflow as tf
import itertools
import numpy as np
import os
import sys

from tensorflow.python.ops import init_ops_v2
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.util import nest

batch = 16384
nnz = 30
em = 128
features = 1000000
feature_watched_values = np.random.randint(0,features,(batch * nnz * 1,))
batch_size = batch * nnz 
resolver = None

table_test = tpu_embedding_v2_utils.TableConfig(
        vocabulary_size=features,dim=em,initializer=None,combiner='sum',name='test')
feature_config = (
        tpu_embedding_v2_utils.FeatureConfig(
            table=table_test,name='watched'))

def get_strategy():
   resolver = tpu_cluster_resolver.TPUClusterResolver(tpu="grpc://"+os.environ["TPU_IP"])
   remote.connect_to_cluster(resolver)
   topology = tpu_strategy_util.initialize_tpu_system(resolver)
   device_assignment = tf.python.tpu.device_assignment.DeviceAssignment.build(topology,computation_shape=[1,1,1],num_replicas=1)

   return tpu_strategy.TPUStrategy(resolver,device_assignment=device_assignment)

def create_strategy_and_mid_level():
   strategy = get_strategy()
   with strategy.scope():
       optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
       embedding = tpu_embedding_v2.TPUEmbedding(
           feature_config=feature_config,batch_size=batch_size,optimizer=optimizer)

   return strategy,embedding,optimizer

strategy,optimizer = create_strategy_and_mid_level()
training = False

def create_dense_input_fn(strategy,include_weights=False,weight=0.5):
    def input_fn(ctx):
      del ctx
      features = (feature_watched_values)
      return dataset_ops.DatasetV2.from_tensor_slices(features).repeat().batch(batch_size)
    return input_fn

def get_replica_numpy(structured,strategy,replica_id):

    def select_replica(x):
      x = strategy.experimental_local_results(x)
      if len(x) == 1:
        return x 
 
      return x[replica_id] 

    return nest.map_structure(select_replica,structured)

input_fn = create_dense_input_fn(strategy)
dist = strategy.experimental_distribute_datasets_from_function(
        input_fn,options=distribute_lib.Inputoptions(
            experimental_prefetch_to_device=False))
dist_iter = iter(dist)

# @def_function.function
@tf.function
def test_fn():
      def step():
        print("In STEPs")
        activation = embedding.dequeue()
        shard0 = get_replica_numpy(activation,0)
        res = tf.math.reduce_sum(tf.reshape(shard0[0],[batch,nnz,em]),axis=1)
        print("RES device : ",res.device)
        return res

      embedding.enqueue(next(dist_iter),training=False)
      return strategy.run(step)

def test_dense_lookup():
    steps = 4
    warmups = 1
    start = time.time()
    for i in range(0,steps+warmups):
        res = test_fn()
    end0 = time.time()
    res.numpy()  
    end = time.time()

    total_bytes = batch * nnz * em * tf.float32.size
    print("Test batch = ",batch," nnz = ",",em = ",em)
    print(" RES shape: ",res.shape)
    print("Whole loop time is : ",end0 - start,end - start)
    print("TPU: total bytes {0},mem bw {1:.3f} GB/s".format(total_bytes,total_bytes*1.0*steps/(end - start)/1.0e9))
    
test_dense_lookup()

print("done")

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。