微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Python apply_async 不修改 numpy 数组

如何解决Python apply_async 不修改 numpy 数组

我正在尝试对大型 numpy 数组执行分类任务。对于上下文,数组是卫星图像的转换。 目标是将每个像素与其周围的邻居进行分类。因此,我在 50x50 的补丁上使用了 CNN。由于双循环滑动窗口的任务非常慢,我尝试使用多处理池。每个预测的结果存储在一个名为“classif”的 numpy 中,该 numpy 包含每个像素的类。 我知道代码由于多次打印而起作用,但它似乎没有修改数组“classif”。事实上,输出总是一个充满零的数组。

我想我并没有真正修改“classif”,而只是它的一个副本,但我似乎无法使用我在网上找到的方法访问它。

有人可以帮我解决这个问题吗?

X_test 是我的大数组

patch_width = 50               
patch_height = patch_width
pad= 25
NUM_PROCESSES=4



img_height,img_width,band = X_test.shape


N = X_test.shape[0]
P = (NUM_PROCESSES + 1) 
partitions = list(zip(np.linspace(0,N,P,dtype=int)[:-1],np.linspace(0,dtype=int)[1:]))
work = partitions[:-1]
work.append((partitions[-1][0],partitions[-1][1]))


M = X_test.shape[1]
partitions2 = list(zip(np.linspace(0,M,dtype=int)[1:]))

# Final range of indices should end +1 past last index for completeness
work2 = partitions2[:-1]
work2.append((partitions2[-1][0],partitions2[-1][1] ))


work3=list()
for i in work:
    for j in work2:
        n= list((j,i))
        work3.append(n)



X_class = np.pad(X_test,((pad,pad),(pad,(0,0)),'reflect')
        
shm = shared_memory.SharedMemory(create=True,size=X_class.nbytes)
X_class_shared = np.ndarray(X_class.shape,dtype=X_class.dtype,buffer=shm.buf)
X_class_shared =X_class.copy()

classif_shape = np.zeros((img_height,img_width),dtype=int)

shm2 = shared_memory.SharedMemory(create=True,size=classif_shape.nbytes)
classif = np.ndarray(classif_shape.shape,dtype=classif_shape.dtype,buffer=shm2.buf)


def predict(height,width):
    for i in range(width[0],width[1]):
        for j in range(height[0],height[1]):
            x=j 
            x_bis = j+patch_height
            y=i
            y_bis=i +patch_width 
            patch = X_class_shared[x:x_bis,y:y_bis,:]
            X = np.expand_dims(patch,axis=0)
            test = cnn.predict(X)
            rounded = np.argmax(test,axis=1)
            classif[j,i]=rounded



def dispatch_jobs(data,job_number):

    pool =Pool(job_number)
    for i,s in work3:
        pool.apply_async(predict,args=(s,i))
    print(classif)
    pool.close()
    pool.join()

 if __name__ == "__main__":
    dispatch_jobs(work3,NUM_PROCESSES)
    

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。