#!/usr/bin/env python3
# Dump SentiBoard data (envelope aware and synchronized)

import serial
import pathlib
import argparse
import time
import struct
from dataclasses import dataclass

SB_HEADER_SIZE = 8
MIN_SB_PACKAGE_SIZE = 22
MAX_SB_PACKAGE_SIZE = 1024 * 8

class TaskFinishedException(Exception):
    def __init__(self, msg):
        Exception.__init__(self, msg)

def fletcher8(data : list):
    a = 0
    b = 0
    for d in data:
        a = (a + d) % 256
        b = (a + b) % 256
    return (a, b)

def prepare_log_path(path : pathlib.Path, format : str, index : int, max_files : int = 999) -> pathlib.Path:
    if not path.exists():
        path.mkdir()

    prev_path = None
    for i in range(index, index + max_files):
        new_path = path.joinpath(format.format(i))

        if prev_path == new_path:
            return new_path

        if not new_path.exists():
            return new_path

        prev_path = new_path

def read_package(data, args):
    '''Reads a SentiBoard package from the data and the lenght of the package including padding bytes.

    On failure returns None and the least amount of bytes to skip.'''

    # Check magic bytes (^B)
    if not data[:2] == b'^B':
        return None, 1

    header_cs = fletcher8(data[:6])
    if not header_cs == (data[6], data[7]):
        print('Header checksum failed')
        return None, 2

    data_len = struct.unpack('H', data[2:4])[0]
    total_len = SB_HEADER_SIZE + data_len
    padding_len = (4 - ( (total_len + 2) % 4)) % 4
    final_length = total_len + 2 + padding_len

    if len(data) < final_length:
        if args.verbose:
            print('Not enought data in buffer')
        return None, 0

    payload = data[SB_HEADER_SIZE:total_len]

    data_cs = fletcher8(payload)
    if not data_cs == (data[total_len], data[total_len + 1]):
        print('Data checksum failed')
        return None, 2

    return data[:final_length], final_length


def format_prefixed(num, base=1024):
    if num < base * 2 :
        return f'{num} '

    mbase = base * base
    if num < mbase * 2:
        return f'{num / base:0.2f}k'

    gbase = mbase * base
    if num < gbase * 2:
        return f'{num / mbase:0.2f}m'

    return f'{num / gbase:0.2f}b'


@dataclass
class PackageStats:
    total_data : int = 0
    data_skipped : int = 0
    n_packages : int = 0
    total_skipped : int = 0
    last_data_print : float = 0.0
    last_error_print : float = 0.0
    first_data_timestamp : float = 0.0

def print_data(package_stats):
    packages = package_stats.n_packages
    total = package_stats.total_data
    skipped = package_stats.total_skipped

    if skipped == 0:
        print(f'\rWritten packages:bytes: {format_prefixed(packages, 1000)}:\t{format_prefixed(total)}B', end='')
    else:
        print(f'\rWritten packages:bytes:skipped: {format_prefixed(packages, 1000)}:\t{format_prefixed(total)}B:\t{format_prefixed(skipped)}B', end='')


def write_packages(datafile, data, package_stats, args):
    now = time.time()
    while len(data) > MIN_SB_PACKAGE_SIZE:
        package_data, bytes_read = read_package(data, args)

        # Not enough data in the buffer to read the full package
        if bytes_read == 0:
            return data

        data = data[bytes_read:]
        if package_data is None:
            package_stats.data_skipped += bytes_read
            continue

        if package_stats.data_skipped > 0:
            if now - package_stats.last_error_print > args.report_delay:
                print(f'Out of sync - skipped {package_stats.data_skipped} bytes')
                package_stats.last_error_print = now
            package_stats.total_skipped += package_stats.data_skipped
            package_stats.data_skipped = 0

        datafile.write(package_data)
        package_stats.n_packages += 1
        package_stats.total_data += len(package_data)

        if args.count > 0 and package_stats.n_packages >= args.count:
            raise TaskFinishedException(f'Package limit reached: {package_stats.n_packages}')
    return data

def check_timeout(args, package_stats, now=None):
    if now == None:
        now = time.time()
    if args.timeout > 0 and now - package_stats.first_data_timestamp > args.timeout:
        if package_stats.first_data_timestamp <= 0:
            return
        msg = f'Collection timeout reached: {now - package_stats.first_data_timestamp : 0.2f} / {args.timeout} s'
        raise TaskFinishedException(msg)

def write_to_file(com, datafile):
    package_stats = PackageStats()

    data = b''
    while True:
        data += com.read_all()
        if len(data) < MIN_SB_PACKAGE_SIZE:
            time.sleep(args.wait_timeout)
            check_timeout(args, package_stats)
            continue

        if package_stats.first_data_timestamp == 0:
            package_stats.first_data_timestamp = time.time()

        data = write_packages(datafile, data, package_stats, args)
        now = time.time()
        if now - package_stats.last_data_print > args.report_delay:
            package_stats.last_data_print = now
            print_data(package_stats)
        check_timeout(args, package_stats, now)

def dump_data(args):
    while True:
        print(f'Opening serial port: {args.device}')
        try:
            com = serial.Serial(args.device)
            if not args.no_flush:
                com.read_all()
            filename = prepare_log_path(args.dir, args.format, args.index)
            print(f'Writing serial data into {filename}')

            with open(filename, 'wb') as f:
                write_to_file(com, f)

        except KeyboardInterrupt as ex:
            print('\nShutting down due to keyboard interrupt')
            return
        except TaskFinishedException as ex:
            print('\nTask finished: ', ex)
            return
        except Exception as ex:
            print()
            print(ex)
            if args.reconnect_timeout < 0:
                return
            if args.reconnect_timeout > 0:
                time.sleep(args.reconnect_timeout)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dir', type=pathlib.Path, default='data', help='Output directory')
    parser.add_argument('-D', '--device', default='/dev/ttySentiboard02', help='Input [D]evice')
    parser.add_argument('-c', '--count', type=int, default=-1, help='If enabled: capture this many packages')
    parser.add_argument('-t', '--timeout', type=float, default=-1, help='If enabled: capture data for this many seconds.')
    parser.add_argument('-f', '--format', default='data_{0:04}.senti',
                        help='Filename format template. Can take 1 parameter (index).')
    parser.add_argument('-F', '--no-flush', default=True, action='store_false', help='Don\'t flush all the data on the port before starting log')
    parser.add_argument('-i', '--index', type=int, default=1, help='The starting index of the filename')
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('-w', '--wait-timeout', type=float, default=0.01, help='Time to wait before next serial read')
    parser.add_argument('-a', '--reconnect-timeout', type=int, default=3, help='Serial reconnect timeout (-1 disable)')
    parser.add_argument('-z', '--report-delay', type=float, default=0.5, help='Minimum delay between print reports')
    args = parser.parse_args()
    dump_data(args)
