--[[
  Copyright (c) 2021 by Plexim GmbH
  All rights reserved.

  A free license is granted to anyone to use this software for any legal
  non safety-critical purpose, including commercial applications, provided
  that:
  1) IT IS NOT USED TO DIRECTLY OR INDIRECTLY COMPETE WITH PLEXIM, and
  2) THIS COPYRIGHT NOTICE IS PRESERVED in its entirety.

  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  SOFTWARE.
--]]

local Module = {}
local U = require('common.utils')

local static = {}

function Module.getBlock(globals, cpu)

  local SpiMaster = require('common.block').getBlock(globals, cpu)
  if static[cpu] == nil then
    static[cpu] = {
      numInstances = 0,
      instances = {},
      finalized = false,
    }
  end
  SpiMaster["instance"] = static[cpu].numInstances
  static[cpu].numInstances = static[cpu].numInstances + 1

  function SpiMaster:checkMaskParameters()
    if not U.isPositiveIntScalar(Block.Mask.Baudrate) then
      U.error("SPI clock must be a positive non-zero integer value.")
    end

    if #Block.Mask.SckPin ~= 1 then
      U.error("Invalid SCK Pin configuration.")
    end

    if #Block.Mask.MisoPin ~= 1 then
      U.error("Invalid MISO Pin configuration.")
    end

    if #Block.Mask.MosiPin ~= 1 then
      U.error("Invalid MOSI Pin configuration.")
    end

    if Block.Mask.CharLength > 16 or Block.Mask.CharLength < 4 or not U.isPositiveIntScalar(Block.Mask.CharLength) then
      U.error("The number of 'Bits per word' has to be an integer between 4 and 16")
    end
  end

  function SpiMaster:p_getDirectFeedthroughCode()
    local Require = ResourceList:new()
    local OutputCode = U.CodeLines:new()

    table.insert(static[self.cpu].instances, self.bid)

    self.spi = Block.Mask.Spi
    local mode = Block.Mask.Mode - 1

    local sckgport = string.char(65+Block.Mask.SckPort - 1)
    local sckpin = Block.Mask.SckPin
    local misoport = string.char(65+Block.Mask.MisoPort - 1)
    local misopin = Block.Mask.MisoPin
    local mosiport = string.char(65+Block.Mask.MosiPort - 1)
    local mosipin = Block.Mask.MosiPin

    local apb = globals.target.getTargetParameters()['spis']['SPI%d' % {self.spi}]['apb']

    self.spi_obj = self:makeBlock('spi', self.cpu)
    self.spi_obj:createImplicit(self.spi, {
      path = self:getName(),
      charlen = Block.Mask.CharLength,
      pol = (mode >= 2),
      phase = (mode == 1) or (mode == 3),
      baudrate = Block.Mask.Baudrate,
      sckport = sckgport,
      sckpin = sckpin,
      misoport = misoport,
      misopin = misopin,
      mosiport = mosiport,
      mosipin = mosipin,
      apb = apb
    }, Require)

    self.spi_instance = self.spi_obj:getObjIndex()

    if #Block.Mask.CsPin ~= #Block.Mask.Dim then
      U.error("Dimensions of slave 'CS Pin' and 'Words per transmission' must match.")
    end

    -- setup chip selects
    local csArrayString = ""
    for i = 1, #Block.Mask.CsPin do
      local dio_obj = self:makeBlock('dout')
      local cs = dio_obj:createImplicit(
      {
        port = string.char(65+Block.Mask.CsPort-1),
        pin = Block.Mask.CsPin[i],
        outputType = 'PUSHPULL',
        path = self:getName()
      }
      , Require)

      csArrayString = csArrayString .. "%i" % {cs}
      if i ~= #Block.Mask.CsPin then
        csArrayString = csArrayString .. ", "
      end
    end

    -- setup message size vector
    local spi_fifo_depth = globals.target.getTargetParameters()['spis']['SPI%d' % {self.spi}]['fifo_depth']
    if (Block.Mask.CharLength <= 8) then
      spi_fifo_depth = spi_fifo_depth * 2 - 1 -- fifo can only hold 3 data frames
    end

    local dimArrayString = ""
    for i = 1, #Block.Mask.Dim do
      if Block.Mask.Dim[i] > spi_fifo_depth then
        U.error("Maximum number of words per transmission for this target equals %i." %
                {spi_fifo_depth})
      end
      dimArrayString = dimArrayString .. "%i" % {Block.Mask.Dim[i]}
      if i ~= #Block.Mask.Dim then
        dimArrayString = dimArrayString .. ", "
      end
    end

    local outSignal = Block:OutputSignal()

    OutputCode:append('{')
    OutputCode:append('static uint16_t SlaveCsHandles[] = {%s};\n' % {csArrayString})
    -- setup I/O buffers
    OutputCode:append(
        'static uint16_t SlaveWordsPerTransmission[] = {%s};\n' % {dimArrayString})

    -- setup buffers
    OutputCode:append('static uint16_t SpiMasterRxData[%i] = {' % {#Block.InputSignal[1]})
    for i = 1, #Block.InputSignal[1] do
      if i > 1 then
        OutputCode:append(', 0')
      else
        OutputCode:append('0')
      end
    end
    OutputCode:append('};\n')
    OutputCode:append('static uint16_t SpiMasterTxData[%i];\n' % {#Block.InputSignal[1]})
    OutputCode:append('static uint16_t SpiMasterRxDataBuffer[%i];\n' % {#Block.InputSignal[1]})

    -- setup flags
    OutputCode:append('static uint16_t SlaveIndex = 0;\n')
    OutputCode:append('static uint16_t SlaveDataIndex = 0;\n')
    OutputCode:append('static bool SlaveTxActive = false;\n')
    OutputCode:append('static bool SpiMasterReady = false;\n')
    OutputCode:append('static bool SpiMasterTxOverrun = false;\n')

    -- output code
    local code = [[
      SpiMasterReady = false;
        if(SlaveTxActive){
          // de-assert last CS
          PLXHAL_DIO_set(SlaveCsHandles[SlaveIndex], true);
        SpiMasterTxOverrun = PLXHAL_SPI_isTxBusy(%(channel)s) ;
          if(SpiMasterTxOverrun){
            // overrun occurred
            SlaveIndex = 0;
            SlaveTxActive = false;
          } else {
            // read data
            PLXHAL_SPI_getWords(%(channel)s, &SpiMasterRxDataBuffer[SlaveDataIndex], SlaveWordsPerTransmission[SlaveIndex]);

            // next slave
            SlaveDataIndex += SlaveWordsPerTransmission[SlaveIndex];
            SlaveIndex++;
            if(SlaveIndex == %(numSlaves)d){
              // all slaves have been serviced
    ]] % {
      channel = self.spi_instance,
      numSlaves = #Block.Mask.CsPin,
    }

    local readLookUp = {0, 1, 2}
    if (#Block.InputSignal[1] == spi_fifo_depth) and (Block.Mask.CharLength <= 8) then
      readLookUp = {1, 2, 0}
    end
    for i = 1, #Block.InputSignal[1] do
      code = code ..
                 "    SpiMasterRxData[%i] = SpiMasterRxDataBuffer[%i];\n" %
                 {i - 1, readLookUp[i]}
    end

    code = code .. [[
                     SpiMasterReady = true;

                     SlaveIndex = 0;
                     SlaveTxActive = false;
                 }
             }
         }

         // prime next transmission
         if(SlaveIndex == 0){
  ]]

    for i = 1, #Block.InputSignal[1] do
      code = code .. "    SpiMasterTxData[%i] = %s;\n" %
                 {i - 1, Block.InputSignal[1][i]}
    end

    code = code..[[
        SlaveDataIndex = 0;
        SlaveTxActive = true;
      }

      if(SlaveTxActive){
        PLXHAL_DIO_set(SlaveCsHandles[SlaveIndex], false);
        PLXHAL_SPI_putWords(%(channel)d, &SpiMasterTxData[SlaveDataIndex], SlaveWordsPerTransmission[SlaveIndex]);
      }
    ]] % {
      channel = self.spi_instance,
    }

    OutputCode:append(code)

    -- output signals
    for i = 1, #Block.InputSignal[1] do
      OutputCode:append('%s = SpiMasterRxData[%i];' % {outSignal[1][i], i - 1})
    end
    OutputCode:append('%s = SpiMasterReady;' % {outSignal[2][1]})
    OutputCode:append('%s = SpiMasterTxOverrun;' % {outSignal[3][1]})
    OutputCode:append('}')

    return {
      OutputCode = OutputCode,
      Require = Require,
      UserData = {bid = self:getId()}
    }
  end

  function SpiMaster:p_getNonDirectFeedthroughCode()
    return {}
  end

  function SpiMaster:finalizeThis(c)
  end

  function SpiMaster:finalize(c)
    if static[self.cpu].finalized then
      return
    end

    for _, bid in pairs(static[self.cpu].instances) do
      local spimaster = globals.instances[bid]
      if spimaster:getCpu() == self.cpu then
        spimaster:finalizeThis(c)
      end
    end

    static[self.cpu].finalized = true
  end

  return SpiMaster
end

return Module