-- SparseMatrix.lua
-- 
-- Copyright (c) 2023 Plexim GmbH
-- All rights reserved.

local Module = {}

local static = {}

local SparseMatrix = require('blocks.FPGACoder.SparseMatrix')
local YInfo = require('blocks.FPGACoder.YInfo')
local Utils = require('blocks.BlockUtils')

local function isSignalConst(aSignal)
  local n = Utils.stringToNumber(aSignal)
  return (n ~= nil), n
end
  
local function processGateSignals(aKey)
  local gateInitList = ""
  local gateWriteEnable = ""
  local allConst = true
  for x, dev in pairs(Block.dcsDevices) do
    local code = dev[aKey]
    local isConst, n = isSignalConst(code)
    if isConst then
      gateInitList = gateInitList .. ((n > 0) and "1, " or "0, ")
      gateWriteEnable = gateWriteEnable .. "0, "
    else
      gateInitList = gateInitList .. "0, "
      gateWriteEnable = gateWriteEnable .. "1, "
      allConst = false
    end
  end
  return gateInitList, gateWriteEnable, allConst
end 

local function extractPwmIndices(aSignals, aComponent)
  local ret = ""
  for x, s in pairs(aSignals) do
    local input = s.InputSignal
    local metaData = s.InputMetaData
    if metaData ~= nil and metaData["port"] ~= nil and metaData["port"] ~= "" then
      local port = metaData["port"]
      if tonumber(port) > 31 then
        error({ errMsg = ("Error in %s: Only the lower 32 Digital In ports can be used " ..
          "for FPGA simulation.") % { aComponent }
        })
      end
      ret = string.format("%s%s, " % { ret, port })
    elseif Utils.stringToNumber(input) == 0 then
      ret = ret .. "-1, "
    else
      error({ errMsg = ("Error in %s: For FPGA simulation, power modules must " ..
        "be connected directly to PWMCapture blocks.") % { aComponent }
      })
    end
  end
  return ret
end

local function extractBiasValues(aValues, aComponent)
  local ret = ""
  for x, b in pairs(aValues) do
    if b == 0 or b == 1 then
      ret = string.format("%s%d, " % { ret, b })
    else
      error({ errMsg = ("Error in %s: Inadmissable bias value (%s). FPGA simulation " ..
        "requires bias value to be either 0 or 1.") % { aComponent, b }
      })
    end
  end
  return ret
end

local function processError(aVal)
 if aVal == nil then
   return "Error information is nil"
 elseif type(aVal) == "string" then
   return aVal
 elseif aVal.errMsg ~= nil then
   return aVal.errMsg
 else
   return "Unknown error information."
 end
end

