import argparse


def get_param():
    parser = argparse.ArgumentParser(description="GAN Training Arguments")

    parser.add_argument("--epochs", "-e", type=int, required=True, help="Total number of epochs")
    parser.add_argument("--batch_size", "-b", type=int, required=True, help="Size of single batch")
    parser.add_argument("--save_interval", "-s", type=int, required=True, help="Interval for saving weights")
    parser.add_argument("--sample_interval", type=int, required=True, help="Interval for saving inference result")
    parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for "
                                                                                                 "computation")
    parser.add_argument("--load", "-l", type=str, default=None, help="Path to previous weights for continuing training")
    parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of "
                                                                                              "generator")
    parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times "
                                                                                                "generator trains in "
                                                                                                "a single epoch")
    parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks "
                                                                                                  "of RNN in "
                                                                                                  "attention network")
    parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each "
                                                                                          "attention RNN blocks")
    parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. "
                                                                                   "If not given, it is assumed to be"
                                                                                   " the same as the generator")

    args = parser.parse_args()
    return args
