#!/usr/bin/env python

import datetime
import os
import random
import shutil
import signal
import string
import sys

import docker
import pandas as pd
import SimpleITK as sitk

DIR = os.path.dirname(os.path.abspath(__file__))
INPUT   = os.path.join(DIR, 'INPUT')
GT      = os.path.join(DIR, 'GT')
PREDICT = os.path.join(DIR, 'PREDICT')
shutil.rmtree(PREDICT, ignore_errors=True)

client = docker.from_env()

ulimits=[
    docker.types.Ulimit(name='memlock', soft=-1, hard=-1), 
    docker.types.Ulimit(name='stack', soft=67108864, hard=67108864),
    ]

tmpfs = {
    '/tmp': '',
}


# def _handle_timeout(signum, frame):
#     # raise TimeoutError('function timeout')
#     c = client.containers.get(name)
#     print(datetime.datetime.now(), 'removing', name, c)
#     c.stop()
#     c.remove()


def container_run(image, input, output, device_id='0'):

    name = ''.join(random.choice(string.ascii_letters) for x in range(10))
    print(datetime.datetime.now(), 'container will be named', name)

    volumes = {
        os.path.abspath(input ): {'bind': '/INPUT' , 'mode': 'ro'},
        os.path.abspath(output): {'bind': '/OUTPUT', 'mode': 'rw'},
        }
    
    device_requests=[
        docker.types.DeviceRequest(
            device_ids=[device_id], 
            capabilities=[['gpu']])]

    print(datetime.datetime.now(), device_requests)
    print(datetime.datetime.now(), 'running', image)

    result = client.containers.run(
        image,
        device_requests = device_requests,
        ipc_mode = 'host',
        ulimits = ulimits,
        auto_remove = True,
        # remove = True,
        name = name,
        volumes = volumes,
        tmpfs = tmpfs,
        )
    return result


def load_and_run(im, input, output, device_id='0'):

    # timeout_sec = 60*60
    # signal.signal(signal.SIGALRM, _handle_timeout)
    # signal.alarm(timeout_sec)

    if os.path.isfile(im):
        print(datetime.datetime.now(), 'loading', im)
        with open(im, 'rb') as f:
            image = client.images.load(f)[0]
    else:
        image = im
        # image = client.images.pull(im)

    try:
        result = container_run(image, input, output, device_id)
    finally:
        signal.alarm(0)

    if image != im:
        client.images.remove(image.id, force=True)

    print(datetime.datetime.now(), result)
    return result


def overlap(gt, pred):
    overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
    overlap_measures_filter.Execute(pred, gt)

    # GetFalsePositiveError was renamed to GetFalseDiscoveryRate in simplteitk 2.3
    if hasattr(overlap_measures_filter, 'GetFalseDiscoveryRate'):
        false_positive = overlap_measures_filter.GetFalseDiscoveryRate(1)
    else:
        false_positive = overlap_measures_filter.GetFalsePositiveError(1)

    r = {
        'precision': 1-false_positive,
        'recall': 1-overlap_measures_filter.GetFalseNegativeError(1),
        'dice': overlap_measures_filter.GetDiceCoefficient(1),
        'jaccard': overlap_measures_filter.GetJaccardCoefficient(1),
    }

    if not 0 <= r['precision'] <= 1:
        r['precision'] = 0
    if not 0 <= r['recall'] <= 1:
        r['recall'] = 0

    return r


def measure(GT_DIR, PREDICT_DIR):

    print(GT_DIR, PREDICT_DIR)

    MEASURES = []
    for item in sorted(os.listdir(GT_DIR)):
        if '.nii.gz' not in item:
            ov  = {
                'name': item.split('.')[0],
                'precision': 0,
                'recall': 0,
                'dice': 0,
            }
        else:

            # lab = sitk.Cast(sitk.ReadImage(os.path.join(GT_DIR, item)), sitk.sitkUInt8)
            # seg = sitk.ReadImage(os.path.join(PREDICT_DIR, item))
            lab = sitk.ReadImage(os.path.join(GT_DIR, item), sitk.sitkFloat32)
            seg = sitk.ReadImage(os.path.join(PREDICT_DIR, item), sitk.sitkFloat32)
            lab = sitk.BinaryThreshold(lab, 0.5, 300)
            seg = sitk.BinaryThreshold(seg, 0.5, 300)

            seg.SetOrigin(lab.GetOrigin())
            seg.SetSpacing(lab.GetSpacing())

            measure = overlap(lab, seg)

            ov  = {
                'name': item.split('.')[0],
                'precision': measure['precision'],
                'recall': measure['recall'],
                'dice': measure['dice'],
            }

        print (ov)
        MEASURES.append(ov)
        # break

    measures = pd.DataFrame(MEASURES)
    measures.to_excel(os.path.join(PREDICT_DIR, '-0.xlsx'))

    # print(measures.to_string())
    # print(measures.describe())

    return measures

def eval_image(image, input, gt, predict, device_id):
    load_and_run(image, input, predict, device_id)
    results = measure(gt, predict)
    return results


if __name__ == '__main__':
    device_id = '0'
    argc = len(sys.argv)
    if argc < 2:
        print('Usage:', sys.argv[0], 'docker_image [GPU_id]')
        print(' docker_image: sample.tar.zst | xfuren/icts2023-sample:1126')
        print(' GPU_id      : default 0')
        exit()
    elif argc > 2:
        device_id = sys.argv[2]
    eval_image(sys.argv[1], INPUT, GT, PREDICT, device_id)
