-- SparseMatrix.lua
-- 
-- Copyright (c) 2023 Plexim GmbH
-- All rights reserved.


local DOUBLE = 1
local UINT32 = 2

local function num2string(aNum, aType)
  if aType == DOUBLE then
    return string.format("%.18g" % aNum)
  elseif aType == UINT32 then
    return string.format("%d" % aNum)
  else
    return "IMPLEMENTATION_ERROR"
  end
end

local function matrixVectorData(aVec, aVecSize, aType)
  if not aVec[0] then
    return "0"
  end
  local elements = ""
  for i = 0, aVecSize - 1 do
    elements = elements .. num2string(aVec[i], aType) .. ","
  end
  return elements
end

--
-- A SparseMatrix implementation for Lua
-- This is based on the StaticSparseMatrix implementation in PLECS.
-- To make data conversion from PLECS easy, all table indices are 0-based
local SparseMatrix = {
  mRowSize = 0,
  mColSize = 0,
  mDataSize = 0
}

function SparseMatrix:new()
  local o = { mRowPtr = {}, mColIdx = {}, mData = {} }
  self.__index = self
  setmetatable(o, self)
  return o
end

function SparseMatrix:areRowsEqual(aRow1, aRow2)
  local colRange = self.mRowPtr[aRow1+1] - self.mRowPtr[aRow1]
  if (self.mRowPtr[aRow2 + 1] - self.mRowPtr[aRow2] ~= colRange) then
    return false
  end
  local idx1 = self.mRowPtr[aRow1]
  local idx2 = self.mRowPtr[aRow2]
  for i = 0, colRange - 1 do
    if self.mColIdx[idx1 + i] ~= self.mColIdx[idx2 + i] or self.mData[idx1 + i] ~= self.mData[idx2 + i] then
      return false
    end
  end
  return true
end

function SparseMatrix:areRowsNegated(aRow1, aRow2)
  local colRange = self.mRowPtr[aRow1+1] - self.mRowPtr[aRow1]
  if (self.mRowPtr[aRow2 + 1] - self.mRowPtr[aRow2] ~= colRange) then
    return false
  end
  local idx1 = self.mRowPtr[aRow1]
  local idx2 = self.mRowPtr[aRow2]
  for i = 0, colRange - 1 do
    if self.mColIdx[idx1 + i] ~= self.mColIdx[idx2 + i] or self.mData[idx1 + i] ~= -self.mData[idx2 + i] then
      return false
    end
  end
  return true
end

function SparseMatrix:filter(aRows)
  local m = SparseMatrix:new()
  m.mRowPtr[0] = 0
  local newIdx = 0
  local newRow = 0
  for k, r in pairs(aRows) do
    for j = self.mRowPtr[r], self.mRowPtr[r + 1] - 1 do
      local col = self.mColIdx[j]
      m.mColIdx[newIdx] = col
      m.mData[newIdx] = self.mData[j]
      newIdx = newIdx + 1
    end
    newRow = newRow + 1
    m.mRowPtr[newRow] = newIdx
  end
  m.mRowSize = newRow + 1
  m.mColSize = newIdx
  m.mDataSize = newIdx
  return m
end

function SparseMatrix:toCode(aName)
   return string.format([[
   static const int %(name)s_rowptr[] = { %(rowptr)s };
   static const int %(name)s_colidx[] = { %(colidx)s };
   static const double %(name)s_data[] = { %(data)s };
   static const struct FPGAPhysicalModelMatrix %(name)s = {
      .rowptr = &%(name)s_rowptr[0],
      .colidx = &%(name)s_colidx[0],
      .data   = &%(name)s_data[0]
   };
   ]] % {
     name = aName,
     rowptr = matrixVectorData(self.mRowPtr, self.mRowSize, UINT32),
     colidx = matrixVectorData(self.mColIdx, self.mColSize, UINT32),
     data = matrixVectorData(self.mData, self.mDataSize, DOUBLE)
   })
end

function SparseMatrix.fromHash(aHash)
  local function list2array(aList)
    local a = {}
    for i = 1, #aList do
      a[i-1] = aList[i]
    end
    return a
  end
  local matrix = SparseMatrix:new()
  matrix.mRowSize = #aHash.rowPtr
  matrix.mColSize = #aHash.colIdx
  matrix.mDataSize = #aHash.data
  matrix.mRowPtr = list2array(aHash.rowPtr)
  matrix.mColIdx = list2array(aHash.colIdx)
  matrix.mData = list2array(aHash.data)
  return matrix
end

return SparseMatrix