function Module.getBlock(globals)
  local FPGACoder = require('blocks.block').getBlock(globals)


  function FPGACoder:generateDFTCode()
    if (Block.Version ~= 1) then
      return "Incompatible version of Target Support Package detected. Please install" ..
      " the latest version of the RT Box Target Support Package."
    end

    --
    -- MeterInfos
    --
    local yInfo = YInfo.fromHash(Block)
    local MeterInfos = {}
    for x, meterIndex in pairs(yInfo.mOldExternalMeters) do
      local offset = yInfo:getExternalMeterOffset(meterIndex)
      local prefix = yInfo.mNegated[meterIndex] and "-" or ""
      local code = string.format("%sFPGAPhysicalModel_getOutput(%d)" % { prefix, offset } )
      MeterInfos[tostring(meterIndex)] = {
        Code = code,
        RowIdx = yInfo:getRowForMeter(meterIndex),
        Negated = yInfo.mNegated[meterIndex]
      }
    end

    --
    -- Conduction states
    --
    local ConductionStateCode = StringList:new()
    for i, dev in pairs(Block.dcsDevices) do
      ConductionStateCode:append("FPGAPhysicalModel_getConductionState(%d)" % { i-1 })
    end

    return {
      OutputCode = StringList:new(),
      MeterInfos = MeterInfos,
      ConductionStateCode = ConductionStateCode
    }
  end


  function FPGACoder:generateNonDFTCode()
    if (Block.Version ~= 1) then
      return "Incompatible version of Target Support Package detected. Please install" ..
      " the latest version of the RT Box Target Support Package."
    end
    
    --
    -- Boundary checks
    --
    local ndcs = #Block.dcsDevices
    local nu = Block.nu
    local nx = Block.nx
    local yInfo = YInfo.fromHash(Block)
    local ny = yInfo:numY()
    local mat_A = SparseMatrix.fromHash(Block.mat_A)
    local mat_B = SparseMatrix.fromHash(Block.mat_B)
    local mat_C = SparseMatrix.fromHash(Block.mat_C)
    local mat_D = SparseMatrix.fromHash(Block.mat_D)
    local mat_I = SparseMatrix.fromHash(Block.mat_I)
    local MAX_XY_CHUNKS = 7
    local PM_VECSZ = 32
    local PM_INBUF_VECTOR_DEPTH = 32


    if (nx + ny > PM_VECSZ * MAX_XY_CHUNKS) then
      return ("System is too large for FPGA simulation: The number of states " ..
        "(%d) plus the number of meters (%d) must not exceed %d."
        % { nx, ny, PM_VECSZ * MAX_XY_CHUNKS }
      )
    end

    if Block.cpuFPGAStepRatio ~= 0 then
      local fpgaStepSize = Block.cpuBasePeriod / Block.cpuFPGAStepRatio * 1e6;
      local maxFpgaStepSize = 1/300 * 1023;
      if fpgaStepSize > maxFpgaStepSize then
        return ("FPGA step size (%.2fµs) must be smaller than %.2fµs."
          % { fpgaStepSize, maxFpgaStepSize }
        )
      end
    end
    
    if yInfo:numY_ext() == 0 then
      return "All outputs of the FPGA simulation are unused."
    end

    --
    -- InitCode
    --
    local InitCode = StringList:new()
    -- process gate signals
    local dcsUpperGateInitList
    local dcsUpperGateWriteEnable
    local dcsUpperGateAllConst = true
    local dcsLowerGateInitList
    local dcsLowerGateWriteEnable
    local dcsLowerGateAllConst = true
    if ndcs > 0 then
      dcsUpperGateInitList, dcsUpperGateWriteEnable, dcsUpperGateAllConst = 
        processGateSignals("upperGateSignal")

      dcsLowerGateInitList, dcsLowerGateWriteEnable, dcsLowerGateAllConst = 
        processGateSignals("lowerGateSignal")
    else
      dcsUpperGateInitList = "0"
      dcsLowerGateInitList = "0"
      dcsUpperGateWriteEnable = "0"
      dcsLowerGateWriteEnable = "0"
    end
    local allGatesConst = dcsUpperGateAllConst and dcsLowerGateAllConst
    
    InitCode:append("{\n")
    InitCode:append("static const int dcs_upper_gate_init[] = {\n%s\n};\n" % { dcsUpperGateInitList } )
    InitCode:append("static const int dcs_lower_gate_init[] = {\n%s\n};\n" % { dcsLowerGateInitList } )
    InitCode:append("static const int dcs_upper_gate_write_enable[] = {\n%s\n};\n" % { dcsUpperGateWriteEnable } )
    InitCode:append("static const int dcs_lower_gate_write_enable[] = {\n%s\n};\n" % { dcsLowerGateWriteEnable } )

    local y_ext_indices = ""
    if yInfo:numY_ext() then
      for x, idx in pairs(yInfo.mNewExternalMeters) do
        local rowIdx = yInfo:getRowForMeter(idx)
        y_ext_indices = y_ext_indices .. string.format("%d, " % { rowIdx })
      end
    else
      y_ext_inidices = "0";
    end
    InitCode:append("static const int y_ext_indices[] = {\n%s\n};\n" % { y_ext_indices } )
    
    local dcsInitialStateValues = ""
    if nx > 0 then
      for x, s in pairs(Block.initialStateValues) do
        dcsInitialStateValues = string.format("%s%.9g, " % {
          dcsInitialStateValues,
          s
        })
      end
    else
      dcsInitialStateValues = "0"
    end
    InitCode:append("static const float x_init[] = {\n%s\n};\n" % { dcsInitialStateValues } )

    local dcs_u_indices = ""
    local dcs_y_indices = ""
    local dcsInitialConductivity = ""
    local dcsUpperVoltageInit = ""
    local dcsLowerVoltageInit = ""
    local dcsGatedTurnOn = ""
    local dcsGatedTurnOff = ""
    local dcsResonatorFreq = ""
    local maxMeterIndex = 0

    local uhbDevices = {}
    if ndcs > 0 then
      for i, dev in pairs(Block.dcsDevices) do
        dcs_u_indices = string.format("%s%d, " % { dcs_u_indices, dev.sourceIndex })
        local rowIdx = yInfo:getRowForMeter(dev.meterIndex)
        if rowIdx > maxMeterIndex then
          maxMeterIndex = rowIdx
        end
        dcs_y_indices = string.format("%s%d, " % { dcs_y_indices, rowIdx })
        dcsInitialConductivity = string.format("%s%s, " % {
          dcsInitialConductivity,
          "PM_DCS_CONDSTATE_" .. dev.initialConductivity
        })
        if dev.uhbData ~= nil then
          table.insert(uhbDevices, dev.uhbData)
        end
        dcsGatedTurnOn = string.format("%s%d, " % { dcsGatedTurnOn, dev.gatedTurnOn and 1 or 0 })
        dcsGatedTurnOff = string.format("%s%d, " % { dcsGatedTurnOff, dev.gatedTurnOff and 1 or 0 })
        dcsUpperVoltageInit = string.format("%s%.18g, " % { dcsUpperVoltageInit, dev.upperVoltageInit })
        dcsLowerVoltageInit = string.format("%s%.18g, " % { dcsLowerVoltageInit, dev.lowerVoltageInit })
        dcsResonatorFreq = string.format("%s%.18g, " % { dcsResonatorFreq, dev.fundamentalFrequency });
      end
    else
      dcs_u_indices = "-1"
      dcs_y_indices = "-1"
      dcsInitialConductivity = "0"
      dcsGatedTurnOn = "0";
      dcsGatedTurnOff = "0";
      dcsUpperVoltageInit = "0"
      dcsLowerVoltageInit = "0"
      dcsResonatorFreq = "0"
    end

    InitCode:append("static const int s_init[] = {\n%s\n};\n" % { dcsInitialConductivity } )
    InitCode:append("static const int dcs_u_indices[] = {\n%s\n};\n" % { dcs_u_indices } )
    InitCode:append("static const int dcs_y_indices[] = {\n%s\n};\n" % { dcs_y_indices } )
    InitCode:append("static const int dcs_gated_turn_on[] = {\n%s\n};\n" % { dcsGatedTurnOn } )
    InitCode:append("static const int dcs_gated_turn_off[] = {\n%s\n};\n" % { dcsGatedTurnOff } )
    InitCode:append("static const double dcs_upper_voltage_init[] = {\n%s\n};\n" % { dcsUpperVoltageInit } )
    InitCode:append("static const double dcs_lower_voltage_init[] = {\n%s\n};\n" % { dcsLowerVoltageInit } )
    InitCode:append("static const double dcs_resonator_freq[] = {\n%s\n};\n" % { dcsResonatorFreq } )
        
    InitCode:append(mat_A:toCode("mat_A"))
    InitCode:append(mat_B:toCode("mat_B"))
    InitCode:append(mat_C:filter(yInfo.mOriginalIndices):toCode("mat_C"))
    InitCode:append(mat_D:filter(yInfo.mOriginalIndices):toCode("mat_D"))
    InitCode:append(mat_I:toCode("mat_I"))
    
    -- begin UHB code
    local uhbDcsIndices = ""
    local uhbWidths = ""
    local uhbVoltMeters = ""
    local uhbSources = ""
    local uhbUpperPwms = ""
    local uhbLowerPwms = ""
    local uhbUpperBias = ""
    local uhbLowerBias = ""
    local totalUhbWidth = 0
    
    if #uhbDevices > 0 then
      for x, uhb in pairs(uhbDevices) do
        local component = Block.dcsDevices[uhb.dcsIndex + 1].componentPath
        uhbDcsIndices = string.format("%s%d, " % { uhbDcsIndices, uhb.dcsIndex })
        uhbWidths = string.format("%s%d, " % { uhbWidths, uhb.width })
        totalUhbWidth = totalUhbWidth + uhb.width
        
        for x, m in pairs(uhb.meterIndices) do
          local uhbRowIdx = yInfo:getRowForMeter(m)
          if uhbRowIdx > maxMeterIndex then
            maxMeterIndex = uhbRowIdx
          end
          uhbVoltMeters = string.format("%s%d, " % { uhbVoltMeters, uhbRowIdx })
        end
        
        for x, s in pairs(uhb.sourceIndices) do
          uhbSources = string.format("%s%d, " % { uhbSources, s })
        end
        
        local ok, result = pcall(extractPwmIndices, uhb.upperPwms, component)
        if not ok then return processError(result) end
        uhbUpperPwms = uhbUpperPwms .. result
        ok, result = pcall(extractPwmIndices, uhb.lowerPwms, component)
        if not ok then return processError(result) end
        uhbLowerPwms = uhbLowerPwms .. result

        ok, result = pcall(extractBiasValues, uhb.upperBias, component)
        if not ok then return processError(result) end
        uhbUpperBias = uhbUpperBias .. result
        ok, result = pcall(extractBiasValues, uhb.lowerBias, component)
        if not ok then return processError(result) end
        uhbLowerBias = uhbLowerBias .. result
      end
    else
      uhbDcsIndices = "0"
      uhbWidths = "0"
      uhbVoltMeters = "0"
      uhbSources = "0"
      uhbUpperPwms = "0"
      uhbLowerPwms = "0"
      uhbUpperBias = "0"
      uhbLowerBias = "0"
    end
 
    if #uhbDevices > PM_VECSZ/2 then
      return ("The system is too large for FPGA simulation: The number " ..
        "of power modules (%d) must not exceed %d."
        % { #uhbDevices, PM_VECSZ/2 }
      )
    end

    if maxMeterIndex+1 > PM_VECSZ then
      return ("The system is too large for FPGA simulation: The final number " ..
        "of switch measurements (%d) must not exceed %d."
        % { maxMeterIndex+1, PM_VECSZ }
      )
    end

    local nu_ext = nu - totalUhbWidth - ndcs
    if nu_ext > PM_VECSZ*PM_INBUF_VECTOR_DEPTH then
      return ("System is too large for FPGA simulation: The number of sources " ..
        "(%d) must not exceed %d."
        % { nu_ext, PM_VECSZ*PM_INBUF_VECTOR_DEPTH}
      )
    end

    if nu_ext == 0 then
      return "To run the system on the FPGA it must contain at least one source."
    end
    
    -- UHB declarations
    InitCode:append([[
      static const int uhb_dcs_indices[] = { %(uhb_dcs_indices)s };
      static const int uhb_widths[] = { %(uhb_width)s };
      static const int uhb_voltmeters[] = { %(uhb_voltmeters)s };
      static const int uhb_current_sources[] = { %(uhb_sources)s };
      static const int uhb_upper_pwms[] = { %(uhb_upper_pwms)s };
      static const int uhb_lower_pwms[] = { %(uhb_lower_pwms)s };
      static const int uhb_upper_bias[] = { %(uhb_upper_bias)s };
      static const int uhb_lower_bias[] = { %(uhb_lower_bias)s };
    ]] % {
      uhb_dcs_indices = uhbDcsIndices,
      uhb_width = uhbWidths,
      uhb_voltmeters = uhbVoltMeters,
      uhb_sources = uhbSources,
      uhb_upper_pwms = uhbUpperPwms,
      uhb_lower_pwms = uhbLowerPwms,
      uhb_upper_bias = uhbUpperBias,
      uhb_lower_bias = uhbLowerBias
    })
    -- end UHB code

    local discMethods = { 
      [0] = "PM_DISCRETIZATION_METHOD_TUSTIN",
      [1] = "PM_DISCRETIZATION_METHOD_RADAU"
    }
    
    local uhbInitCode = [[
      .dcs_indices = &uhb_dcs_indices[0],
      .widths = &uhb_widths[0],
      .voltmeters = &uhb_voltmeters[0],
      .current_sources = &uhb_current_sources[0],
      .upper_pwms = &uhb_upper_pwms[0],
      .lower_pwms = &uhb_lower_pwms[0],
      .upper_bias = &uhb_upper_bias[0],
      .lower_bias = &uhb_lower_bias[0]
    ]]


    InitCode:append([[
      static const struct FPGAPhysicalModel model =
      {
         .nx = %(nx)d,
         .nu = %(nu)d,
         .ny = %(ny)d,
         .ndcs = %(ndcs)d,
         .nuhb = %(nuhb)d,
         .ny_ext = %(ny_ext)d,
         .y_ext_indices = &y_ext_indices[0],

         .mat_A = &mat_A,
         .mat_B = &mat_B,
         .mat_C = &mat_C,
         .mat_D = &mat_D,
         .mat_I = &mat_I,

         .discretization_method = %(discretization_method)s,

         .x_init = &x_init[0],
         .s_init = &s_init[0],

         .dcs = {
            .resistance = %(dcs_resistance).18g,

            .u_indices = &dcs_u_indices[0],
            .y_indices = &dcs_y_indices[0],

            .gated_turn_on = &dcs_gated_turn_on[0],
            .gated_turn_off = &dcs_gated_turn_off[0],

            .upper_voltage_init = &dcs_upper_voltage_init[0],
            .lower_voltage_init = &dcs_lower_voltage_init[0],

            .upper_gate_init = &dcs_upper_gate_init[0],
            .lower_gate_init = &dcs_lower_gate_init[0],
            .upper_gate_write_enable = &dcs_upper_gate_write_enable[0],
            .lower_gate_write_enable = &dcs_lower_gate_write_enable[0],
            .all_gates_constant = %(all_gates_const)d,

            .resonator_freq = &dcs_resonator_freq[0]
         },

         .uhb = {
            %(uhb)s
         },

        .cpu_base_period = %(cpu_base_period).18g,
        .sync_steps = %(sync_steps)d
      };

    ]] % {
      nx = nx,
      nu = nu,
      ny = ny,
      ndcs = ndcs,
      nuhb = #uhbDevices,
      ny_ext = yInfo:numY_ext(),
      discretization_method = discMethods[Block.discMethod],
      dcs_resistance = Block.dcsResistance,
      all_gates_const = (allGatesConst and 1 or 0),
      uhb = uhbInitCode,
      cpu_base_period = Block.cpuBasePeriod,
      sync_steps = Block.cpuFPGAStepRatio
    })


    -- print(dump(InitCode))

    --
    -- Update Code
    --
    local UpdateCode = StringList:new()
    UpdateCode:append("{\n")

    -- external (non-dcs and non-uhb) inputs
    local inputSignals = Block.InputSignals
    local inputMetaData = Block.InputMetaData
    for idx, code in pairs(inputSignals) do
      local useInputCode = true;
      if inputMetaData and inputMetaData[idx] then
        local sourceInfo = inputMetaData[idx].sourceInfo
        if sourceInfo then
          if sourceInfo.fromNanostep == 1 then
            local nanostepIndex = sourceInfo.nanostepIndex
            local nanostepMeterIndex = sourceInfo.meterIndex
            InitCode:append("FPGAPhysicalModel_setupInputFromNanostep(%i, %i, %i);\n"
            % {
              nanostepIndex, nanostepMeterIndex, idx
            })
            useInputCode = false;
          end
        end
      end
      if useInputCode then
        UpdateCode:append("FPGAPhysicalModel_setInput(%d, %s);\n" % { idx, code })
      end
    end
    
    InitCode:append("FPGAPhysicalModel_setup(&model);\n")
    InitCode:append("}\n")

    -- process gate signals
    for idx, dev in pairs(Block.dcsDevices) do
      local upperGateSignal = dev.upperGateSignal
      if not isSignalConst(upperGateSignal) then
        UpdateCode:append("FPGAPhysicalModel_setDcsUpperGate(%d, %s);\n" % { idx - 1, upperGateSignal })
      end
      local lowerGateSignal = dev.lowerGateSignal
      if not isSignalConst(lowerGateSignal) then
        UpdateCode:append("FPGAPhysicalModel_setDcsLowerGate(%d, %s);\n" % { idx - 1, lowerGateSignal })
      end
    end

    UpdateCode:append("}\n")

    return {
      InitCode = InitCode,
      UpdateCode = UpdateCode,
    }
  end
       
  return FPGACoder
end

return Module
