Skip to content

Commit

Permalink
Merge pull request #26 from AutoResearch/feat/q-learner
Browse files Browse the repository at this point in the history
added choice probabilities
  • Loading branch information
musslick authored Jul 15, 2024
2 parents 55b3abc + fb61cc7 commit ab7d890
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/autora/experiment_runner/synthetic/psychology/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def run_AgentQ(rewards):
num_trials = rewards.shape[0]

y = np.zeros(rewards.shape)
choice_proba = np.zeros(rewards.shape)

agent = AgentQ(
alpha=learning_rate,
Expand All @@ -204,29 +205,40 @@ def run_AgentQ(rewards):
)

for i in range(num_trials):
proba = agent.get_choice_probs()
choice = agent.get_choice()
y[i, choice] = 1
choice_proba[i] = proba
reward = rewards[i, choice]
agent.update(choice, reward)
return y
return y, choice_proba

def run(
conditions: Union[pd.DataFrame, np.ndarray, np.recarray],
random_state: Optional[int] = None,
return_choice_probabilities = False,
):

if random_state is not None:
np.random.seed(random_state)

Y = list()
Y_proba = list()
if isinstance(conditions, pd.DataFrame):
for index, session in conditions.iterrows():
rewards = session[0]
Y.append(run_AgentQ(rewards))
choice, choice_proba = run_AgentQ(rewards)
Y.append(choice)
Y_proba.append(choice_proba)
elif isinstance(conditions, np.ndarray):
Y.append(run_AgentQ(conditions))
choice, choice_proba = run_AgentQ(conditions)
Y.append(choice)
Y_proba.append(choice_proba)

return Y
if return_choice_probabilities:
return Y, Y_proba
else:
return Y

ground_truth = partial(run)

Expand All @@ -245,3 +257,4 @@ def domain():
)
return collection


0 comments on commit ab7d890

Please sign in to comment.