# Constantine
# Copyright (c) 2018-2019    Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
#   * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
#   * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

# Strided View - Monodimensional Tensors
# ----------------------------------------------------------------
#
# FFT uses recursive divide-and-conquer.
# In code this means need strided views
# to enable different logical views of the same memory buffer.
# Strided views are monodimensional tensors:
# See Arraymancer backend:
# https://github.com/mratsim/Arraymancer/blob/71cf616/src/arraymancer/laser/tensor/datatypes.nim#L28-L32
# Or the minimal tensor implementation challenge:
# https://github.com/SimonDanisch/julia-challenge/blob/b8ed3b6/nim/nim_sol_mratsim.nim#L4-L26

{.experimental: "views".}

type
  View*[T] = object
    ## A strided view over an (unowned) data buffer
    len*: int
    stride: int
    offset: int
    data: lent UncheckedArray[T]

func `[]`*[T](v: View[T], idx: int): lent T {.inline.} =
  v.data[v.offset + idx*v.stride]

func `[]`*[T](v: var View[T], idx: int): var T {.inline.} =
  # Experimental views indeed ...
  cast[ptr UncheckedArray[T]](v.data)[v.offset + idx*v.stride]

func `[]=`*[T](v: var View[T], idx: int, val: T) {.inline.} =
  # Experimental views indeed ...
  cast[ptr UncheckedArray[T]](v.data)[v.offset + idx*v.stride] = val

func toView*[T](oa: openArray[T]): View[T] {.inline.} =
  result.len = oa.len
  result.stride = 1
  result.offset = 0
  result.data = cast[lent UncheckedArray[T]](oa[0].unsafeAddr)

iterator items*[T](v: View[T]): lent T =
  var cur = v.offset
  for _ in 0 ..< v.len:
    yield v.data[cur]
    cur += v.stride

func `$`*(v: View): string =
  result = "View["
  var first = true
  for elem in v:
    if not first:
      result &= ", "
    else:
      first = false
    result &= $elem
  result &= ']'

func toHex*(v: View): string =
  mixin toHex

  result = "View["
  var first = true
  for elem in v:
    if not first:
      result &= ", "
    else:
      first = false
    result &= elem.toHex()
  result &= ']'

# FFT-specific splitting
# -------------------------------------------------------------------------------

func splitAlternate*(t: View): tuple[even, odd: View] {.inline.} =
  ## Split the tensor into 2
  ## partitioning the input every other index
  ## even: indices [0, 2, 4, ...]
  ## odd: indices [ 1, 3, 5, ...]
  assert (t.len and 1) == 0, "The tensor must contain an even number of elements"

  let half = t.len shr 1
  let skipHalf = t.stride shl 1

  result.even.len = half
  result.even.stride = skipHalf
  result.even.offset = t.offset
  result.even.data = t.data

  result.odd.len = half
  result.odd.stride = skipHalf
  result.odd.offset = t.offset + t.stride
  result.odd.data = t.data

func splitMiddle*(t: View): tuple[left, right: View] {.inline.} =
  ## Split the tensor into 2
  ## partitioning into left and right halves.
  ## left:  indices [0, 1, 2, 3]
  ## right: indices  [4, 5, 6, 7]
  assert (t.len and 1) == 0, "The tensor must contain an even number of elements"

  let half = t.len shr 1

  result.left.len = half
  result.left.stride = t.stride
  result.left.offset = t.offset
  result.left.data = t.data

  result.right.len = half
  result.right.stride = t.stride
  result.right.offset = t.offset + half
  result.right.data = t.data

func skipHalf*(t: View): View {.inline.} =
  ## Pick one every other indices
  ## output: [0, 2, 4, ...]
  assert (t.len and 1) == 0, "The tensor must contain an even number of elements"

  result.len = t.len shr 1
  result.stride = t.stride shl 1
  result.offset = t.offset
  result.data = t.data

