#!/usr/bin/env python3
#
# Native python routines for Galois field calculations.
#
# These calculations can be done much faster in C/C++ (and with vectorization), but
# this code helps illustrate how Galois field math works.  Also, it's helpful if
# you want to do small amounts of calculation without working with native C/C++
# code.
#

import random

class GaloisException(Exception):
    def __init__ (self, value):
        self.value = value
    def __str__ (self):
        return repr(self.value)

class GaloisNumber:
    '''
    Class to represent a number in a Galois (finite) field.

    The class supports "normal" syntax for the addition, subtraction, multiplication,
    division, additive inverse (-), and multiplicative inverse (~) operations.
    '''
    def __init__ (self, x, field=None):
        if isinstance(x, GaloisNumber):
            self.field = x.field
            self.value = x.value
        elif type(field) not in (GaloisFieldLog, GaloisFieldDirect):
            raise ValueError ("Must specify field")
        else:
            self.field = field
            if type(x) == bytes:
                x = int.from_bytes (x, 'big')
            elif type(x) != int:
                raise ValueError ("Must be an integer or bytes")
            self.assign (x)

    def assign (self, v):
        '''
        Assign a new integer value to this Galois field number.  The number must be valid in
        the field with which the GaloisNumber instance was defined.
        '''
        if v > self.field.size:
            raise ValueError ("Value {0} is outside field".format (v))
        self.value = v

    def mul_region (self, arr, init_value = None):
        if not init_value:
            init_value = [GaloisNumber (0, self.field)] * len (arr)
        if any ([type(x) != GaloisNumber or type(x.field) != type(self.field) for x in arr]):
            raise ValueError ("All items must be Galois numbers in this field!")
        return [x * self + i for x, i in zip (arr, init_value)]

    def copy (self):
        return GaloisNumber (self.value, self.field)

    def __add__ (self, other):
        if self.field != other.field:
            raise GaloisException ("Field elements from different fields")
        return GaloisNumber (self.value ^ other.value, self.field)

    def __iadd__ (self, other):
        if self.field != other.field:
            raise GaloisException ("Field elements from different fields")
        return self + other

    def __sub__ (self, other):
        return self + other

    def __isub__ (self, other):
        return self + other

    def __invert__ (self):
        return self.field.invert (self)

    def __neg__ (self):
        return GaloisNumber (self)

    def __mul__ (self, other):
        if self.field != other.field:
            raise GaloisException ("Field elements from different fields")
        return self.field.multiply (self, other)

    def __imul__ (self, other):
        if self.field != other.field:
            raise GaloisException ("Field elements from different fields")
        return self.field.multiply (self, other)

    def __floordiv__ (self, other):
        if self.field != other.field:
            raise GaloisException ("Field elements from different fields")
        return self.field.divide (self, other)

    def __truediv__ (self, other):
        return self // other

    def __eq__ (self, other):
        if self.field != other.field:
            raise GaloisException ("Field elements from different fields")
        return self.value == other.value

    def __repr__ (self):
        return self.field.fmt (self.value)

    def to_bytes (self):
        return (self.value).to_bytes ((self.field.bits + 7) // 8, byteorder='big')

class GaloisFieldLog:
    '''
    Pure python implementation of Galois (finite) field arithmetic routines using log/antilog
    tables.

    There only needs to be one instantiation of the field for a given set of parameters,
    but elements from different field instances with the same parameters may be mixed.
    '''
    field_widths = (4, 6, 8, 12, 16)
    poly_defaults = {4: 0x13, 6: 0x43, 8: 0x11d, 12:0x1053, 16: 0x1100b}
    multiply_test_size = 10000
    def __init__ (self, bits, primitive_polynomial = None, repr_prefix = 'G', alpha = 1):
        '''
        Create a Galois field using log/antilog tables for arithmetic.
        '''
        if bits not in self.field_widths:
            raise GaloisException ("Field widths supported: {0}".format (self.field_widths))
        self.bits = bits
        self.size = (1 << bits)
        self.prim = self.poly_defaults[bits] if not primitive_polynomial else primitive_polynomial
        self.value_format = repr_prefix + '{:0>' + str(bits // 4) + 'x}'
        self.alpha = alpha
        # Set up the log and anti-log tables
        self.log_tbl = [0] * self.size
        self.antilog_tbl = [0] * (self.size - 1)
        b = 1
        for i in range (self.size - 1):
            self.log_tbl[b] = i
            self.antilog_tbl[i] = b
            b <<= 1
            if b >= self.size:
                b ^= self.prim

    def __eq__ (self, other):
        return self.bits == other.bits and self.prim == other.prim and self.alpha == other.alpha

    def fmt (self, v):
        return self.value_format.format (v)

    def multiply (self, v1, v2):
        a = v1.value
        b = v2.value
        if a == 0 or b == 0:
            return GaloisNumber (0, self)
        return GaloisNumber (self.antilog_tbl[(self.log_tbl[a] + self.log_tbl[b]) % (self.size - 1)], self)

    def invert (self, v):
        if v.value == 0:
            return GaloisNumber(0, self)
        elif v.value == 1:
            return GaloisNumber (1, self)
        else:
            return GaloisNumber (self.antilog_tbl[self.size - 1 - self.log_tbl[v.value]], self)

    def divide (self, v1, v2):
        return self.multiply (v1, self.invert(v2))

    def self_test (self):
        mul_identity = GaloisNumber (1, self)
        v = GaloisNumber (0, self)
        g_0 = GaloisNumber (0, self)
        g_1 = GaloisNumber (1, self)
        for i in range (self.size):
            v.assign (i)
            if i == 0: continue
            assert v * ~v == mul_identity, "Multiplicative inverse failed at {}".format (i)
            assert g_0 - v == -v, "Additive inverse failed at {}".format (i)
            assert v * g_1 == v, "Multiplicative identity failed at {}".format (i)
        vb = GaloisNumber (0, self)
        for a in range (1, self.multiply_test_size):
            v.assign (random.randint (1, self.size - 1))
            vb.assign (random.randint (1, self.size - 1))
            product = v * vb
            assert product / v == vb, "Multiplication failed for {} * {}".format(v.value,vb.value)
            assert product / vb == v, "Multiplication failed for {} * {}".format(v.value,vb.value)
        return True

class GaloisFieldDirect:
    '''
    Pure python implementation of Galois (finite) field arithmetic routines using direct
    arithmetic (no log tables).

    There only needs to be one instantiation of the field for a given set of parameters,
    but elements from different field instances with the same parameters may be mixed.
    '''
    field_widths = (4, 6, 8, 12, 16, 32, 64)
    poly_defaults = {4: 0x13, 6: 0x43, 8: 0x11d, 12:0x1053, 16: 0x1100b, 32: 0x1000000c5, 64: 0x1000000000000001b}
    max_test_size = 5000
    def __init__ (self, bits, primitive_polynomial = None, repr_prefix = 'G', alpha = 1):
        '''
        Create a Galois field using direct arithmetic.  No log tables or inverses to
        precalculate, since the field might be too large to store them
        '''
        if bits not in self.field_widths:
            raise GaloisException ("Field widths supported: {0}".format (self.field_widths))
        self.bits = bits
        self.size = (1 << bits)
        self.prim = self.poly_defaults[bits] if not primitive_polynomial else primitive_polynomial
        self.value_format = repr_prefix + '{:0>' + str(bits // 4) + 'x}'
        self.alpha = alpha

    def __eq__ (self, other):
        return self.bits == other.bits and self.prim == other.prim and self.alpha == other.alpha

    def fmt (self, v):
        return self.value_format.format (v)

    def multiply (self, v1, v2):
        return GaloisNumber (self.direct_multiply (v1.value, v2.value), self)

    def direct_multiply (self, a, b):
        # Multiplication is commutative, and it's faster if we use the smaller value as the
        # multiplier since we can exit the while loop sooner.
        if b > a:
            a, b = b, a
        if a == 0:
            result = 0
        else:
            result = a if b & 1 else 0
            tmp = a
            b >>= 1
            while b != 0:
                a <<= 1
                if a >= self.size:
                    a ^= self.prim
                if b & 1:
                    result ^= a
                b >>= 1
        return result

    def invert (self, v):
        '''
        Calculate inverse(v) by computing v^(field_size-2).
        This is just v^2 * v^4 ... v^(field_size / 2), so calculation time is proportional
        to field width in bits.
        '''
        if v.value == 0:
            return GaloisNumber(0, self)
        elif v.value == 1:
            return GaloisNumber (1, self)
        inv = 1
        sq =  v.value
        for i in range (1, self.bits):
            sq = self.direct_multiply (sq, sq)
            inv = self.direct_multiply (inv, sq)
        return GaloisNumber (inv, self)

    def divide (self, v1, v2):
        return self.multiply (v1, self.invert(v2))

    def self_test (self):
        mul_identity = GaloisNumber (1, self)
        v = GaloisNumber (0, self)
        g_0 = GaloisNumber (0, self)
        g_1 = GaloisNumber (1, self)
        small_field = self.size < self.max_test_size
        n_tests = (self.size - 1) if small_field else self.max_test_size
        for i in range (0, n_tests):
            v.assign ((i if small_field else random.randint (0, self.size - 2)) + 1)
            assert v * ~v == mul_identity, "Multiplicative inverse failed at {}".format (i)
            assert g_0 - v == -v, "Additive inverse failed at {}".format (i)
            assert v * g_1 == v, "Multiplicative identity failed at {}".format (i)
        vb = GaloisNumber (0, self)
        for a in range (1, self.max_test_size):
            v.assign (random.randint (1, self.size - 1))
            vb.assign (random.randint (1, self.size - 1))
            product = v * vb
            assert product / v == vb, "Multiplication failed for {} * {}".format(v.value,vb.value)
            assert product / vb == v, "Multiplication failed for {} * {}".format(v.value,vb.value)
        return True

if __name__ == '__main__':
    print ('\nTesting direct fields...........')
    for width in GaloisFieldDirect.field_widths:
        field = GaloisFieldDirect (width)
        g0 = GaloisNumber (2, field)
        g1 = GaloisNumber (7, field)
        print ('{0} + {1} = {2}'.format (g0, g1, g0 + g1))
        if field.self_test ():
            print ("{0} bit field (direct) passed!".format (width))
    print ('\nTesting log fields...........')
    for width in GaloisFieldLog.field_widths:
        field = GaloisFieldLog (width)
        g0 = GaloisNumber (2, field)
        g1 = GaloisNumber (7, field)
        print ('{0} + {1} = {2}'.format (g0, g1, g0 + g1))
        if field.self_test ():
            print ("{0} bit field (log) passed!".format (width))