import itasca as it

it.command("python-reset-state false")
import multiprocessing

max_threads = multiprocessing.cpu_count()
import numpy as np
import platform
import timeit
import json
import socket
import os
import subprocess
import six
import re
import tempfile
import ssl

# %time exec(open("_PFC3D_speed_test.py").read())
software_name = "PFC3D"
benchmark_version = 1.0


def get_WMIC_data(wmic_command):
    """
    Run the Windows WMIC command and return the header-result pairs as
    a list of dicts.
    """
    result = []
    startupinfo = subprocess.STARTUPINFO()
    startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
    text = subprocess.check_output("wmic " + wmic_command, startupinfo=startupinfo).decode("utf-8", "replace")
    lines = [s for s in text.splitlines() if s.strip()]
    if len(lines) == 0:
        return result
    header_line = lines[0]
    headers = re.findall('\S+\s+|\S$', header_line)
    pos = [0]
    for header in headers:
        pos.append(pos[-1] + len(header))
    for i in range(len(headers)):
        headers[i] = headers[i].strip()
    for r in range(1, len(lines)):
        row = {}
        for i in range(len(pos) - 1):
            row[headers[i]] = lines[r][pos[i]:pos[i + 1]].strip()
        result.append(row)
    return result


def get_lscpu_data():
    """ Run the Linux lscpu command (CPU info) and return the header-result pairs as
    a list of dicts.
    """
    result = []
    text = subprocess.check_output("lscpu", shell=True).decode("utf-8", "replace")
    lines = text.splitlines()
    row = {}
    for i in range(len(lines)):
        a = lines[i].split(':')
        row[a[0].strip()] = a[1].strip()
    result.append(row)
    return result


# def get_wmi_data(wmi_component):
#     import wmi
#     c = wmi.WMI()
#
#     if wmi_component == 'cpu':
#         cpu_info = []
#         for cpu in c.Win32_Processor():
#             cpu_details = {
#                 "Name": cpu.Name,
#                 "L3CacheSize": cpu.L3CacheSize,  # in KB
#                 "Cores": cpu.NumberOfCores,
#                 "Threads": cpu.NumberOfLogicalProcessors,
#                 "Architecture": cpu.SystemArchitecture,
#                 "Manufacturer": cpu.Manufacturer
#             }
#             cpu_info.append(cpu_details)
#         return cpu_info
#
#     if wmi_component == 'memorychip':
#         memory_info = []
#
#         # Fetch memory details
#         for mem in c.Win32_PhysicalMemory():
#             memory_details = {
#                 "Capacity": mem.Capacity,  # in bytes
#                 "Speed": mem.Speed,  # in MHz
#                 "Manufacturer": mem.Manufacturer
#             }
#             memory_info.append(memory_details)
#         return memory_info


def get_CPU_info():
    if os.name == "posix":
        return get_lscpu_data()
    if os.name == "nt":
        try:
            return get_WMIC_data("cpu")
        except Exception as e:
            # try:
            #     return get_wmi_data("cpu")
            # except Exception as e:
            try:
                processor_name = platform.processor()
                return [{'Name': processor_name, 'L3CacheSize': ''}]
            except Exception as e:
                return []


def get_RAM_info():
    if os.name == "posix":
        return []
    if os.name == "nt":
        try:
            return get_WMIC_data("memorychip")
        except Exception as e:
            # try:
            #     return get_wmi_data("memorychip")
            # except Exception as e:
            return []


code_major, code_minor = 0, 0


def get_version():
    global code_major, code_minor
    it.command("[global v0 = version.code.major]")
    it.command("[global v1 = version.code.minor]")
    code_major, code_minor = it.fish.get("v0"), it.fish.get("v1")
    return "{}_{}.{}".format(software_name, code_major, code_minor)