func slice*(v: View, start, stop, step: int): View {.inline.} =
  ## Slice a view
  ## stop is inclusive
  # General tensor slicing algorithm is
  # https://github.com/mratsim/Arraymancer/blob/71cf616/src/arraymancer/tensor/private/p_accessors_macros_read.nim#L26-L56
  #
  # for i, slice in slices:
  #   # Check if we start from the end
  #   let a = if slice.a_from_end: result.shape[i] - slice.a
  #           else: slice.a
  #
  #   let b = if slice.b_from_end: result.shape[i] - slice.b
  #           else: slice.b
  #
  #   # Compute offset:
  #   result.offset += a * result.strides[i]
  #   # Now change shape and strides
  #   result.strides[i] *= slice.step
  #   result.shape[i] = abs((b-a) div slice.step) + 1
  #
  # with slices being of size 1, as we have a monodimensional Tensor
  # and the slice being a..<b with the reverse case: len-1 -> 0
  #
  # result is preinitialized with a copy of v (shape, stride, offset, data)
  result.offset = v.offset + start * v.stride
  result.stride = v.stride * step
  result.len = abs((stop-start) div step) + 1
  result.data = v.data

func reversed*(v: View): View {.inline.} =
  # Hopefully the compiler optimizes div by -1
  v.slice(v.len-1, 0, -1)

# ############################################################
#
#                  Debugging helpers
#
# ############################################################
import strformat, strutils

func display*[F](name: string, indent: int, oa: openArray[F]) =
  debugEcho indent(name & ", openarray of " & $F & " of length " & $oa.len, indent)
  for i in 0 ..< oa.len:
    debugEcho indent(&"    {i:>2}: {oa[i].toHex()}", indent)
  debugEcho indent(name & "  " & $F & " -- FIN\n", indent)

func display*[F](name: string, indent: int, v: View[F]) =
  debugEcho indent(name & ", view of " & $F & " of length " & $v.len, indent)
  for i in 0 ..< v.len:
    debugEcho indent(&"    {i:>2}: {v[i].toHex()}", indent)
  debugEcho indent(name & "  " & $F & " -- FIN\n", indent)

# ############################################################
#
#                    Sanity checks
#
# ############################################################

when isMainModule:
  proc main() =
    var x = [0, 1, 2, 3, 4, 5, 6, 7]
    let v = x.toView()

    echo "view: ", v
    echo "reversed: ", v.reversed()

    block:
      let (even, odd) = v.splitAlternate()
      echo "\nSplit Alternate"
      echo "----------------"
      echo "even: ", even
      echo "odd:  ", odd

      block:
        let (ee, eo) = even.splitAlternate()
        echo ""
        echo "even-even: ", ee
        echo "even-odd:  ", eo
        echo "even-even rev: ", ee.reversed()
        echo "even-odd rev:  ", eo.reversed()

      block:
        let (oe, oo) = odd.splitAlternate()
        echo ""
        echo "odd-even: ", oe
        echo "odd-odd:  ", oo
        echo "odd-even rev: ", oe.reversed()
        echo "odd-odd rev:  ", oo.reversed()

    echo "\nSkip Half"
    echo "----------------"
    echo "skipHalf: ", v.skipHalf()
    echo "skipQuad: ", v.skipHalf().skipHalf()
    echo "skipQuad rev: ", v.skipHalf().skipHalf().reversed()

    echo "\nSplit middle"
    echo "----------------"
    block:
      let (left, right) = v.splitMiddle()
      echo "left:  ", left
      echo "right: ", right
      block:
        let (ll, lr) = left.splitMiddle()
        echo ""
        echo "left-left:  ", ll
        echo "left-right: ", lr
        echo "left-left rev:  ", ll.reversed()
        echo "left-right rev: ", lr.reversed()

      block:
        let (rl, rr) = right.splitMiddle()
        echo ""
        echo "right-left:  ", rl
        echo "right-right: ", rr
        echo "right-left rev:  ", rl.reversed()
        echo "right-right rev: ", rr.reversed()

  main()