--[[
  Copyright (c) 2021 by Plexim GmbH
  All rights reserved.

  A free license is granted to anyone to use this software for any legal
  non safety-critical purpose, including commercial applications, provided
  that:
  1) IT IS NOT USED TO DIRECTLY OR INDIRECTLY COMPETE WITH PLEXIM, and
  2) THIS COPYRIGHT NOTICE IS PRESERVED in its entirety.

  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  SOFTWARE.
--]]--
local Module = {}
local U = require('common.utils')

local static = {}
local pullTypeCombo = {'NO', 'UP', 'DOWN'}

function Module.getBlock(globals, cpu)

  local PulseCtr = require('common.block').getBlock(globals, cpu)
  if static[cpu] == nil then
    static[cpu] = {
      numInstances = 0,
      instances = {},
      finalized = false,
    }
  end
  PulseCtr["instance"] = static[cpu].numInstances
  static[cpu].numInstances = static[cpu].numInstances + 1

  function PulseCtr:checkMaskParameters()
    if not U.isNonNegativeIntScalar(Block.Mask.ChPin) then
      U.error('Channel pin must be a scalar positive or zero integer.')
    end
    if Block.Mask.Mode == 2 then
      if not U.isNonNegativeIntScalar(Block.Mask.DirPin) then
        U.error('Direction pin must be a scalar positive or zero integer.')
      end
    end
    if not U.isNonNegativeIntScalar(Block.Mask.InitValue) then
      U.error('Initial condition must be a scalar positive or zero integer.')
    end
    if not U.isPositiveIntScalar(Block.Mask.MaxCounterVal) then
      U.error('Maximum counter value must be a scalar positive integer.')
    end
  end

  function PulseCtr:p_getDirectFeedthroughCode()
    local Require = ResourceList:new()
    local InitCode = U.CodeLines:new()
    local OutputSignal = StringList:new()
    local OutputCode = U.CodeLines:new()

    table.insert(static[self.cpu].instances, self.bid)
    
    local tims = {1, 2, 3, 4, 5, 6, 7, 8} -- must match mask
    local mode = {'single', 'single_plus_dir'} -- must match mask
    local edge = {'rising', 'falling', 'both'} -- must match mask
    local pullTypeCombo = {'NO', 'UP', 'DOWN'} -- must match mask
        
    self.unit = tims[Block.Mask.TimUnit] 
    if (self.unit == 6) or (self.unit == 7) then
      U.error('TIM%i is not supported as pulse counter.' % {self.unit})
    end

    self.initValue = Block.Mask.InitValue
    self.maxCtr = Block.Mask.MaxCounterVal
    local nbrBit = 16
    if (self.unit == 2) or (self.unit == 5) then
      nbrBit = 32
    end

    if self.maxCtr > 2^nbrBit-1 then
      U.error("Maximum counter value (%i) must be lower than 2^%i-1 (%i)." % {self.maxCtr, nbrBit, 2^nbrBit-1})
    end

    if self.initValue > self.maxCtr then
      U.error("Initial condition (%i) must be lower than the maximum counter value (%i)." % {self.initValue, self.maxCtr})
    end

    globals.syscfg:claimResourceIfFree('TIM %d' % {self.unit})
    Require:add('TIM%d' % {self.unit})
    
    self.mode = mode[Block.Mask.Mode]

    if self.mode == 'single_plus_dir' then
      if self:targetMatches({'f3', 'h7'}) then
        U.error("'Single channel + direction' counter mode is not supported by this target.")
      elseif not self:targetMatches('g4') then
        U.throwUnhandledTargetError()
      end
    end

    if Block.Mask.ExtReset ~= 1 then
      OutputCode:append('{')
      OutputCode:append("static unsigned char lastTriggerValue = 0;\n")
      OutputCode:append("int triggerValue = !!%s;\n" % {Block.InputSignal[1][1]})
      if Block.Mask.ExtReset == 2 then -- rising
        OutputCode:append("if (!lastTriggerValue && triggerValue) {\n")
      elseif Block.Mask.ExtReset == 3 then -- falling
        OutputCode:append("if (lastTriggerValue && !triggerValue) {\n")
      else --either
        OutputCode:append("if (lastTriggerValue != triggerValue) {\n")
      end
      OutputCode:append("PLXHAL_CTR_setCounter(%i, %i);" % {self.instance, self.initValue})
      OutputCode:append("}\nlastTriggerValue = triggerValue;\n")
      OutputCode:append('}')
    end
    
    OutputSignal:append('PLXHAL_CTR_getCounter(%i)' % {self.instance})

    local port = string.char(65+Block.Mask.ChPort-1)
    local pin = Block.Mask.ChPin
    Require:add('P%s' % {port}, pin)

    local channel, af
    if self.mode == 'single_plus_dir' then -- Channel 2 is used for the counter
      channel = 2
      local func = 'TIM%d_CH2' % {self.unit}
      local pad = '%s%d' % {port, pin}
      local errMsgPrefix = 'Invalid @param:ChPort: and @param:ChPin: for the selected @param:TimUnit:.'

      af = globals.target.getAlternateFunctionOrError({
        func = func,
        pad = pad,
        opt_errMsgPrefix = errMsgPrefix,
      })
    elseif self.mode == 'single' then -- Channels 1 and 2 can be used for the counter
      local chIdxCandidates = {1, 2}
      for _, ch in ipairs(chIdxCandidates) do
        local func = 'TIM%d_CH%d' % {self.unit, ch}
        local pad = '%s%d' % {port, pin}

        if globals.target.checkAlternateFunctionExists({func = func, pad = pad}) then
          af = globals.target.getAlternateFunctionOrError({
            func = func,
            pad = pad,
          })
          channel = ch
          break
        end
      end

      if not channel then
        local func = 'TIM%d_CH[12]' % {self.unit} -- Note: This is a regular expression which we don't want to display to the user!
        local userErrorFunc = 'TIM%(unit)d_CH1 or TIM%(unit)d_CH2'
           % {unit = self.unit}

        local padsMsg = globals.target.getValidPadsMsg({
          func = func,
          opt_userErrorFunc = userErrorFunc,
        })

        U.error('Invalid @param:ChPort: and @param:ChPin: for the selected @param:TimUnit:.\n\n%s'
          % {padsMsg})
      end
    else
      error('Edge counter mode is configured incorrectly.')
    end

    self.ch_conf = {
      channel = channel,
      port = port,
      pin = pin,
      pullType = pullTypeCombo[Block.Mask.ChPullType],
      edge = edge[Block.Mask.ChEdge],
      af = af
    }

    if self.mode == 'single_plus_dir' then
      local port = string.char(65+Block.Mask.DirPort-1)
      local pin = Block.Mask.DirPin
      self.dir_conf = {
        port = port,
        pin = pin,
        pullType = pullTypeCombo[Block.Mask.DirPullType]
      }
      local errMsgPrefix = 'Invalid @param:DirPort: and @param:DirPin: for the selected @param:TimUnit:.'

      self.dir_conf.af = globals.target.getAlternateFunctionOrError({
        func = 'TIM%d_CH1' % {self.unit},
        pad = '%s%d' % {port, pin},
        opt_errMsgPrefix = errMsgPrefix,
      })

      Require:add('P%s' % {port}, pin)
      if self.ch_conf.channel ~= 2 then
        local errMsgPrefix = 'Invalid @param:ChPort: and @param:ChPin: for the selected @param:TimUnit:.'

        globals.target.getAlternateFunctionOrError({
          func = 'TIM%d_CH2' % {self.unit},
          pad = '%s%d' % {self.ch_conf.port, self.ch_conf.pin},
          opt_errMsgPrefix = errMsgPrefix,
        })
      end
    end

    local pinconf = {}

    if self.ch_conf ~= nil then
      table.insert(pinconf,{
        port = self.ch_conf.port,
        pin = self.ch_conf.pin,
        pull = self.ch_conf.pullType,
        af = self.ch_conf.af
      })
    end
    if self.dir_conf ~= nil then
      table.insert(pinconf,{
        port = self.dir_conf.port,
        pin = self.dir_conf.pin,
        pull = self.dir_conf.pullType,
        af = self.dir_conf.af
      })
    end
    globals.syscfg:addEntry('PulseCtr', {
      unit = self.unit,
      pins = pinconf,
      path = self:getName()
    })

    return {
      InitCode = InitCode,
      OutputCode = OutputCode,
      OutputSignal = {OutputSignal},
      Require = Require,
      UserData = {bid = self:getId()}
    }
  end
  
  function PulseCtr:p_getNonDirectFeedthroughCode()
    return {}
  end

  function PulseCtr:finalizeThis(c)
    local init_code = [[
      LL_TIM_InitTypeDef eInitDef = {
         0
      };

      eInitDef.Prescaler = 0;
      eInitDef.CounterMode = LL_TIM_COUNTERMODE_UP;
      eInitDef.Autoreload = %(reloadValue)d;
      eInitDef.ClockDivision = LL_TIM_CLOCKDIVISION_DIV1;
      eInitDef.RepetitionCounter = 0;

      LL_TIM_Init(TIM%(unit)d, &eInitDef);
      LL_TIM_DisableARRPreload(TIM%(unit)d);
      
      LL_TIM_SetTriggerOutput(TIM%(unit)d, LL_TIM_TRGO_RESET);
      LL_TIM_SetTriggerOutput2(TIM%(unit)d, LL_TIM_TRGO2_RESET);
      LL_TIM_DisableMasterSlaveMode(TIM%(unit)d);


    ]] % {
      unit = self.unit,
      reloadValue = self.maxCtr,
    }

    local m_chPol
    if self.ch_conf.edge == 'both' then
      m_chPol = 'LL_TIM_IC_POLARITY_BOTHEDGE'
    elseif self.ch_conf.edge == 'rising' then
      m_chPol = 'LL_TIM_IC_POLARITY_RISING'
    else
      m_chPol = 'LL_TIM_IC_POLARITY_FALLING'
    end
  
    if self.mode == 'single' then
      init_code = init_code..[[
        LL_TIM_SetTriggerInput(TIM%(unit)d, LL_TIM_TS_TI%(channel)iFP%(channel)d);
        LL_TIM_SetClockSource(TIM%(unit)d, LL_TIM_CLOCKSOURCE_EXT_MODE1);
        LL_TIM_IC_SetFilter(TIM%(unit)d, LL_TIM_CHANNEL_CH%(channel)d, LL_TIM_IC_FILTER_FDIV1);
        LL_TIM_IC_SetPolarity(TIM%(unit)d, LL_TIM_CHANNEL_CH%(channel)d, %(m_chPol)s);
        LL_TIM_DisableIT_TRIG(TIM%(unit)d);
        LL_TIM_DisableDMAReq_TRIG(TIM%(unit)d);
      ]] % {
        unit = self.unit,
        channel = self.ch_conf.channel,
        m_chPol = m_chPol,
      }
    else
      init_code = init_code..[[
        LL_TIM_ENCODER_InitTypeDef eConfig = {
           0
        };

        eConfig.EncoderMode = LL_TIM_ENCODERMODE_CLOCKPLUSDIRECTION_%(m_counterMode)s;
        eConfig.IC1Polarity = LL_TIM_IC_POLARITY_RISING;
        eConfig.IC1ActiveInput = LL_TIM_ACTIVEINPUT_DIRECTTI;
        eConfig.IC1Prescaler = LL_TIM_ICPSC_DIV1;
        eConfig.IC1Filter = 0;
        eConfig.IC2Polarity = %(m_chPol)s;
        eConfig.IC2ActiveInput = LL_TIM_ACTIVEINPUT_DIRECTTI;
        eConfig.IC2Prescaler = LL_TIM_ICPSC_DIV1;
        eConfig.IC2Filter = 0;
        (void) LL_TIM_ENCODER_Init(TIM%(unit)d, &eConfig);
      ]] % {
        m_counterMode = (self.ch_conf.edge == 'both') and 'X2' or 'X1',
        unit = self.unit,
        m_chPol = m_chPol,
      }
    end

    c.PreInitCode:append("{")
    c.PreInitCode:append(init_code)
    c.PreInitCode:append('PLX_CTR_setup(CtrHandles[%i], PLX_CTR_TIM%i, %i);' % {self.instance, self.unit, self.maxCtr})
    
    local enable_code = [[
      LL_TIM_SetCounter(TIM%(unit)d, %(initValue)d);
      LL_TIM_EnableCounter(TIM%(unit)d);
    ]] % {
      unit = self.unit,
      initValue = self.initValue,
    }
    c.PreInitCode:append(enable_code)
    c.PreInitCode:append("}")
  end

  function PulseCtr:finalize(c)
    if static[self.cpu].finalized then
      return
    end

    c.Include:append('plx_ctr.h')
    c.Declarations:append('PLX_CTR_Handle_t CtrHandles[%i];' %
                              {static[self.cpu].numInstances})
    c.Declarations:append('PLX_CTR_Obj_t CtrObj[%i];' % {static[self.cpu].numInstances})

    c.Declarations:append('uint32_t PLXHAL_CTR_getCounter(uint16_t aCtr){')
    c.Declarations:append('  return PLX_CTR_getCounter(CtrHandles[aCtr]);')
    c.Declarations:append('}')

    c.Declarations:append('void PLXHAL_CTR_setCounter(uint16_t aCtr, uint32_t aValue){')
    c.Declarations:append('  PLX_CTR_setCounter(CtrHandles[aCtr], aValue);')
    c.Declarations:append('}')

    local code = [[
      PLX_CTR_sinit();
      for (int i = 0; i < %d; i++) {
        CtrHandles[i] = PLX_CTR_init(&CtrObj[i], sizeof(CtrObj[i]));
      }
    ]]
    c.PreInitCode:append(code % {static[self.cpu].numInstances})
    
    for _, bid in pairs(static[self.cpu].instances) do
      local pulseCtr = globals.instances[bid]
      if pulseCtr:getCpu() == self.cpu then
        pulseCtr:finalizeThis(c)
      end
    end

    static[self.cpu].finalized = true
  end

  return PulseCtr
end

return Module
