package modbat.test

import modbat.mbt._
import modbat.mbt.Predef._
import java.io.IOException
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.ClosedByInterruptException
import java.nio.channels.ServerSocketChannel
import java.nio.channels.SocketChannel

/* Revised model of NioSocket1. Uses return value of non-blocking connect
   to allow a possible transition to state "connected"; also keeps count
   of number of bytes read, for a model that should never fail. */

object JavaNioSocket {
  object TestServer extends Thread {
    val ch = ServerSocketChannel.open()
    ch.socket().bind(new InetSocketAddress("localhost", 8888))
    ch.configureBlocking(true)

    override def run() {
      var closed = false
      var connection: SocketChannel = null
      while (!closed) {
        try {
          connection = ch.accept()
          val buf = ByteBuffer.allocate(2)
          buf.asCharBuffer().put("\n")
          connection.write(buf)
          connection.socket().close()
        } catch {
          case e: ClosedByInterruptException => {
	    if (connection != null) {
              connection.socket().close()
	    }
            closed = true
          }
        }
      }
      TestServer.ch.close()
    }
  }

  @init def startServer() {
    TestServer.start()
  }

  @shutdown def shutdown() {
    TestServer.interrupt()
  }
}

class JavaNioSocket extends Model {
  var ch: SocketChannel = null
  var connected: Boolean = false // track ret. val. of non-blocking connect
  var n = 0 // number of bytes read so far

  @after def cleanup() {
    if (ch != null) {
      ch.close()
    }
  }

  // helper functions
  def connect(ch: SocketChannel) = {
    ch.connect(new InetSocketAddress("localhost", 8888))
  }

  def readFrom(ch: SocketChannel, n: Int) = {
    val buf = ByteBuffer.allocate(1)
   /* TODO: for non-blocking reads: check return value, increment n only
      if data is actually read. JPF model should help to find this. */
    val l = ch.read(buf)
    if (n < 2) {
      var limit = 0 // non-blocking read may return 0 bytes
      if (ch.isBlocking()) {
	limit = 1
      }
      assert(l >= limit,
	     {"Expected data, got " + l + " after " +
	      (n + 1) + " reads with blocking = " + ch.isBlocking()})
    } else {
      assert(l == -1,
	     {"Expected EOF, got " + l + " after " +
	      (n + 1) + " reads with blocking = " + ch.isBlocking()})
    }
    l
  }

  def toggleBlocking(ch: SocketChannel) {
    ch.configureBlocking(!ch.isBlocking())
  }

  // transitions
  def instance() = {
    new MBT (
      "reset" -> "open" := {
	ch = SocketChannel.open()
      },
      "open" -> "open" := {
	toggleBlocking(ch)
      },
      "open" -> "connected" := {
	require(ch.isBlocking())
	connect(ch)
      },
      "open" -> "maybeconnected" := {
	require(!ch.isBlocking())
	Thread.sleep(50)
	connected = connect(ch)
	maybe { toggleBlocking(ch); connected = ch.finishConnect }
      } maybeNextIf ((() => connected) -> "connected"),
      "maybeconnected" -> "maybeconnected" := {
	toggleBlocking(ch)
      },
      "maybeconnected" -> "connected" := {
	require(ch.isBlocking())
	ch.finishConnect()
      },
      "maybeconnected" -> "maybeconnected" := {
	require(!ch.isBlocking())
	Thread.sleep(50)
      } maybeNextIf ((() => ch.finishConnect) -> "connected"),
      "open" -> "err" := {
	ch.finishConnect()
      } throws ("NoConnectionPendingException"),
      "maybeconnected" -> "err" := {
	require(!connected)
	connect(ch)
      } throws ("ConnectionPendingException"),
      "connected" -> "err" := {
	connect(ch)
      } throws ("AlreadyConnectedException"),
      "open" -> "err" := {
	readFrom(ch, n)
      } throws ("NotYetConnectedException"),
      "maybeconnected" -> "err" := {
	require(!connected)
	readFrom(ch, n)
      } throws ("NotYetConnectedException"),
      "connected" -> "connected" := {
	ch.finishConnect() // redundant call to finishConnect (no effect)
      },
      "connected" -> "connected" := {
	val l = readFrom(ch, n)
	if (l > 0) {
	  n = n + l
	}
      },
      Set("open", "connected", "maybeconnected", "closed") -> "closed" := {
	ch.close()
      },
      "closed" -> "err" := {
	readFrom(ch, n)
      } throws ("ClosedChannelException")
    )
  }
}
