#!/bin/python3

import arbor
import numpy as np
import time
from datetime import datetime
import json
import argparse
import os
import re

###############################################################################
# getTimestamp
# Returns a previously determined timestamp, or the timestamp of the current point in time
# - refresh [optional]: if True, forcibly retrieves a new timestamp; else, only returns a new timestamp if no previous one is known
# - return: timestamp in the format YY-MM-DD_HH-MM-SS
def getTimestamp(refresh = False):
	global timestamp_var # make this variable static
		
	try:
		if timestamp_var and refresh == True:
			timestamp_var = datetime.now() # forcibly refresh timestamp
	except NameError:
		timestamp_var = datetime.now() # get timestamp for the first time
		
	return timestamp_var.strftime("%y-%m-%d_%H-%M-%S")	

###############################################################################
# getDataPath
# Consumes a general description for the simulation and a file description, and returns 
# a path to a timestamped file in the output directory;  if no file description is provided, 
# returns the path to the output directory
# - sim_description: general description for the simulation
# - file_description [optional]: specific name and extension for the file
# - refresh [optional]: if True, enforces the retrieval of a new timestamp
# - return: path to a file in the output directory
def getDataPath(sim_description, file_description = "", refresh = False):
	timestamp = getTimestamp(refresh)
	out_path = "data_" + timestamp 
	
	if sim_description != "":
			out_path = out_path + " " + sim_description

	if file_description == "":
		return out_path
		
	return os.path.join(out_path, timestamp + "_" + file_description)

    
###############################################################################
# writeLog
# Writes string to the global log file 'logf' and prints it to the console
# - ostrs: the string(s) to be written/printed
def writeLog(*ostrs):

	for i in range(len(ostrs)):
		ostr = str(ostrs[i])
		ostr = re.sub(r'\x1b\[[0-9]*m', '', ostr) # remove console formatting
		if i == 0:
			logf.write(ostr)
		else:
			logf.write(" " + ostr)
	logf.write("\n")
	
	print(*ostrs)

