local MultiTasking = { }

local HaveMultipleTasks = false
local TasksPerCore = { }
local MaxPeriodTicks = 1

local function DeepCopy(orig)
  local copy
  if type(orig) == 'table' then
    copy = {}
    for k,v in pairs(orig) do
      copy[DeepCopy(k)] = DeepCopy(v)
    end
  else
    copy = orig
  end
  return copy
end

function MultiTasking.NumTasksOnCore(core)
  return #TasksPerCore[core]
end

function MultiTasking.CheckTaskConfiguration(taskConfig)

  -- Group tasks by core.
  local TasksPerCore = {
    [1] = {},
    [2] = {},
    [3] = {},
  }
  for id, task in ipairs(taskConfig) do
    local coreId = task.Core + 1 -- our core ids start at 0 but Lua lists use 1-based indexing

    if coreId > 3 then
      local msg = StringList:new()
      msg:append("Invalid core %d in task configuration." % { task.Core })
      msg:append("Valid core numbers are 0-2.")
      return true, table.concat(msg, " ")
    end
    local coreTasks = TasksPerCore[coreId]
    coreTasks[#coreTasks+1] = task
    TasksPerCore[coreId] = coreTasks
  end
  
  if #TasksPerCore[2] > 1 or #TasksPerCore[3] > 1 then
    local msg = StringList:new()
    msg:append("Cores 1 and 2 can execute only one single task each.")
    for coreIdx = 2, 3 do
      if #TasksPerCore[coreIdx] > 1 then
        local coreTasks = StringList:new()
        for id, task in ipairs(TasksPerCore[coreIdx]) do
          coreTasks:append("- " .. task.Name)
        end
        msg:append("The following tasks are configured to run on core %i:\n%s"
                   % {
                     coreIdx - 1,
                     table.concat(coreTasks, "\n")
                   })
      end
    end
    return true, table.concat(msg, "\n")
  end
  
  return false

end

function MultiTasking.AnalyzeTasks(tasks)

  HaveMultipleTasks = #tasks > 1
  
  -- Group tasks by core.
  TasksPerCore = {
    [1] = {},
    [2] = {},
    [3] = {},
  }
  for id, task in ipairs(tasks) do
    local coreId = task.Core + 1 -- our core ids start at 0 but Lua lists use 1-based indexing

    local coreTasks = TasksPerCore[coreId]
    coreTasks[#coreTasks+1] = DeepCopy(task)
    TasksPerCore[coreId] = coreTasks
  end

end


function MultiTasking.CalculateTicks(basePeriod)

  MaxPeriodTicks = 1
  for _, core in ipairs(TasksPerCore) do
    for _, task in ipairs(core) do
      task.period = math.floor(task.SampleTime[1] / basePeriod + 0.5)
      if task.period > MaxPeriodTicks then
        MaxPeriodTicks = task.period
      end
      task.offset = math.floor(task.SampleTime[2] / basePeriod + 0.5)
    end
  end

end


function MultiTasking.GetMaxPeriodTicks()
  return MaxPeriodTicks
end


function MultiTasking.CreateCoreTimingCode()
  
  local code = StringList:new()

  for idx = 2, 3 do
    corePeriod = 0
    coreTick = 0
    if #TasksPerCore[idx] > 0 then
      corePeriod = TasksPerCore[idx][1].period
      if TasksPerCore[idx][1].offset > 0 then
        coreTick = TasksPerCore[idx][1].period - TasksPerCore[idx][1].offset;
      end
    end
    code:append("      .mCore%iTick = %i" % { idx-1, coreTick})
    code:append("      .mCore%iPeriod = %i" % { idx-1, corePeriod})
  end

  return table.concat(code, ',\n')

end


function MultiTasking.NumTasksPerCore(coreId)
  return #TasksPerCore[coreId+1]
end


function MultiTasking.HaveNonBasePeriodTask()

  for _, core in ipairs(TasksPerCore) do
    for _, task in ipairs(core) do
      if task.period > 1 then return true end
    end
  end
  
  return false

end


local function channels2masks(pwmTasks, taskName)
  local maskLo = 0
  local maskHi = 0
  if pwmTasks[taskName] then
    for x, pwmChannel in pairs(pwmTasks[taskName]) do
      if pwmChannel < 32 then
        maskLo = maskLo | (1 << pwmChannel)
      else
        maskHi = maskHi | (1 << (pwmChannel - 32))
      end
    end
  end
  return maskLo, maskHi
end

local function initialPwmMasks(pwmTasks)
  local maskLo = 0
  local maskHi = 0
  for _, tasks in pairs(TasksPerCore) do
    if #tasks > 0 and tasks[1].period == 1 then
      local newMaskLo, newMaskHi = channels2masks(pwmTasks, tasks[1].Name)
      maskLo = maskLo | newMaskLo
      maskHi = maskHi | newMaskHi
    end
  end
  return maskLo, maskHi
end

local function needsPwmUpdateCode(pwmTasks)
  local needsLowerPwmUpdateCode = false
  local needsHigherPwmUpdateCode = false
  for _, tasks in pairs(TasksPerCore) do
    for idx, task in pairs(tasks) do
      if idx > 1 or task.period > 1 then
        local newMaskLo, newMaskHi = channels2masks(pwmTasks, task.Name)
        if newMaskLo > 0 then
          needsLowerPwmUpdateCode = true
        end
        if newMaskHi > 0 then
          needsHigherPwmUpdateCode = true
        end
      end
    end
  end
  return needsLowerPwmUpdateCode, needsHigherPwmUpdateCode
end


local function pwmUpdateCode(maskLo, maskHi, indent)
  local code = "%splxPreparePwmUpdate(0x%x);" % { indent, maskLo }
  if maskHi > 0 then
    code = code .. "\n%splxPrepareHigherPwmUpdate(0x%x);" % { indent, maskHi }
  end
  return code
end

local function needsIoLock()
  local basePeriod = TasksPerCore[1][1].period
  for core = 2, 3 do
    for _, task in ipairs(TasksPerCore[core]) do
      if task.period > basePeriod then
        return true
      end
    end
  end
  return false
end
  
local function pwmCoreUpdateCode(core, pwmTasks)
  local code = StringList:new()
  local taskList = TasksPerCore[core]
  
  if #taskList > 0 and taskList[1].period > 1 then
    local maskLo, maskHi = channels2masks(pwmTasks, taskList[1].Name)
    if maskLo > 0 or maskHi > 0 then
      code:append("   if (aCore%dAboutToStart)\n   {" % core)
      if maskLo > 0 then
        code:append("      mPwmUpdateMask |= 0x%x;" % { maskLo })
      end
      if maskHi > 0 then
        code:append("      mHigherPwmUpdateMask |= 0x%x;" % { maskHi })
      end
      code:append("   }")
    else
      code:append("   (void)aCore%dAboutToStart;" % core)
    end
  else
    code:append("   (void)aCore%dAboutToStart;" % core)
  end
  return table.concat(code, '\n')
end


function MultiTasking.CreateInitFunctionCode(pwmTasks)

  local code = StringList:new()
  code:append([=[
int plxPlatform_poll(void);
static void modelInitFunction(void)
{]=])
  if needsIoLock() then
    code:append("   plxIoLock.lock = 0;")
  end
  code:append([=[
   MODEL_INITIALIZE(0);
   postInitCode();
   // Run step functions a few times to have code in cache
   for (int i=0; i<10; i++)
   {
      if (!plxErrorFlag)
      {]=])
  if HaveMultipleTasks then
    for _, task in ipairs(TasksPerCore[1]) do
      code:append("         MODEL_STEP(%i);" % { task.TaskId })
      code:append("         plxPlatform_poll();")
    end
    code:append("         plxMulticoreSyncedStep();")
  else
    code:append("         MODEL_STEP();")
    code:append("         plxPlatform_poll();")
  end
  local pwmMaskLo, pwmMaskHi = initialPwmMasks(pwmTasks)
  code:append([=[
      }
   }
   if (!plxErrorFlag)
   {
      MODEL_TERMINATE();
   }
%(preparePwmUpdateCode)s
   MODEL_INITIALIZE(0);
   postInitCode();
}]=]
  %
  {
    preparePwmUpdateCode = pwmUpdateCode(pwmMaskLo, pwmMaskHi, "   "),
  })

  return table.concat(code, '\n')

end


function MultiTasking.CreateSyncFunctionCode()

  local code = StringList:new()

  code:append([=[
static void modelSyncFunction(void)
{
   MODEL_SYNC();
}]=])

  return table.concat(code, '\n')
end

function MultiTasking.CreateStepFunctionCode(preBaseStepCode, postBaseStepCode, 
                                             pwmTasks, ioCode, ioDeclarations)

  -- Generate the IRQ functions for the cores
  local code = StringList:new()
  local hasIoCode = #ioCode[1] > 0 or #ioCode[2] > 0
  local initalPwmMaskLo, initalPwmMaskHi = initialPwmMasks(pwmTasks)
  local needsLowerPwmUpdateCode, needsHigherPwmUpdateCode = needsPwmUpdateCode(pwmTasks)
  if needsLowerPwmUpdateCode == true then
    code:append("static uint32_t mPwmUpdateMask;");
  end
  if needsHigherPwmUpdateCode == true then
    code:append("static uint32_t mHigherPwmUpdateMask;");
  end
  if hasIoCode == true then
    code:append(ioDeclarations)
  end
  code:append("")
  code:append("void plxGeneratedMultiCorePreStepFunction(bool aCore2AboutToStart, bool aCore3AboutToStart)")
  code:append("{");
  if hasIoCode == true then
    code:append("   spin_lock(&plxIoLock);\n")
  end
  code:append(table.concat(preBaseStepCode, '\n'))
  if hasIoCode == true then
    code:append("   spin_unlock(&plxIoLock);\n")
  end
  if needsLowerPwmUpdateCode == true then
    code:append("   mPwmUpdateMask = 0x%x;" % { initalPwmMaskLo });
  end
  if needsHigherPwmUpdateCode == true then
    code:append("   mHigherPwmUpdateMask = 0x%x;" % { initalPwmMaskHi });
  end
  code:append(pwmCoreUpdateCode(2, pwmTasks))
  code:append(pwmCoreUpdateCode(3, pwmTasks))
  code:append("}");
  code:append("")
  code:append("void plxGeneratedSingleCorePreStepFunction(void)")
  code:append("{");
  code:append(table.concat(preBaseStepCode, '\n'))
  if needsLowerPwmUpdateCode == true then
    code:append("   mPwmUpdateMask = 0x%x;" % { initalPwmMaskLo });
  end
  if needsHigherPwmUpdateCode == true then
    code:append("   mHigherPwmUpdateMask = 0x%x;" % { initalPwmMaskHi });
  end
  code:append("}");
  code:append("")
  
  code:append("void modelStepFunction0(void)")
  code:append("{");
  local taskList = TasksPerCore[1]
  if #taskList ~= 0 then
    local baseTaskIdx = nil
    local firstSubTaskIdx = 1
    if taskList[1].period == 1 then
      baseTaskIdx = 1
      firstSubTaskIdx = 2
    end
    
    local subTaskIds = { }
    local subTaskPeriods = { }
    local subTaskFalse = { }
    local subTaskTicks = { }
    local subTaskOverrunMsgs = StringList:new()
    local pwmTaskMask = {}
    local higherPwmTaskMask = {}
    local hasPwmTaskMask = false
    local hasHigherPwmTaskMask = false
    for subTaskIdx = firstSubTaskIdx, #taskList do
      local mask, maskHi = channels2masks(pwmTasks, taskList[subTaskIdx].Name)
      table.insert(pwmTaskMask, mask)
      table.insert(higherPwmTaskMask, maskHi)
      hasPwmTaskMask = hasPwmTaskMask or (mask > 0)
      hasHigherPwmTaskMask = hasHigherPwmTaskMask or (maskHi > 0)

      local subTask = taskList[subTaskIdx]
      table.insert(subTaskIds, subTask.TaskId)
      table.insert(subTaskPeriods, subTask.period)
      table.insert(subTaskTicks, subTask.offset ~= 0 and subTask.period - subTask.offset or 0)
      table.insert(subTaskFalse, 'false')
      subTaskOverrunMsgs:append('"Overrun in Task \\"%s\\"."' % { subTask.Name })
    end
    
    if #subTaskIds > 0 then
      code:append([=[
   static const size_t subTaskId[%(numSubTasks)i] = { %(subTaskIds)s };
   static const size_t subTaskPeriod[%(numSubTasks)i] = { %(subTaskPeriods)s };
   static size_t subTaskTick[%(numSubTasks)i] = { %(subTaskTicks)s };
   static atomic_bool subTaskHit[%(numSubTasks)i];
   static bool subTaskActive[%(numSubTasks)i] = { %(subTaskFalse)s };
   static const char* subTaskOverrunMsg[%(numSubTasks)i] = { %(subTaskOverrunMsgs)s };
   size_t subTaskIdx;]=]
  %
  {
    numSubTasks = #subTaskIds,
    subTaskIds = table.concat(subTaskIds, ", "),
    subTaskPeriods = table.concat(subTaskPeriods, ", "),
    subTaskTicks = table.concat(subTaskTicks, ", "),
    subTaskFalse = table.concat(subTaskFalse, ", "),
    subTaskOverrunMsgs = table.concat(subTaskOverrunMsgs, ", "),
  })
      if hasPwmTaskMask == true then
        code:append([=[
   static const uint32_t pwmMaskList[%(numSubTasks)i] = { %(pwmMaskList)s };]=]
        %
        {
          numSubTasks = #subTaskIds,
          pwmMaskList = table.concat(pwmTaskMask, ", "),
        })
      end
      if hasHigherPwmTaskMask == true then
        code:append([=[
   static const uint32_t higherPwmMaskList[%(numSubTasks)i] = { %(pwmMaskList)s };]=]
        %
        {
          numSubTasks = #subTaskIds,
          pwmMaskList = table.concat(higherPwmTaskMask, ", "),
        })
      end
    end
        
    if baseTaskIdx then
      code:append("   /* Execute base task. */")
      if #subTaskIds > 0 then
        code:append([=[
   static int inBaseStep = 0;
   if (inBaseStep)
   {
      plxCancelTimingMeasurement();
      return;
   }
   inBaseStep++;
]=])
      end
      if HaveMultipleTasks then
        code:append("   MODEL_STEP(%s);" % { taskList[baseTaskIdx].TaskId })
      else
        code:append("   MODEL_STEP();")
      end
    end

    code:append([=[
   if(unlikely((size_t)MODEL_ERROR_STATUS))
   {
      plxStopTimer();
      plxUserMessage(PLXUSERMSG_NEEDS_ATTENTION, "%s\n", MODEL_ERROR_STATUS);
      plxErrorFlag = 1;
]=])
    if #subTaskIds == 0 then
      code:append("      plxWaitForIrqAck();\n")
    end
    code:append([=[
   }
   if (unlikely(plxErrorFlag))
      return;
   
#if defined(EXTERNAL_MODE) && EXTERNAL_MODE
   checkScopeTrigger();
#endif /* defined(EXTERNAL_MODE) */
]=])
      
    if #subTaskIds > 0 then
      code:append([=[
   // Increment sub-task tick counters and set sub-task hit flags.
   for (subTaskIdx = 0; subTaskIdx < %(numSubTasks)i; ++subTaskIdx)
   {
      if (subTaskTick[subTaskIdx] == 0)
         atomic_store_explicit(&subTaskHit[subTaskIdx], true, memory_order_relaxed);
      subTaskTick[subTaskIdx] = (subTaskTick[subTaskIdx]+1) %% subTaskPeriod[subTaskIdx];]=]
      %
      {
        numSubTasks = #subTaskIds,
      })
      if hasPwmTaskMask == true or hasHigherPwmTaskMask == true then
        code:append("      if (!subTaskTick[subTaskIdx])\n      {")
        if hasPwmTaskMask == true then
          code:append("         mPwmUpdateMask |= pwmMaskList[subTaskIdx];")
        end
        if hasHigherPwmTaskMask == true then
          code:append("         mHigherPwmUpdateMask |= higherPwmMaskList[subTaskIdx];")
        end
        code:append("      }")
      end
      code:append([=[
      if (!subTaskActive[subTaskIdx])
         plxActiveTasks &= ~(1 << (subTaskIdx+1));
   }]=]
      )
    end

    if needsLowerPwmUpdateCode == true then
      code:append("   plxPreparePwmUpdate(mPwmUpdateMask);\n")
    end
    if needsHigherPwmUpdateCode == true then
      code:append("   plxPrepareHigherPwmUpdate(mHigherPwmUpdateMask);\n")
    end

    if baseTaskIdx and #subTaskIds > 0 then
      code:append("   plxPostBaseStep();")
      code:append("   inBaseStep--;")
    end
    if baseTaskIdx then
      code:append(table.concat(postBaseStepCode, '\n'))
    end
    
    if #subTaskIds > 0 then
      code:append([=[
   // Execute sub-task based on sub-task hit flags.
   for (subTaskIdx = 0; subTaskIdx < %(numSubTasks)i; ++subTaskIdx)
   {
      if (atomic_exchange_explicit(&subTaskHit[subTaskIdx], false, memory_order_relaxed))
      {
         if (subTaskActive[subTaskIdx])
         {
            MODEL_ERROR_STATUS = subTaskOverrunMsg[subTaskIdx];
            break;
         }
         subTaskActive[subTaskIdx] = true;
         plxActiveTasks |= (1 << (subTaskIdx+1));

         MODEL_STEP(subTaskId[subTaskIdx]);
         
         subTaskActive[subTaskIdx] = false;
      }
      else if (subTaskActive[subTaskIdx])
         break;
   }]=]
  %
  {
    numSubTasks = #subTaskIds,
  })
    end
    
  end
  code:append("}");

  for coreId = 2, 3 do
    code:append("")
    code:append("void modelStepFunction%i(void)" % { coreId - 1 })
    code:append("{");
    if #TasksPerCore[coreId] > 0 then
      code:append("   MODEL_STEP(%s);" % { TasksPerCore[coreId][1].TaskId })
      if #ioCode[coreId - 1] > 0 then
        code:append("   spin_lock(&plxIoLock);")
        code:append("   " .. table.concat(ioCode[coreId - 1], "\n   "))
        code:append("   spin_unlock(&plxIoLock);")
      end
    end
    code:append("}");
  end

  return table.concat(code, '\n')
end

return MultiTasking
