使用具有相同超参数的 gpytorch 重现高斯过程的预测协方差的问题

如何解决使用具有相同超参数的 gpytorch 重现高斯过程的预测协方差的问题

我需要构建一个函数来给出高斯过程的后验协方差。这个想法是使用 GPytorch 训练一个 GP,然后获取学习的超参数,并将它们传递到我的内核函数中。 (出于多种原因,我不能直接使用 GPyTorch)。

现在的问题是我无法重现预测。这里是我写的代码。我已经研究了一整天,但我找不到问题所在。你知道我做错了什么吗?

        from gpytorch.mlls import ExactMarginalLogLikelihood 
        import numpy as np
        import gpytorch
        import torch
        train_x1 = torch.linspace(0,0.95,50) + 0.05 * torch.rand(50)
        train_y1 = torch.sin(train_x1 * (2 * np.pi)) + 0.2 * torch.randn_like(train_x1)

        n_datapoints = train_x1.shape[0]

        def kernel_rbf(x1,x2,c,l):
            # my RBF function
            if x1.shape is ():
                x1 = np.atleast_2d(x1)
            if x2.shape is ():
                x2 = np.atleast_2d(x2)
            return c * np.exp(- np.matmul((x1 - x2).T,(x1 - x2)) / (2 * l ** 2))

        class ExactGPModel(gpytorch.models.ExactGP):
            def __init__(self,train_x,train_y,likelihood):
                super().__init__(train_x,likelihood)

                lengthscale_prior = gpytorch.priors.GammaPrior(3.0,6.0)
                outputscale_prior = gpytorch.priors.GammaPrior(2.0,0.15)

                self.mean_module = gpytorch.means.ConstantMean()
                self.covar_module = gpytorch.kernels.ScaleKernel(
                    gpytorch.kernels.RBFKernel(lengthscale_prior=lengthscale_prior),outputscale_prior=outputscale_prior)

            def forward(self,x):
                mean_x = self.mean_module(x)
                covar_x = self.covar_module(x)
                return gpytorch.distributions.Multivariatenormal(mean_x,covar_x)

        likelihood = gpytorch.likelihoods.GaussianLikelihood()
        model = ExactGPModel(train_x1,train_y1,likelihood)

        # Find optimal model hyperparameters
        model.train()
        likelihood.train()

        mll = ExactMarginalLogLikelihood(likelihood,model)

        # Use the Adam optimizer
        optimizer = torch.optim.Adam(model.parameters(),lr=0.1)  # Includes GaussianLikelihood parameters
        training_iterations = 50
        for i in range(training_iterations):
            optimizer.zero_grad()
            output = model(*model.train_inputs)
            loss = -mll(output,model.train_targets)
            loss.backward()
            print('Iter %d/%d - Loss: %.3f' % (i + 1,training_iterations,loss.item()))
            optimizer.step()

        # Get the learned hyperparameters 
        outputscale = model.covar_module.outputscale.item()
        lengthscale = model.covar_module.base_kernel.lengthscale.item()
        noise = likelihood.noise_covar.noise.item()

        train_x1 = train_x1.numpy()
        train_y1 = train_y1.numpy()

        # Get covariance train points
        K = np.zeros((n_datapoints,n_datapoints))
        for i in range(n_datapoints):
            for j in range(n_datapoints):
                K[i,j] = kernel_rbf(train_x1[i],train_x1[j],outputscale,lengthscale)

        # Add noise
        K += noise ** 2 * np.eye(n_datapoints)

        # Get covariance train-test points
        x_test = torch.rand(1,1)
        Ks = np.zeros((n_datapoints,1))
        for i in range(n_datapoints):
            Ks[i] = kernel_rbf(train_x1[i],x_test.numpy(),lengthscale)

        # Get variance test points
        Kss = kernel_rbf(x_test.numpy(),lengthscale)

        L = np.linalg.cholesky(K)
        v = np.linalg.solve(L,Ks)
        var = Kss - np.matmul(v.T,v)

        model.eval()
        likelihood.eval()
        with gpytorch.settings.fast_pred_var():
            y_preds = likelihood(model(x_test))

        print(f"Predicted variance with gpytorch:{y_preds.variance.item()}")
        print(f"Predicted variance with my kernel:{var}")

解决方法

我发现了错误:

  1. 噪声不是平方,所以它是 K += noise * np.eye(n_datapoints) 而不是 K += noise**2 * np.eye(n_datapoints)
  2. 我忘记在 $$ K** $$ 中添加噪声项,即 Kss += noise

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?