def create_initial_model(model, rad, save=True, pre_cycle=50):
    it.command("""
               model new
               model largestrain off
               fish automatic-create off
               model domain extent -10 10
               [global rad = 0.05]
               ball generate hexagonal radius [rad] box -5 5
               ball attribute density 1000.0
               ball group 'boundary' range position-x [-5+2.0*rad] [5-2.0*rad] not
               ball group 'boundary' range position-y [-5+2.0*rad] [5-2.0*rad] not  
               ball group 'boundary' range position-z [-5+2.0*rad] [5-2.0*rad] not  
               ball fix velocity spin range group 'boundary' 
               ball attribute radius multiply 1.001 range position-x -5.0 0.0 group 'boundary' not 
               {model}
               model clean all
    """.format(model=model, rad=rad))
    if save:
        it.command("model save 'timing_test'")


models = {"linear": """
                contact cmat default model linear property kn 1e6
                """,
          "softbond": """
                contact cmat default model softbond property kn 1e6
            """
          }

start_time = 0


def start_timer(*args):
    global start_time
    it.remove_callback("start_timer", -1)
    start_time = timeit.default_timer()


def measure_cycling_speed(step_count):
    it.set_callback("start_timer", -1)
    it.command("model cycle {}".format(step_count))
    total_time = timeit.default_timer() - start_time
    speed = it.contact.count() * float(step_count) / total_time / 1000.0
    return speed


def run_large(constitutive_model):
    steps = 50
    commands = models[constitutive_model]
    rad = 0.05
    create_initial_model(commands, rad, save=True)  # 4e6 zones
    it.command("program threads {}".format(max_threads))
    ss_speed = measure_cycling_speed(steps)
    six.print_("{} large test small-strain speed ".format(constitutive_model), ss_speed)
    it.command("""
    model restore 'timing_test'
    model largestrain on
    """)
    ls_speed = measure_cycling_speed(steps)
    six.print_("{} large test large-strain speed ".format(constitutive_model), ls_speed)
    return ss_speed, ls_speed


def upload_data(jsondata):
    url = "https://ue9zylc163.execute-api.us-east-1.amazonaws.com/prod"
    req = six.moves.urllib.request.Request(url)
    req.add_header('Content-Type', 'application/json; charset=utf-8')
    jsondataasbytes = jsondata.encode('utf-8')
    req.add_header('Content-Length', len(jsondataasbytes))
    # here we disable the SSL certificate verification in Python 3
    # this is bad but in this case the data we are sending is public so it is OK.
    ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
    ssl_context.check_hostname = False
    ssl_context.verify_mode = ssl.CERT_NONE
    response = six.moves.urllib.request.urlopen(req, jsondataasbytes, context=ssl_context)
    return response.read()


version = get_version()

lin_large_small_strain_speed, lin_large_large_strain_speed = run_large("linear")
sb_large_small_strain_speed, sb_large_large_strain_speed = run_large("softbond")
op_sys = platform.system()
hostname = socket.gethostname()

data_package = {"version": version,
                "op_sys": op_sys,
                "benchmark_version": benchmark_version,
                "hostname": hostname,
                "lin_large_ss_speed": int(lin_large_small_strain_speed),
                "lin_large_ls_speed": int(lin_large_large_strain_speed),
                "sb_large_ss_speed": int(sb_large_small_strain_speed),
                "sb_large_ls_speed": int(sb_large_large_strain_speed),
                "cpu_info": get_CPU_info(),
                "memory_info": get_RAM_info(),
                "vCPUs": int(max_threads)}

body = json.dumps(data_package)
try:
    six.print_(upload_data(body))
    six.print_("-" * 80)
    six.print_("Thank you for running the PFC3D speed test.")
    six.print_("You can see the results at benchmark.itascacloud.com")
    six.print_("-" * 80)
except Exception as err:
    six.print_(err)
    with tempfile.NamedTemporaryFile(mode="w",
                                     prefix="itasca_bemchmark_",
                                     suffix=".json",
                                     delete=False) as tfile:
        six.print_(err, file=tfile)
        six.print_(body, file=tfile)
        six.print_("*" * 70)
        six.print_(
            "An error occurred while trying to upload the benchmark results.\nThe results have been written to {}\nPlease email this file to jfurtney@itascacg.com and r.legoc@itasca.fr.".format(
                tfile.name))
        six.print_("*" * 70)

try:
    os.remove('timing_test.f3sav')
except:
    pass
