/*
 * (C) Copyright 2017
 * Timesys Corporation
 *
 * SPDX-License-Identifier:	GPL-2.0+
 */

#include <common.h>
#include <command.h>
#include <ext4fs.h>
#include <fs.h>
#include <image.h>
#include <hash.h>
#include <malloc.h>
#include <memalign.h>
#include <mmc.h>
#include <u-boot/sha256.h>

#define HASH_BYTES SHA256_SUM_LEN
#define READ_BUF_SIZE   (8 * SZ_1M)
#define FIT_MMC_NODE "mmc"
#define FIT_MMC_HASH_PROP "mmc-hash"

static int do_mmc_hash_verify(cmd_tbl_t *cmdtp, int flag, int argc, char * const argv[])
{
	int ret, cnt;

	/* MMC variables */
	struct mmc *mmc;
	u8 *tmp_buf;
	lbaint_t curr_block, last_block, read_size, max_read, total_size;
	disk_partition_t info;
	int mmc_dev, mmc_part;
	char mmc_str[4];

	/* Hash variables */
	struct hash_algo *algo;
	void *ctx;
	u8 calculated_hash_val[HASH_BYTES];
	/* Hardcoded, could be read from FIT image */
	const char *algo_name = "sha256";

	/* FIT image variables */
	int conf_node_off, len = 0;
	u8 *fitImage_hash_val;
	ulong fit_addr = 0UL;

	if (argc != 4)
		return CMD_RET_USAGE;
	
	fit_addr = simple_strtoul(argv[1], NULL, 16);
	mmc_dev = simple_strtoul(argv[2], NULL, 16);
	mmc_part = simple_strtoul(argv[3], NULL, 16);

	/* STEP 1: MMC setup */
	mmc = find_mmc_device(mmc_dev);
	if (!mmc) {
		puts("mmc device not found!!\n");
		return CMD_RET_FAILURE;
	}

	if (mmc_init(mmc)) {
		puts("MMC init failed\n");
		return CMD_RET_FAILURE;
	}

	ret = get_partition_info(&mmc->block_dev, mmc_part, &info);
	if (ret) {
		puts("MMC partition error\n");
		return CMD_RET_FAILURE;
	}

	sprintf(mmc_str, "%u:%u", mmc_dev, mmc_part);
	if (fs_set_blk_dev("mmc", mmc_str, FS_TYPE_EXT))
	{
		puts("unable to get fs\n");
		return CMD_RET_FAILURE;
	}

	total_size = ext4fs_totalsize();
	/* Convert to sectors */
	total_size /= info.blksz;

	printf("MMC: dev#%d, part#%d, Start: %lu, Size: %lu, Block Size: %lu\n",
       mmc_dev, mmc_part, info.start, total_size, info.blksz);

	/* STEP 2: Hash setup */
	ret = hash_progressive_lookup_algo(algo_name, &algo);
	if (ret)
		return CMD_RET_FAILURE;

	ret = algo->hash_init(algo, &ctx);
	if (ret)
		return CMD_RET_FAILURE;

	/* STEP 3: Read mmc and compute hash */
	tmp_buf = malloc_cache_aligned(READ_BUF_SIZE);

	if (!tmp_buf)
	{
		puts("malloc failed\n");
		return CMD_RET_FAILURE;
	}

	curr_block = info.start;
	last_block = info.start + total_size;
	max_read = READ_BUF_SIZE/info.blksz;

	while(curr_block < last_block)
	{
		if(curr_block + max_read >= last_block)
			read_size = last_block - curr_block;
		else
			read_size = max_read;

		ret = mmc->block_dev.block_read(&mmc->block_dev, curr_block, read_size, tmp_buf);
		if (!ret) {
			printf("\nMMC read failed at block: %lu",curr_block);
			free(tmp_buf);
			return CMD_RET_FAILURE;
		}

		ret = algo->hash_update(algo, ctx, tmp_buf, read_size * info.blksz, 0);
		if (ret)
		{
			puts("hash update failed\n");
			free(tmp_buf);
			return CMD_RET_FAILURE;
		}

		curr_block += read_size;
	}
	free(tmp_buf);

	ret = algo->hash_finish(algo, ctx, calculated_hash_val, algo->digest_size);
	if (ret)
		return CMD_RET_FAILURE;

	/* STEP 4: Read hash from FIT image */
	conf_node_off = fit_image_get_node((void*)fit_addr, FIT_MMC_NODE);
	if (conf_node_off < 0) {
		printf("MMC Hash Check: %s: no such config\n", FIT_MMC_NODE);
		return CMD_RET_FAILURE;
	}

	fitImage_hash_val = (uint8_t *)fdt_getprop((void*)fit_addr, conf_node_off, FIT_MMC_HASH_PROP, &len);

	if (len != HASH_BYTES) {
		puts("invalid hash length\n");
		return CMD_RET_FAILURE;
	}

	/* STEP 5: Compare hash */
	if (memcmp(fitImage_hash_val, calculated_hash_val, HASH_BYTES) != 0)
	{
		printf("Bad hash value for partition: %d\n", mmc_part);
		printf("MMC Hash: ");
		for (cnt = 0; cnt < HASH_BYTES; cnt++)
		{
			printf("%02x", calculated_hash_val[cnt]);
		}
		printf("\nFit Hash: ");
		for (cnt = 0; cnt < HASH_BYTES; cnt++)
		{
			printf("%02x", fitImage_hash_val[cnt]);
		}
		printf("\n");
		return CMD_RET_FAILURE;
	}

	return CMD_RET_SUCCESS;
}

U_BOOT_CMD(mmc_hash_verify, 4, 0, do_mmc_hash_verify,
	"Verify mmc partition hash with hash inside fit Image",
	"[fit_addr] [mmc_dev_number] [mmc_partition_number]\n"
	"\t eg: run mmc_hash_verify 0x80100000 0 2\n"
	"\t where 0x80100000 is address of fitImage with mmc partiton hash\n"
	"\t       0 2 correspond to mmc0blkp2"
);
