当前位置 : 主页 > 编程语言 > python >

根据baselines库修改的运行输入参数的解析代码

来源:互联网 收集:自由互联 发布时间:2022-06-15
如题: def arg_parser(): """ Create an empty argparse.ArgumentParser. """ import argparse parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--env', help='environment ID', type=s


如题:

def arg_parser():
"""
Create an empty argparse.ArgumentParser.
"""
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
parser.add_argument('--num_timesteps', type=float, default=1e6),
parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None)
parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None)
parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float)
parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str)
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str)
parser.add_argument('--play', default=False, action='store_true')

return parser.parse_known_args()

def parse_unknown_args(args):
"""
Parse arguments not consumed by arg parser into a dictionary
"""
retval = {}
preceded_by_key = False
for arg in args:
if arg.startswith('--'):
if '=' in arg:
key = arg.split('=')[0][2:]
value = arg.split('=')[1]
retval[key] = value
else:
key = arg[2:]
preceded_by_key = True
elif preceded_by_key:
retval[key] = arg
preceded_by_key = False

return retval

def parse_cmdline_kwargs(args, unknown_args):
'''
convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
'''
def parse(v):

assert isinstance(v, str)
try:
return eval(v)
except (NameError, SyntaxError):
return v

args.__dict__.update({k: parse(v) for k,v in parse_unknown_args(unknown_args).items()})
return args


args, unknown_args = arg_parser()

print(args)
args = parse_cmdline_kwargs(args, unknown_args)
print(args)


运行:

python test.py --aaa=me --xxx=11.11  --abc=True   --cde=1+99

解析结果:

Namespace(alg='ppo2', env='Reacher-v2', env_type=None, gamestate=None, log_path=None, network=None, num_env=None, num_timesteps=1000000.0, play=False, reward_scale=1.0, save_path=None, save_video_interval=0, save_video_length=200, seed=None)

Namespace(aaa='me', abc=True, alg='ppo2', cde=100, env='Reacher-v2', env_type=None, gamestate=None, log_path=None, network=None, num_env=None, num_timesteps=1000000.0, play=False, reward_scale=1.0, save_path=None, save_video_interval=0, save_video_length=200, seed=None, xxx=11.11)

=======================================


比较规范的运行参数解析的代码,方便后续代码中对参数的调用。


网友评论