微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

类型错误:sample_chain() 得到了一个意外的关键字参数“种子”——Tensorflow 2.0

如何解决类型错误:sample_chain() 得到了一个意外的关键字参数“种子”——Tensorflow 2.0

在 MacOS 10.13.6 上使用 MCMC 时 Tensorflow 2.0 出错

控制台上的错误

2020-12-27 22:06:48.253835: I tensorflow/core/platform/cpu_feature_guard.cc:145] This TensorFlow binary is optimized with Intel(R) MKL-DNN to use the following cpu instructions in performance critical operations:  SSE4.1 SSE4.2 AVX
To enable them in non-MKL-DNN operations,rebuild TensorFlow with the appropriate compiler flags.
2020-12-27 22:06:48.254353: I tensorflow/core/common_runtime/process_util.cc:115] Creating new thread pool with default inter op setting: 4. Tune using inter_op_parallelism_threads for best performance.
objc[69111]: Class zmAppHelper is implemented in both /Library/ScriptingAdditions/zOLPluginInjection.osax/Contents/MacOS/zOLPluginInjection (0x1a48eaf4f0) and /Library/Application Support/Microsoft/ZoomOutlookPlugin/zOutlookPlugin64.bundle/Contents/MacOS/zOutlookPlugin64 (0x1a490e0518). One of the two will be used. Which one is undefined.
objc[69111]: class `ERCalendarEventEditorWindowController' not linked into application
Traceback (most recent call last):
  File "dc7.py",line 131,in <module>
    chains,kernel_results = run_chain(initial_state)
  File "/Users/ram/opt/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py",line 457,in __call__
    result = self._call(*args,**kwds)
  File "/Users/ram/opt/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py",line 503,in _call
    self._initialize(args,kwds,add_initializers_to=initializer_map)
  File "/Users/ram/opt/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py",line 408,in _initialize
    *args,**kwds))
  File "/Users/ram/opt/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py",line 1848,in _get_concrete_function_internal_garbage_collected
    graph_function,_,_ = self._maybe_define_function(args,kwargs)
  File "/Users/ram/opt/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py",line 2150,in _maybe_define_function
    graph_function = self._create_graph_function(args,line 2041,in _create_graph_function
    capture_by_value=self._capture_by_value),File "/Users/ram/opt/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py",line 915,in func_graph_from_py_func
    func_outputs = python_func(*func_args,**func_kwargs)
  File "/Users/ram/opt/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py",line 358,in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args,**kwds)
  File "/Users/ram/opt/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py",line 905,in wrapper
    raise e.ag_error_Metadata.to_exception(e)
TypeError: in converted code:

    dc7.py:117 run_chain  *
        return tfp.mcmc.sample_chain(

    **TypeError: sample_chain() got an unexpected keyword argument 'seed'**

版本

MacOS 10.13.6 High Sierra

tensorflow                2.0.0           mkl_py37hda344b4_0  
tensorflow-base           2.0.0           mkl_py37h66b1bf0_0  
tensorflow-estimator      2.0.0              pyh2649769_0  
tensorflow-probability    0.8.0                      py_0    conda-forge
jupyter_client            6.1.7                      py_0  
jupyter_core              4.7.0            py37hecd8cb5_0  
jupyterlab_pygments       0.1.2                      py_0  
ipython                   7.19.0           py37h01d92e1_0  
ipython_genutils          0.2.0              pyhd3eb1b0_1  
python                    3.7.9                h26836e1_0  
python-dateutil           2.8.1                      py_0  
python_abi                3.7                     1_cp37m    conda-forge

代码

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

#import tensorflow as tf
#print(tf.__version__)

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import tensorflow_probability as tfp

sns.reset_defaults()
sns.set_context(context = 'talk',font_scale = 0.7)
plt.rcParams['image.cmap'] = 'viridis'

#%matplotlib inline

tfd = tfp.distributions
tfb = tfp.bijectors


#### ============================================

#@title Utils { display-mode: "form" }
def print_subclasses_from_module(module,base_class,maxwidth=80):
  import functools,inspect,sys
  subclasses = [name for name,obj in inspect.getmembers(module)
                if inspect.isclass(obj) and issubclass(obj,base_class)]
  def red(acc,x):
    if not acc or len(acc[-1]) + len(x) + 2 > maxwidth:
      acc.append(x)
    else:
      acc[-1] += "," + x
    return acc
  print('\n'.join(functools.reduce(red,subclasses,[])))

# Generate some data
def f(x,w):
  # Pad x with 1's so we can add bias via matmul
  x = tf.pad(x,[[1,0],[0,0]],constant_values=1)
  linop = tf.linalg.LinearOperatorFullMatrix(w[...,np.newaxis])
  result = linop.matmul(x,adjoint=True)
  return result[...,:]

num_features = 2
num_examples = 50
noise_scale = .5
true_w = np.array([-1.,2.,3.])

xs = np.random.uniform(-1.,1.,[num_features,num_examples])
ys = f(xs,true_w) + np.random.normal(0.,noise_scale,size=num_examples)

# Visualize the data set
plt.scatter(*xs,c=ys,s=100,linewidths=0)

grid = np.meshgrid(*([np.linspace(-1,1,100)] * 2))
xs_grid = np.stack(grid,axis=0)
fs_grid = f(xs_grid.reshape([num_features,-1]),true_w)
fs_grid = np.reshape(fs_grid,[100,100])
plt.colorbar()
plt.contour(xs_grid[0,...],xs_grid[1,fs_grid,20,linewidths=1)
plt.show()

### Sampling the noise scale

# Define the joint_log_prob function,and our unnormalized posterior.
def joint_log_prob(w,sigma,x,y):
  # Our model in maths is
  #   w ~ MVN([0,diag([1,1]))
  #   y_i ~ normal(w @ x_i,noise_scale),i=1..N

  rv_w = tfd.MultivariatenormalDiag(
    loc=np.zeros(num_features + 1),scale_diag=np.ones(num_features + 1))
  
  rv_sigma = tfd.Lognormal(np.float64(1.),np.float64(5.))

  rv_y = tfd.normal(f(x,w),sigma[...,np.newaxis])
  return (rv_w.log_prob(w) +
          rv_sigma.log_prob(sigma) +
          tf.reduce_sum(rv_y.log_prob(y),axis=-1))

# Create our unnormalized target density by currying x and y from the joint.
def unnormalized_posterior(w,sigma):
  return joint_log_prob(w,xs,ys)


# Create an HMC TransitionKernel
hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
  target_log_prob_fn=unnormalized_posterior,step_size=np.float64(.1),num_leapfrog_steps=4)



# Create a TransformedTransitionKernl
transformed_kernel = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=hmc_kernel,bijector=[tfb.Identity(),# w
              tfb.Invert(tfb.softplus())])   # sigma


# Apply a simple step size adaptation during burnin
@tf.function
def run_chain(initial_state,num_results=1000,num_burnin_steps=500):
  adaptive_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
      transformed_kernel,num_adaptation_steps=int(.8 * num_burnin_steps),target_accept_prob=np.float64(.75))

  return tfp.mcmc.sample_chain(
    num_results=num_results,num_burnin_steps=num_burnin_steps,current_state=initial_state,kernel=adaptive_kernel,seed=(0,1),trace_fn=lambda cs,kr: kr)


# Instead of a single set of initial w's,we create a batch of 8.
num_chains = 8
initial_state = [np.zeros([num_chains,num_features + 1]),.54 * np.ones([num_chains],dtype=np.float64)]

chains,kernel_results = run_chain(initial_state)

r_hat = tfp.mcmc.potential_scale_reduction(chains)
print("Acceptance rate:",kernel_results.inner_results.inner_results.is_accepted.numpy().mean())
print("R-hat diagnostic (per w variable):",r_hat[0].numpy())
print("R-hat diagnostic (sigma):",r_hat[1].numpy())

w_chains,sigma_chains = chains

解决方法

我使用了不兼容的 tensorflow 和 tensorflow_probability 版本。 在以下版本中,上述 Typererror 消失了: ipython 7.19.0 pypi_0 pypi ipython-genutils 0.2.0 pypi_0 pypi 蟒蛇 3.7.9 h26836e1_0
python-dateutil 2.8.1 pypi_0 pypi 张量板 2.4.0 pypi_0 pypi 张量板插件机智 1.7.0 pypi_0 pypi 张量流 2.3.0 pypi_0 pypi 张量流估计器 2.3.0 pypi_0 pypi 张量流概率 0.11.0 pypi_0 pypi

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。