#!/usr/bin/env python3

import sys
import secrets
from galois import GaloisNumber as GN
import galois
from collections import namedtuple

gf = galois.GaloisFieldLog (8)

def _gf_unpack (v):
    return tuple(GN(b, gf) for b in v)
def _gf_pack (v):
    return b''.join ([b.to_bytes() for b in v])

class Secret:
    value_len = 32
    def __init__ (self, value = None):
        if not value:
            self.value = secrets.token_bytes (Secret.value_len)
        elif type(value) == str:
            self.value = bytes.fromhex (value)
        elif type(value) == bytes:
            self.value = value
        else:
            raise ValueError("Secret must be bytes or hex string")
        if len (self.value) != Secret.value_len:
            raise ValueError("Secret must be exactly {0} bytes long".format (Secret.value_len))
    def __str__ (self):
        return self.value.hex ()
    def __repr__ (self):
        return str(self)
    def __len__ (self):
        return len (self.value)

class SSS_Share:
    def __init__ (self, share_id, secret):
        if not 0 < share_id < 254:
            raise ValueError("Invalid share ID")
        if type(secret) != Secret:
            raise ValueError ("Invalid secret")
        self.share_id = share_id
        self.secret = secret
    def __str__ (self):
        return 'ID {0:2}: {1}'.format (self.share_id, self.secret)
    def __repr__ (self):
        return str(self)

def sss_generate_shares (secret, min_shares, share_ids):
    '''
    secret:     secret value to share - must be of type Secret
    min_shares: minimum number of shares to reconstruct secret
    share_ids:  list (iterable) of share identifiers
    Returns a list of shares.
    '''
    if any ([not 1 <= x <= 254 for x in share_ids]):
        raise ValueError ("Share IDs must be between 1-254")
    coeff = list()
    coeff.append (_gf_unpack (secret.value))
    for i in range (0, min_shares-1):
        coeff.append (_gf_unpack (Secret().value))
    shares = list()
    for i in share_ids:
        v = GN(i, gf)
        m = GN(1, gf)
        accum = [GN(0, gf)] * len (secret)
        for a in coeff:
            accum = m.mul_region (a, accum)
            m *= v
        shares.append (SSS_Share(i, Secret (_gf_pack(accum))))
    return shares


def sss_recover_secret (shares, share_id = 0):
    shrs = tuple (_gf_unpack (s.secret.value) for s in shares)
    ids = tuple (GN(x.share_id, gf) for x in shares)
    accum = [GN(0, gf)] * len(shares[0].secret)
    sid = GN(share_id, gf)
    for (x, shr) in zip (ids, shrs):
        prod = GN(1, gf)
        for i in ids:
            if i != x:
                prod *= (sid - i) / (i - x)
        accum = prod.mul_region (shr, accum)
    secret = Secret(_gf_pack (accum))
    if (share_id == 0):
        return secret
    else:
        return SSS_Share (share_id, secret)


def sss_generate_new_share (shares, share_id):
    '''
    Generate a new share with the passed share_id, compatible with
    the shares passed in.
    '''
    return sss_recover_secret (shares, share_id)