import argparse


def get_param():
    parser = argparse.ArgumentParser(description="GAN Training Arguments")

    parser.add_argument("--epochs", "-e", type=int, required=True, help="Total number of epochs")
    parser.add_argument("--batch_size", "-b", type=int, required=True, help="Size of single batch")
    parser.add_argument("--save_interval", "-s", type=int, required=True, help="Interval for saving weights")
    parser.add_argument("--sample_interval", type=int, required=True, help="Interval for saving inference result")
    parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for computation")
    parser.add_argument("--load", "-l", type=str, help="Path to previous weights for continuing training")
    parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of generator")
    parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times generator trains in a single epoch")
    parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks of RNN in attention network")
    parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each attention RNN blocks")
    parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. If not given, it is assumed to be the same as the generator")

    args = parser.parse_args()
    return args
