import sys
import logging
import itertools
import numpy as np
import pandas as pd
logging.basicConfig(level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
stream=sys.stdout, datefmt='%H:%M:%S')
discount = 1.
df = pd.DataFrame(0., index=range(-4, 5), columns=[])
df["h(left)"] = 0.85 ** df.index.to_series() # preference for S = left
df["h(right)"] = 0.15 ** df.index.to_series() # preference for S = right
df["p(left)"] = df["h(left)"] / (df["h(left)"] + df["h(right)"]) # b(left)
df["p(right)"] = df["h(right)"] / (df["h(left)"] + df["h(right)"]) # b(right)
df["omega(left)"] = 0.85 * df["p(left)"] + 0.15 * df["p(right)"]
# omega(left|b, listen)
df["omega(right)"] = 0.15 * df["p(left)"] + 0.85 * df["p(right)"]
# omega(right|b, listen)
df["r(left)"] = 10. * df["p(left)"] - 100. * df["p(right)"] # r(b, left)
df["r(right)"] = -100. * df["p(left)"] + 10. * df["p(right)"] # r(b, right)
df["r(listen)"] = -1. # r(b, listen)
df[["q(left)", "q(right)", "q(listen)", "v"]] = 0. # values
for i in range(300):
df["q(left)"] = df["r(left)"]
df["q(right)"] = df["r(right)"]
df["q(listen)"] = df["r(listen)"] + discount * (
df["omega(left)"] * df["v"].shift(-1).fillna(10) +
df["omega(right)"] * df["v"].shift(1).fillna(10))
df["v"] = df[["q(left)", "q(right)", "q(listen)"]].max(axis=1)
df["action"] = df[["q(left)", "q(right)", "q(listen)"]].values.argmax(axis=1)
df
h(left) | h(right) | p(left) | p(right) | omega(left) | omega(right) | r(left) | r(right) | r(listen) | q(left) | q(right) | q(listen) | v | action | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
-4 | 1.915686 | 1975.308642 | 0.000969 | 0.999031 | 0.150678 | 0.849322 | -99.893424 | 9.893424 | -1.0 | -99.893424 | 9.893424 | 8.909410 | 9.893424 | 1 |
-3 | 1.628333 | 296.296296 | 0.005466 | 0.994534 | 0.153826 | 0.846174 | -99.398785 | 9.398785 | -1.0 | -99.398785 | 9.398785 | 8.578243 | 9.398785 | 1 |
-2 | 1.384083 | 44.444444 | 0.030201 | 0.969799 | 0.171141 | 0.828859 | -96.677852 | 6.677852 | -1.0 | -96.677852 | 6.677852 | 7.844483 | 7.844483 | 2 |
-1 | 1.176471 | 6.666667 | 0.150000 | 0.850000 | 0.255000 | 0.745000 | -83.500000 | -6.500000 | -1.0 | -83.500000 | -6.500000 | 6.159919 | 6.159919 | 2 |
0 | 1.000000 | 1.000000 | 0.500000 | 0.500000 | 0.500000 | 0.500000 | -45.000000 | -45.000000 | -1.0 | -45.000000 | -45.000000 | 5.159919 | 5.159919 | 2 |
1 | 0.850000 | 0.150000 | 0.850000 | 0.150000 | 0.745000 | 0.255000 | -6.500000 | -83.500000 | -1.0 | -6.500000 | -83.500000 | 6.159919 | 6.159919 | 2 |
2 | 0.722500 | 0.022500 | 0.969799 | 0.030201 | 0.828859 | 0.171141 | 6.677852 | -96.677852 | -1.0 | 6.677852 | -96.677852 | 7.844483 | 7.844483 | 2 |
3 | 0.614125 | 0.003375 | 0.994534 | 0.005466 | 0.846174 | 0.153826 | 9.398785 | -99.398785 | -1.0 | 9.398785 | -99.398785 | 8.578243 | 9.398785 | 0 |
4 | 0.522006 | 0.000506 | 0.999031 | 0.000969 | 0.849322 | 0.150678 | 9.893424 | -99.893424 | -1.0 | 9.893424 | -99.893424 | 8.909410 | 9.893424 | 0 |
class State:
LEFT, RIGHT = range(2) # do not contain the terminate state
state_count = 2
states = range(state_count)
class Action:
LEFT, RIGHT, LISTEN = range(3)
action_count = 3
actions = range(action_count)
class Observation:
LEFT, RIGHT = range(2)
observation_count = 2
observations = range(observation_count)
# r(S,A): state x action -> reward
rewards = np.zeros((state_count, action_count))
rewards[State.LEFT, Action.LEFT] = 10.
rewards[State.LEFT, Action.RIGHT] = -100.
rewards[State.RIGHT, Action.LEFT] = -100.
rewards[State.RIGHT, Action.RIGHT] = 10.
rewards[:, Action.LISTEN] = -1.
# p(S'|S,A): state x action x next_state -> probability
transitions = np.zeros((state_count, action_count, state_count))
transitions[State.LEFT, :, State.LEFT] = 1.
transitions[State.RIGHT, :, State.RIGHT] = 1.
# o(O|A,S'): action x next_state x next_observation -> probability
observes = np.zeros((action_count, state_count, observation_count))
observes[Action.LISTEN, Action.LEFT, Observation.LEFT] = 0.85
observes[Action.LISTEN, Action.LEFT, Observation.RIGHT] = 0.15
observes[Action.LISTEN, Action.RIGHT, Observation.LEFT] = 0.15
observes[Action.LISTEN, Action.RIGHT, Observation.RIGHT] = 0.85
# sample beliefs
belief_count = 15
beliefs = list(np.array([p, 1-p]) for p in np.linspace(0, 1, belief_count))
action_alphas = {action: rewards[:, action] for action in actions}
horizon = 10
# initialize alpha vectors
alphas = [np.zeros(state_count)]
ss_state_value = {}
for t in reversed(range(horizon)):
logging.info("t = %d", t)
# Calculate alpha vector for each (action, observation, alpha)
action_observation_alpha_alphas = {}
for action in actions:
for observation in observations:
for alpha_idx, alpha in enumerate(alphas):
action_observation_alpha_alphas \
[(action, observation, alpha_idx)] = \
discount * np.dot(transitions[:, action, :], \
observes[action, :, observation] * alpha)
# Calculate alpha vector for each (belief, action)
belief_action_alphas = {}
for belief_idx, belief in enumerate(beliefs):
for action in actions:
belief_action_alphas[(belief_idx, action)] = \
action_alphas[action].copy()
def dot_belief(x):
return np.dot(x, belief)
for observation in observations:
belief_action_observation_vector = max([
action_observation_alpha_alphas[
(action, observation, alpha_idx)]
for alpha_idx, _ in enumerate(alphas)], key=dot_belief)
belief_action_alphas[(belief_idx, action)] += \
belief_action_observation_vector
# Calculate alpha vector for each belief
belief_alphas = {}
for belief_idx, belief in enumerate(beliefs):
def dot_belief(x):
return np.dot(x, belief)
belief_alphas[belief_idx] = max([
belief_action_alphas[(belief_idx, action)]
for action in actions], key=dot_belief)
alphas = belief_alphas.values()
# dump state_values for display only
df_belief = pd.DataFrame(beliefs, index=range(belief_count), columns=states)
df_alpha = pd.DataFrame(alphas, index=range(belief_count), columns=states)
ss_state_value[t] = (df_belief * df_alpha).sum(axis=1)
logging.info("state_value =")
pd.DataFrame(ss_state_value)
00:00:00 [INFO] t = 9 00:00:00 [INFO] t = 8 00:00:00 [INFO] t = 7 00:00:00 [INFO] t = 6 00:00:00 [INFO] t = 5 00:00:00 [INFO] t = 4 00:00:00 [INFO] t = 3 00:00:00 [INFO] t = 2 00:00:00 [INFO] t = 1 00:00:00 [INFO] t = 0 00:00:00 [INFO] state_value =
9 | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 |
1 | 2.142857 | 5.621429 | 5.822143 | 6.365429 | 6.560952 | 6.712210 | 6.783735 | 6.832759 | 6.856939 | 6.873232 |
2 | -1.000000 | 3.892857 | 3.642857 | 5.297768 | 5.092529 | 5.853556 | 5.752700 | 6.058425 | 6.015260 | 6.131608 |
3 | -1.000000 | 2.164286 | 2.767143 | 4.586589 | 4.696057 | 5.347428 | 5.374806 | 5.617213 | 5.607098 | 5.712664 |
4 | -1.000000 | 0.435714 | 2.720000 | 3.875411 | 4.499933 | 4.918538 | 5.120284 | 5.261618 | 5.328633 | 5.375889 |
5 | -1.000000 | -1.292857 | 2.720000 | 3.164232 | 4.315830 | 4.490510 | 4.865762 | 4.925356 | 5.050169 | 5.070124 |
6 | -1.000000 | -2.000000 | 2.720000 | 2.713518 | 4.226650 | 4.097497 | 4.802944 | 4.717488 | 5.014304 | 4.972858 |
7 | -1.000000 | -2.000000 | 2.720000 | 2.465000 | 4.226650 | 4.028547 | 4.802944 | 4.704370 | 5.014304 | 4.972858 |
8 | -1.000000 | -2.000000 | 2.720000 | 2.713518 | 4.226650 | 4.097497 | 4.802944 | 4.717488 | 5.014304 | 4.972858 |
9 | -1.000000 | -1.292857 | 2.720000 | 3.164232 | 4.315830 | 4.490510 | 4.865762 | 4.925356 | 5.050169 | 5.070124 |
10 | -1.000000 | 0.435714 | 2.720000 | 3.875411 | 4.499933 | 4.918538 | 5.120284 | 5.261618 | 5.328633 | 5.375889 |
11 | -1.000000 | 2.164286 | 2.767143 | 4.586589 | 4.696057 | 5.347428 | 5.374806 | 5.617213 | 5.607098 | 5.712664 |
12 | -1.000000 | 3.892857 | 3.642857 | 5.297768 | 5.092529 | 5.853556 | 5.752700 | 6.058425 | 6.015260 | 6.131608 |
13 | 2.142857 | 5.621429 | 5.822143 | 6.365429 | 6.560952 | 6.712210 | 6.783735 | 6.832759 | 6.856939 | 6.873232 |
14 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 | 10.000000 |