微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 ray tune 中跨多个调用获取不同的配置集 等等,但是为什么

如何解决在 ray tune 中跨多个调用获取不同的配置集 等等,但是为什么

我正在努力使我的代码可重现。我已经添加了 np.random.seed(...) 和 random.seed(...),目前我没有使用 pytorch 或 tf,因此没有调度程序或搜索程序可以引入任何随机问题。使用上述代码生成的配置集在多次调用中应该始终相同。然而,事实并非如此。

有人可以帮忙吗?

谢谢!

代码如下:

import ray
from ray import tune
import random
import numpy as np

def training_function(config,data_init):
    print('CONfig:',config)
    tune.report(end_of_training=1,acc=0,f=0)

if __name__ == '__main__':
    ray.init(num_cpus=12)
    tune_config = {'sentence_classification': False,'norm_word_emb': tune.choice(['True','False']),'use_crf': tune.choice(['True','use_char': tune.choice(['True','word_seq_feature': tune.choice(['CNN','LSTM','GRU']),'char_seq_feature': tune.choice(['CNN','seed_num': 1267}
    data = {'a': 1}
    tune_seed = tune_config['seed_num']
    random.seed(tune_seed)
    np.random.seed(tune_seed)
    n_samples = 15
    exp_name = 'experiment_name'
    analysis = tune.run(
        tune.with_parameters(training_function,data_init={'data': data}),name=exp_name,metric="f",mode="max",queue_trials=True,config=tune_config,num_samples=n_samples,resources_per_trial={"cpu": 1},checkpoint_at_end=True,max_failures=0,)

解决方法

函数级 API 无法重现(ray v1.1.0,可能会发生变化)。

等等,但是为什么

  1. tune.run 创建一个 Experiment 对象,将您的函数传递到那里。
  2. Experiment registers 可通过调用 register_trainable
  3. 训练的函数
  4. register_trainable 使用 wrap_function
  5. 包装您的函数
  6. wrap_function 将通过继承 FunctionRunner 类来创建类级 API(射线 Actor)。
  7. FunctionRunnersetup 方法没有任何回调访问权限。

Actor 的工作方式过于简单,它在工作人员之间分配,然后使用 setup 方法在不同的进程中初始化。这就是为什么在自定义 Trainable 中传递种子和实现初始化逻辑至关重要,如 this 答案中所述。需要播种是因为 tune.choice 只是 random/np.random 函数的包装器。您可以在 tune/sample.py 中观察到这一点。

看例子:


import ray
from ray import tune
import random
import numpy as np

class Tunable(tune.Trainable):
    def setup(self,config):
        self.config = config
        self.seed = config['seed_num']
        random.seed(self.seed)
        np.random.seed(self.seed)
    
    def step(self):
        print('CONFIG:',self.config)
        return {tune.result.DONE: 'done','acc': 0,'f': 0}

if __name__ == '__main__':
    ray.init(num_cpus=12)
    tune_config = {'sentence_classification': False,'norm_word_emb': tune.choice(['True','False']),'use_crf': tune.choice(['True','use_char': tune.choice(['True','word_seq_feature': tune.choice(['CNN','LSTM','GRU']),'char_seq_feature': tune.choice(['CNN','seed_num': 1267}
    data = {'a': 1}
    tune_seed = tune_config['seed_num']
    n_samples = 15
    exp_name = 'experiment_name'
    analysis = tune.run(
        Tunable,name=exp_name,metric="f",mode="max",queue_trials=True,config=tune_config,num_samples=n_samples,resources_per_trial={"cpu": 1},checkpoint_at_end=False,max_failures=0,)
,

我看到了播种工作的行为。我运行了这个脚本:

 var sql = "INSERT into restaurant_reviews.restaurants(average_rating) SELECT AVG(rating) FROM restaurant_reviews.reviews INNER JOIN restaurant_reviews.restaurants ON restaurants_idrestaurant = idrestaurant

我跑过的地方:

import ray
from ray import tune
import numpy as np
import random


def training_function(config,data_init):
    print('CONFIG:',config)
    tune.report(end_of_training=1,acc=0,f=0)

if __name__ == '__main__':
    # ray.init(num_cpus=12)
    tune_config = {'sentence_classification': False,'seed': 1267}
    data = {'a': 1}
    tune_seed = tune_config['seed']
    random.seed(tune_seed)
    np.random.seed(tune_seed)
    n_samples = 15
    analysis = tune.run(
        tune.with_parameters(training_function,data_init={'data': data}),#name=exp_name,verbose=2,)

以及随后的运行:

Resources requested: 0/16 CPUs,0/0 GPUs,0.0/27.0 GiB heap,0.0/9.28 GiB objects
Current best trial: 84b84_00014 with f=0 and parameters={'sentence_classification': False,'norm_word_emb': 'False','use_crf': 'True','use_char': 'False','word_seq_feature': 'LSTM','char_seq_feature': 'GRU','seed': 1267}
Number of trials: 15/15 (15 TERMINATED)
+--------------------+------------+-------+--------------------+-----------------+------------+-----------+--------------------+--------+------------------+-------------------+-------+-----+
| Trial name         | status     | loc   | char_seq_feature   | norm_word_emb   | use_char   | use_crf   | word_seq_feature   |   iter |   total time (s) |   end_of_training |   acc |   f |
|--------------------+------------+-------+--------------------+-----------------+------------+-----------+--------------------+--------+------------------+-------------------+-------+-----|
| _inner_84b84_00000 | TERMINATED |       | LSTM               | True            | False      | False     | LSTM               |      1 |       0.00149202 |                 1 |     0 |   0 |
| _inner_84b84_00001 | TERMINATED |       | CNN                | False           | True       | False     | CNN                |      1 |       0.0014801  |                 1 |     0 |   0 |
| _inner_84b84_00002 | TERMINATED |       | GRU                | False           | False      | True      | GRU                |      1 |       0.00152397 |                 1 |     0 |   0 |
| _inner_84b84_00003 | TERMINATED |       | GRU                | False           | False      | False     | GRU                |      1 |       0.00165081 |                 1 |     0 |   0 |
| _inner_84b84_00004 | TERMINATED |       | CNN                | False           | False      | False     | CNN                |      1 |       0.00173998 |                 1 |     0 |   0 |
| _inner_84b84_00005 | TERMINATED |       | LSTM               | True            | True       | True      | CNN                |      1 |       0.00219083 |                 1 |     0 |   0 |
| _inner_84b84_00006 | TERMINATED |       | GRU                | True            | False      | False     | LSTM               |      1 |       0.00192428 |                 1 |     0 |   0 |
| _inner_84b84_00007 | TERMINATED |       | LSTM               | True            | False      | False     | CNN                |      1 |       0.00208902 |                 1 |     0 |   0 |
| _inner_84b84_00008 | TERMINATED |       | LSTM               | True            | True       | True      | GRU                |      1 |       0.00146484 |                 1 |     0 |   0 |
| _inner_84b84_00009 | TERMINATED |       | CNN                | False           | False      | True      | CNN                |      1 |       0.00152087 |                 1 |     0 |   0 |
| _inner_84b84_00010 | TERMINATED |       | LSTM               | False           | True       | False     | CNN                |      1 |       0.00124121 |                 1 |     0 |   0 |
| _inner_84b84_00011 | TERMINATED |       | LSTM               | True            | True       | True      | CNN                |      1 |       0.00124812 |                 1 |     0 |   0 |
| _inner_84b84_00012 | TERMINATED |       | LSTM               | True            | True       | True      | LSTM               |      1 |       0.00133514 |                 1 |     0 |   0 |
| _inner_84b84_00013 | TERMINATED |       | LSTM               | True            | False      | True      | CNN                |      1 |       0.00142407 |                 1 |     0 |   0 |
| _inner_84b84_00014 | TERMINATED |       | GRU                | False           | False      | True      | LSTM               |      1 |       0.00120211 |                 1 |     0 |   0 |
+--------------------+------------+-------+--------------------+-----------------+------------+-----------+--------------------+--------+------------------+-------------------+-------+-----+

请注意,试验和它们的配置完全相同(以相同的顺序)。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。