#####################################
# NetworkRecipe
# Implementation of Arbor simulation recipe
class NetworkRecipe(arbor.recipe):

	# constructor
	# - config: dictionary containing configuration data
	def __init__(self, config):

		# The base C++ class constructor must be called first, to ensure that
		# all memory in the C++ class is initialized correctly. (see https://github.com/tetzlab/FIPPA/blob/main/STDP/arbor_lif_stdp.py)
		arbor.recipe.__init__(self)
		
		self.N_exc = int(config["populations"]["N_exc"]) # number of neurons in the excitatory population
		self.N_inh = int(config["populations"]["N_inh"]) # number of neurons in the inhibitory population
		self.N_tot = self.N_exc + self.N_inh # total number of neurons (excitatory and inhibitory)
		self.p_c = config["populations"]["p_c"] # probability of connection
		
		self.props = arbor.neuron_cable_properties() # initialize the cell properties to match Neuron's defaults 
		                                             # (cf. https://docs.arbor-sim.org/en/v0.5.2/tutorial/single_cell_recipe.html)
		
		cat = arbor.load_catalogue("./custom-catalogue.so") # load the catalogue of custom mechanisms
		cat.extend(arbor.default_catalogue(), "") # add the default catalogue
		self.props.catalogue = cat

		self.runtime = config["simulation"]["runtime"] # runtime of the simulation in ms
		self.dt = config["simulation"]["dt"] # duration of one timestep in ms
		self.neuron_config = config["neuron"]
		self.syn_config = config["synapses"]
		self.h_0 = self.syn_config["h_0"]
		self.w_ei = config["populations"]["w_ei"]
		self.w_ie = config["populations"]["w_ie"]
		self.w_ii = config["populations"]["w_ii"]

		if config['populations']['conn_file']: # if a connections file is specified, load the connectivity matrix from that file
			self.conn_matrix = np.loadtxt(config['populations']['conn_file']).transpose()
			self.p_c = -1
		else: # there is no pre-defined connectivity matrix -> generate one
			rng = np.random.default_rng() # random number generator
			self.conn_matrix = rng.random((self.N_tot, self.N_tot)) <= self.p_c # two-dim. array of booleans indicating the existence of any incoming connection
			self.conn_matrix[np.identity(self.N_tot, dtype=bool)] = 0 # remove self-couplings
	
	# cell_kind
	# Defines the kind of the neuron given by gid
	# - gid: global identifier of the cell
	# - return: type of the cell
	def cell_kind(self, gid):
		
		return arbor.cell_kind.cable # note: implementation of arbor.cell_kind.lif is not ready to use yet

	# cell_description
	# Defines the morphology, cell mechanism, etc. of the neuron given by gid
	# - gid: global identifier of the cell
	# - return: description of the cell
	def cell_description(self, gid):

		# cylinder morphology
		tree = arbor.segment_tree()
		radius = self.neuron_config["radius"] # radius of cylinder (in µm)
		height = 2*radius # height of cylinder (in µm)
		tree.append(arbor.mnpos,
		            arbor.mpoint(-height/2, 0, 0, radius),
		            arbor.mpoint(height/2, 0, 0, radius),
		            tag=1)
		labels = arbor.label_dict({"center": "(location 0 0.5)"})
		area_m2 = 2 * np.pi * (radius * 1e-6) * (height * 1e-6) # surface area of the cylinder in m^2 (excluding the circle-shaped ends, since Arbor does not consider current flux there)
		area_cm2 = 2 * np.pi * (radius * 1e-4) * (height * 1e-4) # surface area of the cylinder in cm^2 (excluding the circle-shaped ends, since Arbor does not consider current flux there)
		i_factor = (1e-9/1e-3) / area_cm2 # conversion factor from nA to mA/cm^2; for point neurons
		c_mem = self.neuron_config["C_mem"] / area_m2 # specific capacitance in F/m^2, computed from absolute capacitance of a point neuron
		
		# cell mechanism
		decor = arbor.decor()
		decor.set_property(Vm=self.neuron_config["V_init"], cm=c_mem)
		mech_neuron = arbor.mechanism(self.neuron_config["mechanism"])
		R_leak = self.neuron_config["R_leak"]
		tau_mem = R_leak*10**9 * self.neuron_config["C_mem"] # membrane time constant in ms
		V_rev = self.neuron_config["V_rev"]
		V_reset = self.neuron_config["V_reset"]
		V_th = self.neuron_config["V_th"]
		mech_neuron.set("R_leak", R_leak)
		mech_neuron.set("R_reset", self.neuron_config["R_reset"])
		mech_neuron.set("I_0", 0) # set to zero (background input is applied via OU process ou_bg)
		mech_neuron.set("i_factor", i_factor)
		mech_neuron.set("V_rev", V_rev)
		mech_neuron.set("V_reset", V_reset)
		mech_neuron.set("V_th", V_th)
		mech_neuron.set("t_ref", self.neuron_config["t_ref"])
		decor.paint('(all)', arbor.density(mech_neuron))
			
		# excitatory neurons
		if gid < self.N_exc:
			# parameter output
			if gid == 0:
				writeLog("area =", area_m2, "m^2")
				writeLog("i_factor =", i_factor, "(mA/cm^2) / (nA)")
				writeLog("c_mem =", c_mem, "F/m^2")
				writeLog("tau_mem =", tau_mem, "ms")
		
			# non-plastic excitatory exponential synapse
			mech_expsyn_exc = arbor.mechanism('expsyn_curr')
			mech_expsyn_exc.set('w', self.h_0)
			mech_expsyn_exc.set('R_mem', R_leak)
			mech_expsyn_exc.set('tau', self.syn_config["tau_syn"])

			inc_exc_connections = np.sum(self.conn_matrix[gid][0:self.N_exc], dtype=int) # number of incoming excitatory connections
			for i in range(inc_exc_connections):
				decor.place('"center"', arbor.synapse(mech_expsyn_exc), "syn_ee") # place synapse at the center of the soma (because: point neuron)
			#print("Placed", inc_exc_connections, "incoming E->E synapses for neuron", gid)

			# non-plastic inhibitory exponential synapse
			mech_expsyn_inh = arbor.mechanism('expsyn_curr')
			mech_expsyn_inh.set('w', -self.w_ie * self.h_0)
			mech_expsyn_inh.set('R_mem', R_leak)
			mech_expsyn_inh.set('tau', self.syn_config["tau_syn"])

			inc_inh_connections = np.sum(self.conn_matrix[gid][self.N_exc:self.N_tot], dtype=int) # number of incoming inhibitory connections
			for i in range(inc_inh_connections):
				decor.place('"center"', arbor.synapse(mech_expsyn_inh), "syn_ie") # place synapse at the center of the soma (because: point neuron)
			
		# inhibitory neurons
		else:			
			# non-plastic excitatory exponential synapse
			mech_expsyn_exc = arbor.mechanism('expsyn_curr')
			mech_expsyn_exc.set('w', self.w_ei * self.h_0)
			mech_expsyn_exc.set('R_mem', R_leak)
			mech_expsyn_exc.set('tau', self.syn_config["tau_syn"])
			inc_exc_connections = np.sum(self.conn_matrix[gid][0:self.N_exc], dtype=int) # number of incoming excitatory connections
			for i in range(inc_exc_connections):
				decor.place('"center"', arbor.synapse(mech_expsyn_exc), "syn_ei") # place synapse at the center of the soma (because: point neuron)
			
			# non-plastic inhibitory exponential synapse
			mech_expsyn_inh = arbor.mechanism('expsyn_curr')
			mech_expsyn_inh.set('w', -self.w_ii * self.h_0)
			mech_expsyn_inh.set('R_mem', R_leak)
			mech_expsyn_inh.set('tau', self.syn_config["tau_syn"])

			inc_inh_connections = np.sum(self.conn_matrix[gid][self.N_exc:self.N_tot], dtype=int) # number of incoming inhibitory connections
			for i in range(inc_inh_connections):
				decor.place('"center"', arbor.synapse(mech_expsyn_inh), "syn_ii") # place synapse at the center of the soma (because: point neuron)
			
		# place spike detector
		decor.place('"center"', arbor.spike_detector(1e12), "spike_detector") # NOTE: set threshold for testing so high that definitely NO SPIKES WILL BE DETECTED AT ALL!
			
		return arbor.cable_cell(tree, labels, decor)
		
	# connections_on
	# Defines the list of incoming synaptic connections to the neuron given by gid
	# - gid: global identifier of the cell
	# - return: connections to the given neuron
	def connections_on(self, gid):
		connections_list = []

		rr = arbor.selection_policy.round_robin
		
		connections = self.conn_matrix[gid]
		assert connections[gid] == 0 # check that there are no self-couplings
		
		exc_connections = np.array(connections*np.concatenate((np.ones(self.N_exc, dtype=np.int8), np.zeros(self.N_inh, dtype=np.int8)), axis=None), dtype=bool) # array of booleans indicating all incoming excitatory connections
		inh_connections = np.array(connections*np.concatenate((np.zeros(self.N_exc, dtype=np.int8), np.ones(self.N_inh, dtype=np.int8)), axis=None), dtype=bool) # array of booleans indicating all incoming inhibitory connections
				
		assert not np.any(np.logical_xor(np.logical_or(exc_connections, inh_connections), connections)) # test if 'exc_connections' and 'inh_connections' together yield 'connections' again
		
		exc_pre_neurons = np.arange(self.N_tot)[exc_connections] # array of excitatory presynaptic neurons indicated by their gid
		inh_pre_neurons = np.arange(self.N_tot)[inh_connections] # array of inhibitory presynaptic neurons indicated by their gid

		assert np.logical_and(np.all(exc_pre_neurons >= 0), np.all(exc_pre_neurons < self.N_exc)) # test if the excitatory neuron numbers are in the correct range
		assert np.logical_and(np.all(inh_pre_neurons >= self.N_exc), np.all(inh_pre_neurons < self.N_tot)) # test if the inhibitory neuron numbers are in the correct range
		
		# delay constants
		d0 = self.syn_config["t_ax_delay"] # delay time of the postsynaptic potential in ms
		
		# excitatory neurons
		if gid < self.N_exc:
								
			# incoming excitatory synapses
			for src in exc_pre_neurons:
				connections_list.append(arbor.connection((src, "spike_detector"), ("syn_ee", rr), 1, d0)) # for postsynaptic potentials
			
			# incoming inhibitory synapses
			for src in inh_pre_neurons:
				connections_list.append(arbor.connection((src,"spike_detector"), ("syn_ie", rr), 1, d0))
				  
		# inhibitory neurons
		else:
			# incoming excitatory synapses
			for src in exc_pre_neurons:
				connections_list.append(arbor.connection((src,"spike_detector"), ("syn_ei", rr), 1, d0))
				
			# incoming inhibitory synapses
			for src in inh_pre_neurons:
				connections_list.append(arbor.connection((src,"spike_detector"), ("syn_ii", rr), 1, d0))
		
		#print("Set connections for gid =" + str(gid) + ": " + str(connections_list))
		return connections_list

	# event_generators
	# Event generators for input to synapses
	# - gid: global identifier of the cell
	# - return: events generated from Arbor schedule
	def event_generators(self, gid):
		inputs = []
			
		return inputs
		
	# global_properties
	# Sets properties that will be applied to all neurons of the specified kind
	# - gid: global identifier of the cell
	# - return: the cell properties 
	def global_properties(self, kind): 

		assert kind == arbor.cell_kind.cable # assert that all neurons are technically cable cells

		return self.props
	
	# num_cells
	# - return: the total number of cells in the network
	def num_cells(self):
		
		return self.N_tot

	# probes
	# - gid: global identifier of the cell
	# - return: the probes on the given cell
	def probes(self, gid):

		return []
	
