import copy

from Blackjack import *  # basic blackjack module  
from BlackjackAgents import *  # a few comparison agents

AGENTS = [None, HitAgent, StandAgent, RandomAgent, OptimalAgent]


## Converts the full Blackjack.Gamestate class to a simpler MDP
## containing player total(s), dealer card, and result (money)
## (extra code is pythonese to allow this class to be used as a dict key)
class SimpleState():  
    def __init__(self, state):
        bj = Blackjack()
        self.playerCards = bj.Count(state.playerCards)
        self.dealerCard = state.dealerCards[0].rank
        self.result = state.money
        v = "%s;%s;%s" % (self.result, self.dealerCard, self.playerCards)
        self.v = hash(v)
    def __hash__(self):
        if not hasattr(self,"hashvalue"):
            self.hashvalue=self.v
        return self.hashvalue
    def __cmp__(self,x):
        if self.v==x.v:
            return 0
        if self.v<x.v:
            return -1
        return 1


## Meet our learning agent, Jack
class Jack:  
    def __init__(self, learnRate = 0.4, discountRate = .9, greedy = .001):
        self.bj = Blackjack()
        self.learnRate = learnRate  ## how much the Q-values are modified each iteration
        self.discountRate = discountRate  ## how much to discount 'looking ahead'
        self.greedy = greedy  ## how often the agent should make a random choice
        self.updates = 0
        self.q = dict()
        ## q is the big player in this program:
        ## the keys are in the form of a SimpleState class,
        ## representing a game situation as a partial Markov Decision Process.
        ## The corresponding values are dictionary variables, with each action
        ## (in this case, just HIT and STAND) attached to the agent's current
        ## best guess as to the value of each action.
        ## Thus, e.g. self.q[some SimpleState][some action] will return
        ## that current q-value.


    ## Returns initial setting for an entry in the Q table if none exists
    ## Experimenting with different values causes different results
    def InitQValue(self):
        #return {HIT:random.random()*2-1,STAND:random.random()*2-1}
        #return {HIT:0.1,STAND:0.0}
        #return {HIT:0.0,STAND:0.1}
        return {HIT:0.0,STAND:0.0}


    ## Returns what it believes to be the best action based on the Q table,
    ## unless it's one of those greedy (explorative) times
    def MaxQ(self, q):
        #greedy = self.greedy - self.greedy * self.updates / 300000
        ## ^^ experimenting with decreasing chance of exploration over time
        if random.random() < self.greedy:
            return random.choice(q.keys())

        best = max(q.values())
        for a in q.keys():
            if q[a] == best:
                action = a
                break

        return action


    ## Given a state, returns an action, based on our action policy
    ## ('Hard-greedy', chooses action corresponding to max Q value almost always)
    def QAction(self, state):
        ss = SimpleState(state)

        if not self.q.has_key(ss):
            self.q[ss] = self.InitQValue()
        return self.MaxQ(self.q[ss])


    ## This method is the body of the mind, so to speak.
    ## It's called after taking an action and seeing the result, and
    ## the original state, the resulting state, and the action it took
    ## are plugged into the Q algorithm to update the table appropriately.
    def QUpdate(self, state, newState, action):
        result = newState.money  # our basic feedback vector is whether we
                                 # made or lost money

        ss = SimpleState(state)
        if not self.q.has_key(ss):
            self.q[ss] = self.InitQValue()

        nss = SimpleState(newState)
        if not self.q.has_key(nss):
            self.q[nss] = self.InitQValue()

        qValue = self.q[ss][action]
        self.updates += 1

        #self.learnRate = 1.0 / (self.updates / 1000 + 1)
        ## ^^ experimenting with a 1/n style alpha parameter to ensure convergence
        self.q[ss][action] = qValue + self.learnRate * (result +
        #    self.discountRate * max(self.q[nss].values()) - qValue)
        ## ^^ above line is technically the Q algorithm, while the below
        ##    is known as SARSA... but they perform nearly identically
            self.discountRate * self.q[nss][self.QAction(newState)] - qValue)

        ## Either way, the learning mechanism is wholely in that one
        ## statement above... i.e.,
        ## Q(s,a) <-- Q(s,a)+alpha(r+lambda*Q(s',t')-Q(s,a))


## Runs one of the four comparison agents through n hands, spitting out data
def Agent(agent, numHands):  
    bj = Blackjack()
    money, hands, groupCounter, action = 0.0, 0, 0.0, None

    print "Initializing agent..."

    for n in range(0,numHands):
        state = bj.NewHand()
        while state.result == IN_PLAY:
            action = agent(state)
            state = bj.Action(action, state)
        hands += 1
        money += state.money
        groupCounter += state.money

        if hands % (numHands / 10) == 0:
            print "money = %i, hands = %i, average = %f" % (money, hands, groupCounter/(numHands/10.0))
            groupCounter = 0.0

    print "Overall average: %f\n" % (1.0*money/hands)
    main()    


## Delegates off to one of the four comparison agents, or runs Jack, then loops
def main():  
    jack, bj = Jack(), Blackjack()
    money, hands, groupCounter, action = 0, 0, 0.0, None

    print "Which agent which you like to use?"
    print "  0) Exit demo"
    print "  1) Always hits"
    print "  2) Always stands"
    print "  3) Random action"
    print "  4) Optimal play"
    print "  5) Q-learning"
    print "  6) Q-learning limited"

    choice = input()

    if choice == 0: return

    print "How many hands?",
    numHands = input()

    if choice < 5:
        Agent(AGENTS[choice], numHands)
        return

    if choice == 5:
        print "Initializing agent..."
    else:
        print "How many hands after that?",
        testHands = input()
        print "Training",

    for n in range(0,numHands):
        #random.seed(1)  ## useful for debugging if you want the same hand over and over
        state = bj.NewHand()
        newState = state
        while newState.result == IN_PLAY:
            action = jack.QAction(state)  # get Jack's choice
            newState = bj.Action(action, state)  # play the move and get the result
            jack.QUpdate(state, newState, action)  # have Jack update his Q table accordingly
            state = newState
        hands += 1
        money += newState.money
        groupCounter += newState.money

        if (hands % (numHands / 10) == 0):
            if choice == 5:
                print "money = %i, hands = %i, entries = %i, average = %f" % (money, hands, len(jack.q), groupCounter/(numHands/10.0))
                groupCounter = 0.0
            else:
                print ".",

    if choice == 5:
        print "Overall average: %f\n" % (1.0*money/hands)
        main()
    else:
        print "done."
        Agent(jack.QAction, testHands)


        ## Code below can be useful for hand-by-hand debugging    
        #print "result = %s, money = %i" % (RESULTS[newState.result], money)
        #print "dealerCard = %s, playerCards = %s" % (state.dealerCards[0], state.playerCards)
        #for entry in jack.q:
        #    print entry.playerCards, jack.q[entry]
        #raw_input()

main()