import argparse def get_param(): parser = argparse.ArgumentParser(description="GAN Training Arguments") parser.add_argument("--epochs", "-e", type=int, required=True, help="Total number of epochs") parser.add_argument("--batch_size", "-b", type=int, required=True, help="Size of single batch") parser.add_argument("--save_interval", "-s", type=int, required=True, help="Interval for saving weights") parser.add_argument("--sample_interval", type=int, required=True, help="Interval for saving inference result") parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for computation") parser.add_argument("--load", "-l", type=str, help="Path to previous weights for continuing training") parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of generator") parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times generator trains in a single epoch") parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks of RNN in attention network") parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each attention RNN blocks") parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. If not given, it is assumed to be the same as the generator") args = parser.parse_args() return args