-- SparseMatrix.lua
-- 
-- Copyright (c) 2023 Plexim GmbH
-- All rights reserved.

local Module = {}

local static = {}

local Utils = require('blocks.BlockUtils')
local SparseMatrix = require('blocks.FPGACoder.SparseMatrix')
local YInfo = require('blocks.FPGACoder.YInfo')
local Utils = require('blocks.BlockUtils')

local function isSignalConst(aSignal, aMetaData)
  local n = Utils.stringToNumber(aSignal)
  if n == nil then
    if aMetaData then
      if Utils.isUnconnectedMetaData(aMetaData) then
        n = 0
      end
    end
  end
  return (n ~= nil), n
end
  
local function processGateSignals(aSignalKey, aMetaDataKey)
  if #Block.dcsDevices == 0 then
    return "0"
  end

  local gateValues = ""
  
  for x, dev in pairs(Block.dcsDevices) do
    local code = dev[aSignalKey]
    local metaData = dev[aMetaDataKey]
    local isConst, n = isSignalConst(code, metaData)
    if isConst then
      gateValues = gateValues .. ((n > 0) and "FLEXARRAY_DCS_GATE_VALUE_STATIC_1, " or "FLEXARRAY_DCS_GATE_VALUE_STATIC_0, ")
    else
      gateValues = gateValues .. "FLEXARRAY_DCS_GATE_VALUE_DYNAMIC, "
    end
  end
  
  return gateValues
end 

local function extractPwmIndices(aSignals, aComponent)
  local ret = ""
  for x, s in pairs(aSignals) do
    local input = s.InputSignal
    local metaData = s.InputMetaData
    if metaData ~= nil and metaData["port"] ~= nil and metaData["port"] ~= "" then
      local port = metaData["port"]
      ret = string.format("%s%s, " % { ret, port })
    elseif Utils.stringToNumber(input) == 0 or Utils.isUnconnectedMetaData(metaData) then
      ret = ret .. "-1, "
    else
      error({ errMsg = ("Error in %s: For FPGA simulation, power modules must " ..
        "be connected directly to PWMCapture blocks.") % { aComponent }
      })
    end
  end
  return ret
end

local function extractBiasValues(aValues, aComponent)
  local ret = ""
  for x, b in pairs(aValues) do
    if b == 0 or b == 1 then
      ret = string.format("%s%d, " % { ret, b })
    else
      error({ errMsg = ("Error in %s: Inadmissable bias value (%s). FPGA simulation " ..
        "requires bias value to be either 0 or 1.") % { aComponent, b }
      })
    end
  end
  return ret
end

local function processError(aVal)
 if aVal == nil then
   return "Error information is nil"
 elseif type(aVal) == "string" then
   return aVal
 elseif aVal.errMsg ~= nil then
   return aVal.errMsg
 else
   return "Unknown error information."
 end
end

local function psInputIndex(aOriginalIndex, aInputSignals, aInputMetaData)
  local index = tonumber(aOriginalIndex)
  for idx, code in pairs(aInputSignals) do
    if aInputMetaData and aInputMetaData[idx] then
      local sourceInfo = aInputMetaData[idx].sourceInfo
      if sourceInfo then
        if sourceInfo.fromNanostep == 1 then
          if (tonumber(idx) < tonumber(aOriginalIndex)) then
            index = index - 1
          end
        end
      end
    end
  end
  return tostring(index)
end

