Reinforcement Learning¶
Bandit tasks are used to study human reinforcement learning behavior. Here, we will implement a simple two-armed bandit task. We then run the same task on a language model specifically trained on tasks like these (centaur) and compare the results.
Two-Armed Bandit Task¶
Imports¶
from sweetbean import Block, Experiment
from sweetbean.stimulus import Bandit, Text
from sweetbean.variable import (
DataVariable,
FunctionVariable,
SharedVariable,
SideEffect,
TimelineVariable,
)
Timeline¶
Here, we slowly change the values of bandit_1
10 to 0 and for bandit_2
in reverse order from 0 to 10.
timeline = []
for i in range(11):
timeline.append(
{
"bandit_1": {"color": "orange", "value": 10 - i},
"bandit_2": {"color": "blue", "value": i},
}
)
Implementation¶
We also keep track of the score with a shared variable to present it between the bandit tasks.
bandit_1 = TimelineVariable("bandit_1")
bandit_2 = TimelineVariable("bandit_2")
score = SharedVariable("score", 0)
value = DataVariable("value", 0)
update_score = FunctionVariable(
"update_score", lambda sc, val: sc + val, [score, value]
)
update_score_side_effect = SideEffect(score, update_score)
bandit_task = Bandit(
bandits=[bandit_1, bandit_2],
side_effects=[update_score_side_effect],
)
score_text = FunctionVariable("score_text", lambda sc: f"Score: {sc}", [score])
show_score = Text(duration=2000, text=score_text)
trial_sequence = Block([bandit_task, show_score], timeline=timeline)
experiment = Experiment([trial_sequence])
Export the experiment to a html file and run it in the browser.
experiment.to_html("bandit.html", path_local_download="bandit.json")
Results¶
After running bandit.html, there should be a file called bandit.json
in the download directory. You can open the file in your browser to see the results. First, we process it so that it only contains relevant data:
import json
from sweetbean.data import process_js, get_n_responses, until_response
with open("bandit.json") as f:
data_raw = json.load(f)
data = process_js(data_raw)
We can now get the number of times a response was made and get the data until before the third response:
n_responses = get_n_responses(data)
data_third_response = until_response(data, 3)
Experiment on language model¶
With the partial data, we can now run the experiment up to that point and then run the rest of the experiment on language input. To test this, we run it manually:
data_input, _ = experiment.run_on_language(input, data=data_third_response)
print(data_input)
Instead of running the experiment manually, we can also use a large language model. In this case, we use centaur. This model has been trained on similar tasks as the two-armed bandit task. We can use the model to predict the next response and then run the experiment on the model. We can then compare the results with the actual data.
First, we need to install unsloth
!pip install unsloth "xformers==0.0.28.post2"
Then, we load the model:
from unsloth import FastLanguageModel
import transformers
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "marcelbinz/Llama-3.1-Centaur-8B-adapter",
max_seq_length = 32768,
dtype = None,
load_in_4bit = True,
)
FastLanguageModel.for_inference(model)
pipe = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
trust_remote_code=True,
pad_token_id=0,
do_sample=True,
temperature=1.0,
max_new_tokens=1,
)
Finally, we create a function to pass into the experiment:
def generate(prompt):
return pipe(prompt)[0]['generated_text'][len(prompt):]
We can use this to run the full experiment:
data_centaur_full = experiment.run_on_language(generate)
Or we can run the experiment from the third response
data_centaur_partial = experiment.run_on_language(generate, data=data_third_response)
# Print the data:
print(data_centaur_full)
print(data_centaur_partial)
print(data)
Comparison¶
We can compare the results of the actual data with the data from the language model. For example, we can compare the number overall scores reached by humans and the language model:
score_human = sum([d["value"] for d in data])
score_centaur = sum([d["value"] for d in data_centaur_full])
print(f"Score human: {score_human}")
print(f"Score centaur: {score_centaur}")
Conclusion¶
This notebook demonstrates how to run a simple bandit task via a website or a language model. The results can, for example, be compared to analyse the language model or can be used in fine-tuning the model.
SweetBean is also integrated in AutoRa, a platform for running the same experiments automatically via prolific. This allows for automatic data collection and analysis while using large language models either for prototyping, in finding good experimental design or for automatic fine-tuning.