微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 Dask Distributed

如何解决在 Dask Distributed

我一直在努力解决这个问题 - 任何帮助将不胜感激。我不确定从这里开始到底要去哪里。

我正在使用 dask 并行化使用 scikit-image's MCP class 的最小成本路径计算

该类是用 Cython 编写的,从 Dask expects intermediate results to be serializable 开始,我实现了一个 Wrapper,可以在反序列化过程中“重新创建”MCP 类。

当我在没有 dask 的情况下运行代码或使用 dask 的单线程调度程序时,它需要更长的时间,但结果恢复正常。

但是,当我切换到使用进程或线程运行时(仍然使用 dask distributed),我没有收到任何错误,但我的结果中出现了一堆 np.inf

此外,结果本身与我在单个线程上运行的结果不一致。

在此处添加相关代码片段:

# Create a client locally
if cluster_type == 'local':
        try:
            client = Client('127.0.0.1:8786')
        except:   
            cluster = LocalCluster(n_workers = 8,processes=True,threads_per_worker=8,scheduler_port=8786)

            client = Client(cluster)
## Create wrapper for MCP
# Creates a wrapper for Cython MCP Class
class Wrapper(object):
    def __init__(self,get_mcp):
        self.get_mcp = get_mcp 
        self.mcp = get_mcp()

    def __reduce__(self):
        #https://stackoverflow.com/questions/19855156/whats-the-exact-usage-of-reduce-in-pickler
        # When unpickled,the filter will be reloaded
        return(self.__class__,(self.get_mcp,))


def load_mcp():
    print("...loading mcp...")
    inR = Rasterio.open(friction_raster_path)
    inD = inR.read()[0,:,:] 
    inD = np.array(inD,dtype=np.float128) * 30 # important to specify pixel size in meters here in oder to get correct measurements
    inD = np.array(inD,dtype=np.float32)
    inD = np.nan_to_num(inD)
    mcp = graph.MCP_Geometric(inD)
    return mcp


# Init the wrapper for MCP
wrapper = Wrapper(load_mcp)

# Only reload inR here to do the crs check
inR = Rasterio.open(friction_raster_path)
# Get costs from origin to dests
def get_costs_for_origin(wrapper,origin_id:str,origin_coords:tuple,dests:pd.DataFrame):
    # Todo - dests should be a list of tuples only
    res=[]
    origin_coords = [origin_coords]
    ends = dests.MCP_DESTS_COORDS.to_list()
    costs,traceback = wrapper.mcp.find_costs(starts=origin_coords,ends=ends)#ends=destinations.MCP_DESTS_COORDS.to_list())
    for idx,dest in enumerate(dests.to_dict(orient='records')):
        dest_coords = dest['MCP_DESTS_COORDS']
        tt = costs[dest_coords[0],dest_coords[1]]
        if tt > 9999999999:
            print(dest['id'])
            print(tt)
            raise ValueError("INF")
        res.append(
            {"d_id": dest['id'],"d_tt": tt}
        )
            
    return {"o_id": origin_id,"o_tfan": res}
# Run on distributed scheduler using processes
def run_async(wrapper:Wrapper,origins_d:pd.DataFrame,dests_d:pd.DataFrame):
    # broadcast the wrapper to all nodes
    wrapper = client.scatter(wrapper,broadcast=True)
    wait(wrapper)

    # broadcast destinations to all nodes.
    dests_d = client.scatter(dests_d,broadcast=True)
    wait(dests_d)

    #https://docs.dask.org/en/latest/futures.html
    tasks = []
    for idx,origin in enumerate(origins_d):
        print(f"Origin {idx} of {len(origins_d)}")
        task = dask.delayed(get_costs_for_origin)(
            wrapper=wrapper,origin_id = origin['id'],origin_coords = origin['MCP_DESTS_COORDS'],dests=dests_d)#client.submit(get_costs_for_origin,wrapper,ogin,dests)
        tasks.append(task)
    #all_res = client.gather(futures)
    all_res_dsk = dask.compute(*tasks)
    all_res_dsk = list(all_res_dsk)
    return all_res_dsk

我假设这与 MCP 类有关,但无法弄清楚可能导致 INF 发生的原因。

先谢谢大家!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。