function Module.getBlock(globals)
  local FPGACoder = require('blocks.block').getBlock(globals)


  function FPGACoder:generateDFTCode()
    if (Block.Version ~= 1) then
      return "Incompatible version of Target Support Package detected. Please install" ..
      " the latest version of the RT Box Target Support Package."
    end
    local simulationMode = 0
    if Target.Family == 'PLECS RT Box' then
      simulationMode = Target.Variables.simulationMode
    elseif Target.OriginalTarget and Target.OriginalTarget.Family == 'PLECS RT Box' then
      simulationMode = Target.OriginalTarget.Variables.simulationMode;
    end
    if simulationMode == 2 then
      return nil, { message = 'FlexArray solver cannot be used when parameter @param:simulationMode: is set to "RCP".' };
    end

    --
    -- MeterInfos
    --
    local yInfo = YInfo.fromHash(Block)
    local MeterInfos = {}
    for x, meterIndex in pairs(yInfo.mOldExternalMeters) do
      local offset = yInfo:getExternalMeterOffset(meterIndex)
      local prefix = yInfo.mNegated[meterIndex] and "-" or ""
      local code = string.format("%sflexarrayGetOutput(%d)" % { prefix, offset } )
      MeterInfos[tostring(meterIndex)] = {
        Code = code,
        RowIdx = yInfo:getRowForMeter(meterIndex),
        Negated = yInfo.mNegated[meterIndex]
      }
    end

    --
    -- Conduction states
    --
    local ConductionStateCode = StringList:new()
    for i, dev in pairs(Block.dcsDevices) do
      ConductionStateCode:append("flexarrayGetConductionState(%d)" % { i-1 })
    end

    return {
      OutputCode = StringList:new(),
      MeterInfos = MeterInfos,
      ConductionStateCode = ConductionStateCode
    }
  end


  function FPGACoder:generateNonDFTCode()
    if (Block.Version ~= 1) then
      return "Incompatible version of Target Support Package detected. Please install" ..
      " the latest version of the RT Box Target Support Package."
    end
    
    --
    -- Boundary checks
    --
    local ndcs = #Block.dcsDevices
    local nu = Block.nu
    local nx = Block.nx
    local yInfo = YInfo.fromHash(Block)
    local ny = yInfo:numY()
    local mat_A = SparseMatrix.fromHash(Block.mat_A)
    local mat_B = SparseMatrix.fromHash(Block.mat_B)
    local mat_C = SparseMatrix.fromHash(Block.mat_C)
    local mat_D = SparseMatrix.fromHash(Block.mat_D)
    local mat_I = SparseMatrix.fromHash(Block.mat_I)
    
    if yInfo:numY_ext() == 0 then
      return "All outputs of the FPGA simulation are unused."
    end

    --
    -- InitCode
    --
    local InitCode = StringList:new()

    -- process gate signals
    local dcsUpperGateValues = processGateSignals("upperGateSignal", "upperGateSignalMetaData")
    local dcsLowerGateValues = processGateSignals("lowerGateSignal", "lowerGateSignalMetaData")
    InitCode:append("{\n")
    InitCode:append("static const enum FlexarrayDcsGateValue dcs_upper_gate_values[] = {\n%s\n};\n" % { dcsUpperGateValues } )
    InitCode:append("static const enum FlexarrayDcsGateValue dcs_lower_gate_values[] = {\n%s\n};\n" % { dcsLowerGateValues } )

    local y_ext_indices = ""
    if yInfo:numY_ext() then
      for x, idx in pairs(yInfo.mNewExternalMeters) do
        local rowIdx = yInfo:getRowForMeter(idx)
        y_ext_indices = y_ext_indices .. string.format("%d, " % { rowIdx })
      end
    else
      y_ext_inidices = "0";
    end
    InitCode:append("static const int y_ext_indices[] = {\n%s\n};\n" % { y_ext_indices } )
    
    local dcsInitialStateValues = ""
    if nx > 0 then
      for x, s in pairs(Block.initialStateValues) do
        dcsInitialStateValues = string.format("%s%.9g, " % {
          dcsInitialStateValues,
          s
        })
      end
    else
      dcsInitialStateValues = "0"
    end
    InitCode:append("static const float x_init[] = {\n%s\n};\n" % { dcsInitialStateValues } )

    local dcs_u_indices = ""
    local dcs_y_indices = ""
    local dcsInitialConductivity = ""
    local dcsStaticUpperVoltages = ""
    local dcsStaticLowerVoltages = ""
    local dcsGateLogic = ""
    local dcsResonatorFreq = ""
    local maxMeterIndex = 0

    local function computeGateLogic(device)
      if not device.gatedTurnOn then
        return "FLEXARRAY_DCS_GATE_LOGIC_DIODE"
      elseif not device.gatedTurnOff then
        return "FLEXARRAY_DCS_GATE_LOGIC_THYRISTOR"
      else
        return "FLEXARRAY_DCS_GATE_LOGIC_IGCT"
      end
    end

    local uhbDevices = {}
    if ndcs > 0 then
      for i, dev in pairs(Block.dcsDevices) do
        dcs_u_indices = string.format("%s%d, " % { dcs_u_indices, dev.sourceIndex })
        local rowIdx = yInfo:getRowForMeter(dev.meterIndex)
        if rowIdx > maxMeterIndex then
          maxMeterIndex = rowIdx
        end
        dcs_y_indices = string.format("%s%d, " % { dcs_y_indices, rowIdx })
        dcsInitialConductivity = string.format("%s%s, " % {
          dcsInitialConductivity,
          "FLEXARRAY_DCS_CONDSTATE_" .. dev.initialConductivity
        })
        if dev.uhbData ~= nil then
          table.insert(uhbDevices, dev.uhbData)
        end
        dcsGateLogic = string.format("%s%s, " % { dcsGateLogic, computeGateLogic(dev) })
        dcsStaticUpperVoltages = string.format("%s%.18g, " % { dcsStaticUpperVoltages, dev.upperVoltageInit })
        dcsStaticLowerVoltages = string.format("%s%.18g, " % { dcsStaticLowerVoltages, dev.lowerVoltageInit })
        dcsResonatorFreq = string.format("%s%.18g, " % { dcsResonatorFreq, dev.fundamentalFrequency });
      end
    else
      dcs_u_indices = "-1"
      dcs_y_indices = "-1"
      dcsInitialConductivity = "0"
      dcsGateLogic = "0";
      dcsStaticUpperVoltages = "0"
      dcsStaticLowerVoltages = "0"
      dcsResonatorFreq = "0"
    end

    InitCode:append("static const enum FlexarrayDcsConductionState dcs_initial_conduction_state[] = {\n%s\n};\n" % { dcsInitialConductivity } )
    InitCode:append("static const int dcs_u_indices[] = {\n%s\n};\n" % { dcs_u_indices } )
    InitCode:append("static const int dcs_y_indices[] = {\n%s\n};\n" % { dcs_y_indices } )
    InitCode:append("static const enum FlexarrayDcsGateLogic dcs_gate_logic[] = {\n%s\n};\n" % { dcsGateLogic } )
    InitCode:append("static const double dcs_static_upper_voltages[] = {\n%s\n};\n" % { dcsStaticUpperVoltages } )
    InitCode:append("static const double dcs_static_lower_voltages[] = {\n%s\n};\n" % { dcsStaticLowerVoltages } )
    InitCode:append("static const double dcs_resonator_freq[] = {\n%s\n};\n" % { dcsResonatorFreq } )
        
    InitCode:append(mat_A:toCode("mat_A"))
    InitCode:append(mat_B:toCode("mat_B"))
    InitCode:append(mat_C:filter(yInfo.mOriginalIndices):toCode("mat_C"))
    InitCode:append(mat_D:filter(yInfo.mOriginalIndices):toCode("mat_D"))
    InitCode:append(mat_I:toCode("mat_I"))
    
    -- begin UHB code
    local uhbDcsIndices = ""
    local uhbWidths = ""
    local uhbVoltMeters = ""
    local uhbSources = ""
    local uhbUpperPwms = ""
    local uhbLowerPwms = ""
    local uhbUpperBias = ""
    local uhbLowerBias = ""
    local totalUhbWidth = 0
    
    if #uhbDevices > 0 then
      for x, uhb in pairs(uhbDevices) do
        local component = Block.dcsDevices[uhb.dcsIndex + 1].componentPath
        uhbDcsIndices = string.format("%s%d, " % { uhbDcsIndices, uhb.dcsIndex })
        uhbWidths = string.format("%s%d, " % { uhbWidths, uhb.width })
        totalUhbWidth = totalUhbWidth + uhb.width
        
        for x, m in pairs(uhb.meterIndices) do
          local uhbRowIdx = yInfo:getRowForMeter(m)
          if uhbRowIdx > maxMeterIndex then
            maxMeterIndex = uhbRowIdx
          end
          uhbVoltMeters = string.format("%s%d, " % { uhbVoltMeters, uhbRowIdx })
        end
        
        for x, s in pairs(uhb.sourceIndices) do
          uhbSources = string.format("%s%d, " % { uhbSources, s })
        end
        
        local ok, result = pcall(extractPwmIndices, uhb.upperPwms, component)
        if not ok then return processError(result) end
        uhbUpperPwms = uhbUpperPwms .. result
        ok, result = pcall(extractPwmIndices, uhb.lowerPwms, component)
        if not ok then return processError(result) end
        uhbLowerPwms = uhbLowerPwms .. result

        ok, result = pcall(extractBiasValues, uhb.upperBias, component)
        if not ok then return processError(result) end
        uhbUpperBias = uhbUpperBias .. result
        ok, result = pcall(extractBiasValues, uhb.lowerBias, component)
        if not ok then return processError(result) end
        uhbLowerBias = uhbLowerBias .. result
      end
    else
      uhbDcsIndices = "0"
      uhbWidths = "0"
      uhbVoltMeters = "0"
      uhbSources = "0"
      uhbUpperPwms = "0"
      uhbLowerPwms = "0"
      uhbUpperBias = "0"
      uhbLowerBias = "0"
    end
    
    -- UHB declarations
    InitCode:append([[
      static const int uhb_dcs_indices[] = { %(uhb_dcs_indices)s };
      static const int uhb_widths[] = { %(uhb_width)s };
      static const int uhb_voltmeters[] = { %(uhb_voltmeters)s };
      static const int uhb_current_sources[] = { %(uhb_sources)s };
      static const int uhb_upper_pwms[] = { %(uhb_upper_pwms)s };
      static const int uhb_lower_pwms[] = { %(uhb_lower_pwms)s };
      static const int uhb_upper_bias[] = { %(uhb_upper_bias)s };
      static const int uhb_lower_bias[] = { %(uhb_lower_bias)s };
    ]] % {
      uhb_dcs_indices = uhbDcsIndices,
      uhb_width = uhbWidths,
      uhb_voltmeters = uhbVoltMeters,
      uhb_sources = uhbSources,
      uhb_upper_pwms = uhbUpperPwms,
      uhb_lower_pwms = uhbLowerPwms,
      uhb_upper_bias = uhbUpperBias,
      uhb_lower_bias = uhbLowerBias
    })
    -- end UHB code

    local discMethods = { 
      [0] = "FLEXARRAY_DISCRETIZATION_METHOD_TUSTIN",
      [1] = "FLEXARRAY_DISCRETIZATION_METHOD_RADAU"
    }
    
    local uhbInitCode = [[
      .dcs_indices = &uhb_dcs_indices[0],
      .widths = &uhb_widths[0],
      .voltmeters = &uhb_voltmeters[0],
      .current_sources = &uhb_current_sources[0],
      .upper_pwms = &uhb_upper_pwms[0],
      .lower_pwms = &uhb_lower_pwms[0],
      .upper_bias = &uhb_upper_bias[0],
      .lower_bias = &uhb_lower_bias[0]
    ]]


    InitCode:append([[
      static const struct FlexarrayModel model =
      {
         .nx = %(nx)d,
         .nu = %(nu)d,
         .ny = %(ny)d,
         .ndcs = %(ndcs)d,
         .nuhb = %(nuhb)d,
         .ny_ext = %(ny_ext)d,
         .y_ext_indices = &y_ext_indices[0],

         .mat_A = &mat_A,
         .mat_B = &mat_B,
         .mat_C = &mat_C,
         .mat_D = &mat_D,
         .mat_I = &mat_I,

         .discretization_method = %(discretization_method)s,

         .step_ratio = %(step_ratio)d,

         .x_init = &x_init[0],

         .dcs = {
            .resistance = %(dcs_resistance).18g,

            .u_indices = &dcs_u_indices[0],
            .y_indices = &dcs_y_indices[0],

            .initial_conduction_state = &dcs_initial_conduction_state[0],

            .gate_logic = &dcs_gate_logic[0],
            .upper_gate_values = &dcs_upper_gate_values[0],
            .lower_gate_values = &dcs_lower_gate_values[0],

            .static_upper_voltages = &dcs_static_upper_voltages[0],
            .static_lower_voltages = &dcs_static_lower_voltages[0],

            .resonator_freq = &dcs_resonator_freq[0]
         },

         .uhb = {
            %(uhb)s
         }
      };

    ]] % {
      nx = nx,
      nu = nu,
      ny = ny,
      ndcs = ndcs,
      nuhb = #uhbDevices,
      ny_ext = yInfo:numY_ext(),
      discretization_method = discMethods[Block.discMethod],
      dcs_resistance = Block.dcsResistance,
      uhb = uhbInitCode,
      step_ratio = Block.cpuFPGAStepRatio
    })


    -- print(dump(InitCode))

    --
    -- Update Code
    --
    local UpdateCode = StringList:new()
    UpdateCode:append("{\n")

    -- external (non-dcs and non-uhb) inputs
    local inputSignals = Block.InputSignals
    local inputMetaData = Block.InputMetaData
    for idx, code in Utils.pairsByKeys(inputSignals) do
      local useInputCode = true;
      if inputMetaData and inputMetaData[idx] then
        local sourceInfo = inputMetaData[idx].sourceInfo
        if sourceInfo then
          if sourceInfo.fromNanostep == 1 then
            local nanostepIndex = sourceInfo.nanostepIndex
            local nanostepMeterIndex = sourceInfo.meterIndex
            InitCode:append("flexarraySetupInputFromNanostep(%i, %i, %i);\n"
            % {
              nanostepIndex, nanostepMeterIndex, idx
            })
            useInputCode = false;
          end
        end
      end
      if useInputCode then
        UpdateCode:append("flexarraySetInput(%d, %s);\n" % { psInputIndex(idx, inputSignals, inputMetaData), code })
      end
    end
    
    InitCode:append("flexarraySetup(&model);\n")
    InitCode:append("}\n")

    -- process gate signals
    for idx, dev in pairs(Block.dcsDevices) do
      local upperGateSignal = dev.upperGateSignal
      if not isSignalConst(upperGateSignal, dev.upperGateSignalMetaData) then
        UpdateCode:append("flexarraySetDcsUpperGate(%d, %s);\n" % { idx - 1, upperGateSignal })
      end
      local lowerGateSignal = dev.lowerGateSignal
      if not isSignalConst(lowerGateSignal, dev.lowerGateSignalMetaData) then
        UpdateCode:append("flexarraySetDcsLowerGate(%d, %s);\n" % { idx - 1, lowerGateSignal })
      end
    end

    UpdateCode:append("}\n")

    return {
      InitCode = InitCode,
      UpdateCode = UpdateCode,
    }
  end
       
  return FPGACoder
end

return Module