#####################################
# arborNetworkConsolidation
# Runs simulation of a recurrent neural network with consolidation dynamics
# - config: configuration of model and simulation parameters (as a dictionary from JSON format)
def arborNetworkConsolidation(config):

	#####################################
	# create output directory, save code and config, and open log file
	s_desc = config['simulation']['short_description']
	out_path = getDataPath(s_desc, refresh=True)
	if not os.path.isdir(out_path): # if the directory does not exist yet
		os.mkdir(out_path)
	os.system("cp -r *.py  \"" + out_path + "\"") # archive the Python code
	os.system("cp -r mechanisms/ \"" + out_path + "\"") # archive the mechanism code
	json.dump(config, open(getDataPath(s_desc, "config.json"), "w"), indent="\t")
	global logf # global handle to the log file (to need less code for output commands)
	logf = open(getDataPath(s_desc, "log.txt"), "w")
	
	#####################################
	# output of key parameters
	writeLog("\x1b[31mArbor network simulation " + getTimestamp() + " (Arbor version: " + str(arbor.__version__) + ")\n" + \
	         "|\n"
	         "\x1b[35mSimulated timespan, timestep:\x1b[37m " + str(config['simulation']['runtime']) + " ms, " + str(config['simulation']['dt']) + " ms\n" + \
	         "|\x1b[0m")

	#####################################
	# set up and run simulation
	recipe = NetworkRecipe(config)

	t_0 = time.time()
	clockseed = int(t_0*10000)
	writeLog("Random seed " + str(clockseed))
	
	alloc = arbor.proc_allocation(threads=1, gpu_id=None) # select one thread and no GPU (default; cf. https://docs.arbor-sim.org/en/v0.7/python/hardware.html#arbor.proc_allocation)
	context = arbor.context(alloc, mpi=None) # constructs a local context without MPI connection
	meter_manager = arbor.meter_manager()
	meter_manager.start(context)

	domains = arbor.partition_load_balance(recipe, context) # constructs a domain_decomposition that distributes the cells in the model described by an arbor.recipe over the distributed and local hardware resources described by an arbor.context (cf. https://docs.arbor-sim.org/en/v0.5.2/python/domdec.html#arbor.partition_load_balance)
	meter_manager.checkpoint('load-balance', context)
	
	sim = arbor.simulation(recipe, context, domains, seed = clockseed)
	meter_manager.checkpoint('simulation-init', context)

	sim.progress_banner()
	sim.record(arbor.spike_recording.off)
	sim.run(tfinal=recipe.runtime, dt=recipe.dt)
	meter_manager.checkpoint('simulation-run', context)

	writeLog(arbor.meter_report(meter_manager, context))
	logf.close()
	
	return recipe

#####################################
if __name__ == '__main__':
		
	# parse the commandline parameter 'config_file'
	parser = argparse.ArgumentParser()
	parser.add_argument('-config_file', required=True, help="configuration of the simulation parameters (JSON file)")
	(args, unknown) = parser.parse_known_args()
	
	# load JSON object containing the parameter configuration as dictionary
	config = json.load(open(args.config_file, "r"))

	# parse the remaining commandline parameters
	parser.add_argument('-s_desc', type=str, help="short description")
	parser.add_argument('-runtime', type=float, help="runtime of the simulation in ms")
	parser.add_argument('-dt', type=float, help="duration of one timestep in ms")
	args = parser.parse_args()

	# in the dictionary containing the parameter configuration, modify the values provided by commandline arguments
	if (args.s_desc is not None): config['simulation']['short_description'] = args.s_desc
	if (args.runtime is not None): config['simulation']['runtime'] = args.runtime
	if (args.dt is not None): config['simulation']['dt'] = args.dt

	# run the simulation
	arborNetworkConsolidation(config)
	
