微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 python 中在 xarray 上应用 R 函数 qmap 加速

如何解决在 python 中在 xarray 上应用 R 函数 qmap 加速

我正在尝试在 Python 中的网格数据集上使用在 R 中实现的偏差校正函数。我在网上找到了一个循环遍历每个网格点的示例。

import pickle
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import sys
from rpy2.robjects.packages import importr
import rpy2.robjects.numpy2ri
rpy2.robjects.numpy2ri.activate()
qmap = importr('qmap')
from rpy2.robjects import pandas2ri
from rpy2.robjects import r
import glob
sys.setrecursionlimit(10000)
import traceback

def bias_correction(x,y):
    q_map = qmap.fitQmap(x,y,method="RQUANT",qstep=0.01,wett_day=False)
    qm1 = qmap.doQmap(y,q_map)
    bias_corrected_output = {}
    bias_corrected_output['params'] = q_map
    bias_corrected_output['outputs'] = qm1
    return bias_corrected_output

def bias_correction_model(y,q_map):
    qm1 = qmap.doQmap(y,q_map)
    bias_corrected_output = {}
    bias_corrected_output['outputs'] = qm1
    return bias_corrected_output

observed = 'temp_CRUJRA_1951-1985.nc'
prcp_hist = 'temp_ACCESS-ESM1-5_1951-1985.nc'
prcp_LIG = 'temp_LIG.nc'

observed = xr.open_dataset(observed)
model_hist = xr.open_dataset(prcp_hist)
model_LIG = xr.open_dataset(prcp_LIG)

observed = observed['temp']
model_hist = model_hist['temp']
model_LIG = model_LIG['temp']

lats = observed.lat.values
lons = observed.lon.values

bias_corrected_results_hist = np.zeros([len(model_hist.time.values),len(model_hist.lat.values),len(model_hist.lon.values)])
bias_corrected_results_hist[:] = np.nan

bias_corrected_results_LIG = np.zeros([len(model_LIG.time.values),len(model_LIG.lat.values),len(model_LIG.lon.values)])
bias_corrected_results_LIG[:] = np.nan

model_hist_values = model_hist.values
hist_dict = {}
hist_dict['time'] = model_hist.time.values
hist_dict['lon'] =  model_hist.lon.values
hist_dict['lat'] =  model_hist.lat.values

modelLIG_values = model_LIG.values
LIG_dict = {}
LIG_dict['time'] = model_LIG.time.values
LIG_dict['lon'] =  model_LIG.lon.values
LIG_dict['lat'] =  model_LIG.lat.values

observation_attr_values = observed.values

correct_params = []
for i,lat in enumerate(lats):
    for j,lon in enumerate(lons):
        params_dict = {}
        if np.isnan(model_hist_values[0,i,j]) or np.isnan(observation_attr_values[0,j]):
            bias_corrected_results_hist[:,j] = np.nan
            params_dict['lat'] = lat
            params_dict['lon'] = lon
            params_dict['params'] = np.nan
        else:
            try:
                y = model_hist_values[:,j]
                x = observation_attr_values[:,j]

                y_LIG = modelLIG_values[:,j]

                temp = bias_correction(x,y)
                q_map = temp['params']
                temp_LIG = bias_correction_model(y_LIG,q_map)

                bias_corrected_results_LIG[:,j] = temp_LIG['outputs']
                bias_corrected_results_hist[:,j] = temp['outputs']

                if i%5==0 and j%5==0:
                    print(lat,lon)
                params_dict['lat'] = lat
                params_dict['lon'] = lon
                params_dict['params'] = temp['params']

            except:
                bias_corrected_results_hist[:,j] = np.nan
                bias_corrected_results_LIG[:,j] = np.nan

                params_dict['lat'] = lat
                params_dict['lon'] = lon
                params_dict['params'] = np.nan

        correct_params.append(params_dict)

ds_hist = xr.Dataset({'temp': (('time','lat','lon'),bias_corrected_results_hist)},coords={'lat': lats,'lon': lons,'time':hist_dict['time'] })
ds_sspLIG = xr.Dataset({'temp': (('time',bias_corrected_results_LIG)},'time':LIG_dict['time'] })

ds_hist.to_netcdf('hist_temp_cor.nc')
ds_sspLIG.to_netcdf('LIG_temp_cor.nc')

对于尺寸为 420(时间)、360(纬度)和 720(经度)的输入文件,这需要大约 1.5 小时。有没有办法让代码更高效?循环遍历每个网格点使它变得如此缓慢,但我不知道如何解决这个问题。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。