ksridhar's picture
Upload model
8e8a936 verified
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from gymnasium import spaces
from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
from transformers import GPTNeoModel, GPTNeoPreTrainedModel
from transformers.modeling_outputs import ModelOutput
from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
import torch.nn.functional as F
from jat.configuration_jat import JatConfig
from jat.processing_jat import JatProcessor
from jat.modeling_jat import JatModel, compute_mse_loss, cyclic_expand_dim, JatOutput
from jat_regent.utils import build_index_vector, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_vector, myprint, L2dist, get_dist_stats, get_images_of_retrieved_obs, get_emb_transform_model_dim, get_optional_suffix
from jat_regent.atari_utils import convert_local_to_global_action, convert_global_to_local_action
from jat_regent.eval.rl import SEEN_TASK_NAME_TO_ENV_ID, UNSEEN_TASK_NAME_TO_ENV_ID
from PIL import Image
import os
from copy import deepcopy
from pytorch_msssim import ssim
import json
def cross_entropy_from_softmax(softmax_probs, targets, reduction="mean", epsilon=1e-9):
"""
Calculate the cross entropy loss given softmax_probs and targets.
:param softmax_probs: tensor containing softmax probabilities
:param targets: tensor containing the target classes (not one-hot encoded)
:return: cross entropy loss
"""
assert len(softmax_probs.shape) == 2, "softmax_probs should be of shape (batch_size, num_classes)"
assert len(targets.shape) == 1, "targets should be of shape (batch_size,)"
# Convert targets to one-hot encoding
targets_one_hot = F.one_hot(targets, num_classes=softmax_probs.shape[1]).float() # shape: (batch_size, num_classes)
# Calculate the cross entropy loss
softmax_probs = softmax_probs.clamp(min=epsilon, max=1-epsilon) # to avoid NaNs from log(0) and instabilities from log(1)
log_softmax_probs = softmax_probs.log() # safe to take log as softmax_probs are non-zero
loss = -torch.sum(targets_one_hot * log_softmax_probs, dim=1)
if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
elif reduction == "none":
return loss
else:
raise ValueError("reduction should be one of 'mean', 'sum', or 'none'")
def compute_ce_loss_from_softmax(
logits: FloatTensor, labels: torch.LongTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None
) -> FloatTensor:
"""
Compute the Cross Entropy (CE) loss between predicted logits and true class labels, considering valid timesteps.
Args:
logits (`FloatTensor` of shape `(batch_size, max_seq_len, [inner_size,] num_classes)`):
Predicted logits at the output of the model.
labels (`torch.LongTensor` of shape `(batch_size, max_seq_len, [inner_size,])`):
Ground truth class labels.
mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*):
Boolean mask indicating valid timesteps.
weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*):
Weights to be applied to the loss.
Returns:
loss (`FloatTensor` of shape `(,)`):
CE loss between predicted logits and true class labels.
"""
if mask is not None:
logits = logits[mask.bool()] # (Y, X, C)
labels = labels[mask.bool()] # (Y, X)
if weights is not None:
weights = weights[mask.bool()] # (Y,)
else:
logits = logits.flatten(end_dim=2) # (B, L, X, C) -> (B*L, X, C)
labels = labels.flatten(end_dim=1) # (B, L, X) -> (B*L, X)
if weights is not None:
weights = weights.flatten(end_dim=1) # (B, L) -> (B*L,)
loss = cross_entropy_from_softmax(logits.view(-1, logits.size(-1)), labels.view(-1), reduction="none") # (Y*X,) # we don't use F.cross_entropy here to avoid double softmax
loss = loss.view(labels.size()) # (Y, X)
loss = loss.mean(-1) # (Y,)
# Multiply the loss by the weights
if weights is not None:
loss = loss * weights # (Y,)
# Average the loss
loss = loss.mean()
return loss
def crazy_relu(x, beta):
return nn.LeakyReLU(beta)(x) - (1-beta) * nn.ReLU()(x-1)
class JatRegentModel(JatModel):
"""
Jat Regent model.
"""
def __init__(self, config: JatConfig) -> None:
super().__init__(config)
hidden_size = config.hidden_size
action_vocab_size = config.action_vocab_size
if config.ONLY_RL_TASKS:
self.single_discrete_decoder = nn.Linear(hidden_size, action_vocab_size, bias=False)
self.N = config.action_vocab_size
else:
self.N = config.vocab_size
self.multi_discrete_decoder = None # not needed
self.image_decoder = None # not needed
self.num_contexts = config.num_contexts # used in get_next_action() at evaluation in an env only
self.lamda = config.lamda # used in get_next_action() at evaluation in an env only
self.use_global_atari_actions = config.use_global_atari_actions
self.dist_multipliers = {'mujoco': config.mujoco_dist_multiplier, 'atari': config.atari_dist_multiplier}
self.dist_normalizer = config.dist_normalizer
self.atari_dist_type = config.atari_dist_type
self.use_atari_embeddings = config.use_atari_embeddings
self.finetune_num_demos = config.finetune_num_demos if hasattr(config, 'finetune_num_demos') else None
if self.use_atari_embeddings:
self.image_encoder = None
self.emb_dim_full = (512,)
# print number of parameters
num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
myprint(f"number of parameters: {num_params / 1e6:.4f}M")
def retrieval_setup(self,
task,
dataset,
num_demos, # to retrieve from
device,
batch_size_retrieval=16, # for atari envs on gpu
nb_cores_autofaiss=8, # for vector obs envs on cpu cores
):
# setup
rew_key, attn_key, obs_key, act_key, B, obs_dim, act_dim = get_task_info(task)
extra_key = 'discrete_RandP_action_logits' if task.startswith("atari") or task.startswith("babyai") else 'continuous_RandP_actions'
optional_suffix = get_optional_suffix(task, self.atari_dist_type, self.finetune_num_demos)
mean_dist, std_dist, max_dist, p80, p85, p90, p95, p99 = get_dist_stats(task=task, optional_suffix=optional_suffix)
# get embedding model
if task.startswith("atari"):
self.emb_transform, self.emb_model, emb_dim, self.emb_model_full = get_emb_transform_model_dim(self.atari_dist_type, self.device, return_emb_weights=True)
obs_dim = emb_dim # overwrite for atari_dist_type
kwargs = {'B': B,
'obs_dim': obs_dim,
'attn_key': attn_key,
'obs_key': obs_key,
'device': device,
'task': task,
'batch_size_retrieval': batch_size_retrieval,
'nb_cores_autofaiss': nb_cores_autofaiss,
'verbose': False,
'atari_dist_type': self.atari_dist_type,
}
raw_obs_dim = obs_dim
if task.startswith("atari"): # overwrite raw_obs_dim because raw obs in atari are (4, 84, 84) and raw obs in babyai have 64 extra dim
raw_obs_dim = (4, 84, 84)
elif task.startswith("babyai"):
raw_obs_dim = (obs_dim[0]+64,)
# save
self.task = task
self.dataset = dataset
self.obs_key = obs_key
self.act_key = act_key
self.rew_key = rew_key
self.attn_key = attn_key
self.obs_dim = obs_dim
self.act_dim = act_dim
self.extra_key = extra_key
self.kwargs = kwargs
self.raw_obs_dim = raw_obs_dim
self.max_dist = max_dist
self.mean_dist = mean_dist
self.std_dist = std_dist
self.p80, self.p85, self.p90, self.p95, self.p99 = p80, p85, p90, p95, p99
self.dist_normalizer_value = {'std': std_dist, 'max': max_dist, 'p80': p80, 'p85': p85, 'p90': p90, 'p95': p95, 'p99': p99}[self.dist_normalizer]
if self.dist_normalizer_value == 0.0: self.dist_normalizer_value = 1.0
# for retrieval,
all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs, all_datarows_dict = collect_all_data(dataset, task, obs_key, num_demos, return_datarows_dict=True, atari_dist_type=self.atari_dist_type)
if task.startswith("babyai"):
# for each mission in task,
self.all_indices = {}
self.knn_index = {}
for mission_idx, mission in enumerate(all_row_idxs.keys()):
# create index, collect subset of data that we can retrieve from
myprint(('*'*50) + f'{mission=} - {mission_idx+1}/{len(all_row_idxs.keys())}')
self.all_indices[mission], self.knn_index[mission] = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG[mission],
all_attn_masks_OG=all_attn_masks_OG[mission],
all_row_idxs=all_row_idxs[mission],
kwargs=kwargs)
else:
# create index, collect subset of data that we can retrieve from
self.all_indices, self.knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG,
all_attn_masks_OG=all_attn_masks_OG,
all_row_idxs=all_row_idxs,
kwargs=kwargs)
# for retrieval inside retrieve()
self.datarows = all_datarows_dict
# # for checking if first env state is similar to retrieval episode's first states
# if task.startswith("mujoco"):
# local_path = f"dataset_jat_regent/{task}"
# with open(f"{local_path}/eps_2_rows_tokenized.json", 'r') as f:
# eps_2_rows_tokenized = json.load(f)
# eps_2_rows_tokenized = {int(k): v for k, v in eps_2_rows_tokenized.items()}
# row_idxs_of_first_state_of_demos = [eps_2_rows_tokenized[eps][0] for eps in range(num_demos)]
# self.first_states_of_demos = [np.array(dataset['train'][row_idx][obs_key][0]) for row_idx in row_idxs_of_first_state_of_demos]
# else:
# self.first_states_of_demos = None
def output_rl(
self,
transformer_outputs,
continuous_observations: Optional[FloatTensor] = None,
discrete_observations: Optional[LongTensor] = None,
image_observations: Optional[FloatTensor] = None,
continuous_actions: Optional[FloatTensor] = None,
discrete_actions: Optional[LongTensor] = None,
rewards: Optional[FloatTensor] = None,
attention_mask: Optional[BoolTensor] = None,
return_loss: bool = True,
return_dict: Optional[bool] = None,
loss_weight: Optional[FloatTensor] = None,
exp_lamda_distances: Optional[FloatTensor] = None,
continuous_RandP_actions: Optional[FloatTensor] = None,
discrete_RandP_action_logits: Optional[FloatTensor] = None,
):
hidden_states = transformer_outputs.last_hidden_state
loss, observation_loss, action_loss = None, None, None
# Observations
assert rewards is not None
observations_mask = attention_mask[:, 1::2] if attention_mask is not None else None
assert self.observation_loss_coef == 0.0, f'{self.observation_loss_coef=} should be 0.0 as we are not predicting observations!'
# warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.")
pred_observations = None
observation_loss = 0.0
# Actions
actions_mask = attention_mask[:, ::2] if attention_mask is not None else None
if continuous_actions is not None:
act_size = continuous_actions.shape[-1]
continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size)
continuous_RandP_actions = cyclic_expand_dim(continuous_RandP_actions, self.config.max_continuous_size)
init_pred_actions = self.continuous_decoder(hidden_states[:, ::2])
pred_actions = self.continuous_action_interpolation(init_pred_actions, exp_lamda_distances, continuous_RandP_actions, beta=0.0)
if return_loss:
action_loss = compute_mse_loss(pred_actions, continuous_actions, actions_mask, weights=loss_weight) # loss_weight is usually 50 for metaworld, 10 for mujoco (except two tasks where it is 20, 50), 1 for the rest!
pred_actions = pred_actions[..., :act_size]
elif discrete_actions is not None:
init_pred_actions = self.single_discrete_decoder(hidden_states[:, ::2])
pred_actions = self.discrete_action_interpolation(init_pred_actions, exp_lamda_distances, discrete_RandP_action_logits, beta=0.0)
if return_loss:
action_loss = compute_ce_loss_from_softmax(pred_actions, discrete_actions, actions_mask, weights=loss_weight)
# Return output
if return_loss:
loss = self.observation_loss_coef * observation_loss + self.action_loss_coef * action_loss
if not return_dict:
output = (pred_observations, pred_actions) + transformer_outputs[1:]
return ((loss, observation_loss, action_loss) + output) if loss is not None else output
return JatOutput(
loss=loss,
observation_loss=observation_loss,
action_loss=action_loss,
pred_observations=pred_observations,
pred_actions=pred_actions,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def shifted_crazy_relu(self, x, beta):
return 2 * crazy_relu(0.5*(x+1), beta) - 1
def continuous_action_interpolation(self, init_pred_actions, exp_lamda_distances, continuous_RandP_actions, beta=0.0):
batch_size, max_seq_len, act_size = init_pred_actions.shape
assert (init_pred_actions.shape == (batch_size, max_seq_len, act_size) and
exp_lamda_distances.shape == (batch_size, max_seq_len, 1) and
continuous_RandP_actions.shape == (batch_size, max_seq_len, act_size)), f'{init_pred_actions.shape=}, {exp_lamda_distances.shape=}, {continuous_RandP_actions.shape=}, {(batch_size, max_seq_len, act_size)=}'
""" MCNN interpolation (https://arxiv.org/abs/2310.06171) """
act_fn = self.shifted_crazy_relu
final_actions = exp_lamda_distances * continuous_RandP_actions + 10.0 * (1 - exp_lamda_distances) * act_fn(init_pred_actions, beta=beta)
return final_actions
def discrete_action_interpolation(self, init_pred_actions, exp_lamda_distances, discrete_RandP_action_logits, beta=0.0):
batch_size, max_seq_len, action_vocab_size = init_pred_actions.shape
assert (init_pred_actions.shape == (batch_size, max_seq_len, action_vocab_size) and
exp_lamda_distances.shape == (batch_size, max_seq_len, 1) and
discrete_RandP_action_logits.shape == (batch_size, max_seq_len, action_vocab_size)), f'{init_pred_actions.shape=}, {exp_lamda_distances.shape=}, {discrete_RandP_action_logits.shape=}, {(batch_size, max_seq_len, action_vocab_size)=}'
""" MCNN-like interpolation """
# print(f'{torch.round(discrete_RandP_action_logits[:, -1],decimals=2)=}')
# print(f'{torch.round(F.softmax(init_pred_actions, dim=-1)[:, -1],decimals=2)=}')
# print(f'{torch.round(exp_lamda_distances[:, -1],decimals=2)=}')
# print(f'first term: {torch.round((exp_lamda_distances * discrete_RandP_action_logits)[:, -1],decimals=2)}')
# print(f'second term: {torch.round(((1 - exp_lamda_distances) * F.softmax(init_pred_actions, dim=-1))[:, -1],decimals=2)}')
final_actions = exp_lamda_distances * discrete_RandP_action_logits + (1 - exp_lamda_distances) * F.softmax(init_pred_actions, dim=-1)
return final_actions
# Copied the forward function from the Parent class with the addition of the last 3 args in the input args and in output_rl args
def forward(
self,
input_ids: Optional[LongTensor] = None,
pixel_values: Optional[FloatTensor] = None,
continuous_observations: Optional[FloatTensor] = None,
discrete_observations: Optional[LongTensor] = None,
image_observations: Optional[FloatTensor] = None,
continuous_actions: Optional[FloatTensor] = None,
discrete_actions: Optional[LongTensor] = None,
rewards: Optional[FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None,
attention_mask: Optional[BoolTensor] = None,
token_type_ids: Optional[LongTensor] = None,
position_ids: Optional[LongTensor] = None,
return_loss: bool = True,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
loss_weight: Optional[FloatTensor] = None,
exp_lamda_distances: Optional[FloatTensor] = None,
continuous_RandP_actions: Optional[FloatTensor] = None,
discrete_RandP_action_logits: Optional[FloatTensor] = None,
) -> JatOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Textual tasks
if input_ids is not None or pixel_values is not None:
inputs_embeds, attention_mask = self.embed_textual(input_ids, pixel_values, attention_mask)
# RL tasks
elif (
continuous_observations is not None or discrete_observations is not None or image_observations is not None
):
inputs_embeds, attention_mask = self.embed_rl(
continuous_observations,
discrete_observations,
image_observations,
continuous_actions,
discrete_actions,
rewards,
attention_mask,
)
else:
raise ValueError("Input not provided.")
# Pass through transformer
transformer_outputs = self.transformer(
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if input_ids is not None or pixel_values is not None:
return self.output_textual(transformer_outputs, input_ids, attention_mask, return_loss, return_dict)
else:
return self.output_rl(
transformer_outputs,
continuous_observations,
discrete_observations,
image_observations,
continuous_actions,
discrete_actions,
rewards,
attention_mask,
return_loss,
return_dict,
loss_weight,
exp_lamda_distances,
continuous_RandP_actions,
discrete_RandP_action_logits,
)
def reset_rl(self):
self.steps = 0
def process(
self,
processor: JatProcessor,
continuous_observation: Optional[List[float]] = None,
discrete_observation: Optional[List[int]] = None,
text_observation: Optional[str] = None,
image_observation: Optional[np.ndarray] = None,
action_space: Union[spaces.Box, spaces.Discrete] = None,
reward: Optional[float] = None,
deterministic: bool = True,
context_window: Optional[int] = None,
):
# Get the maximum sequence length
max_length = self.config.max_position_embeddings // 2
# Get the maximum sequence length
### see script/train_jat.py > L161.
### None ==> value set to 512 in jat/processing_jat.py > L354 and then // 2 in L355.
### weirdly, the value in script/eval_jat.py is set as 256 so it will be // 2 again in L355.
# max_length = 64 if self.task.startswith("atari") else None
# Convert everything to lists
def to_list(x):
return x.tolist() if isinstance(x, np.ndarray) else x
continuous_observation = to_list(continuous_observation)
discrete_observation = to_list(discrete_observation)
# get babyai mission within task
if self.task.startswith("babyai"):
mission = deepcopy(text_observation)
assert mission in self.knn_index.keys(), f'{mission=} should be in {self.knn_index.keys()=}'
# Add a fake action to the end of the sequence
if isinstance(action_space, spaces.Box):
fake_continuous_action = [0.0 for _ in range(action_space.shape[0])]
fake_discrete_action = None
elif isinstance(action_space, spaces.Discrete):
fake_continuous_action = None
fake_discrete_action = 0
continuous_observations = [continuous_observation] if continuous_observation is not None else None
discrete_observations = [discrete_observation] if discrete_observation is not None else None
text_observations = [text_observation] if text_observation is not None else None
image_observations = [image_observation] if image_observation is not None else None
continuous_actions = [fake_continuous_action] if fake_continuous_action is not None else None
discrete_actions = [fake_discrete_action] if fake_discrete_action is not None else None
rewards = [reward] if reward is not None else [0.0]
# Add the batch dimension
continuous_observations = [continuous_observations] if continuous_observations is not None else None
discrete_observations = [discrete_observations] if discrete_observations is not None else None
text_observations = [text_observations] if text_observations is not None else None
image_observations = [image_observations] if image_observations is not None else None
continuous_actions = [continuous_actions] if continuous_actions is not None else None
discrete_actions = [discrete_actions] if discrete_actions is not None else None
rewards = [rewards]
# Process the inputs
processed = processor(
continuous_observations=continuous_observations,
discrete_observations=discrete_observations,
text_observations=text_observations,
image_observations=image_observations,
continuous_actions=continuous_actions,
discrete_actions=discrete_actions,
rewards=rewards,
truncation=True,
truncation_side="left",
max_length=max_length,
return_tensors="pt",
)
assert (((self.act_key == 'continuous_actions' and processed[self.act_key].shape == (1, 1, self.act_dim)) or # zeros
(self.act_key == 'discrete_actions' and processed[self.act_key].shape == (1, 1))) and
processed[self.obs_key].shape == (1, 1, *self.raw_obs_dim) and
processed[self.rew_key].shape == (1, 1)), f'{processed[self.act_key].shape=}, {processed[self.obs_key].shape=}, {processed[self.rew_key].shape=}, {self.act_dim=}, {self.raw_obs_dim=}'
# save babyai mission
if self.task.startswith("babyai"):
processed['mission'] = mission
# save action_space and deterministic
processed['action_space'] = action_space
processed['deterministic'] = deterministic
return processed
def retrieve(
self,
all_processed: List[dict],
num_to_retrieve: int,
):
self.steps += 1
# Set num envs
num_envs = len(all_processed)
# Get obs from processed and make batch
row_of_obs = [all_processed[idx][self.obs_key][0].numpy() for idx in range(num_envs)]
row_of_obs = np.concatenate(row_of_obs)
assert row_of_obs.shape == (num_envs, *self.raw_obs_dim) and isinstance(row_of_obs, np.ndarray)
if self.task.startswith("atari"):
row_of_obs = process_row_of_obs_atari_full_without_mask(row_of_obs)
row_of_obs = torch.from_numpy(row_of_obs).to(self.device)
with torch.no_grad():
row_of_obs = self.emb_model(self.emb_transform(row_of_obs)).cpu().numpy()
elif self.task.startswith("babyai"):
row_of_obs = row_of_obs[:, :148] # removing last 64 text tokens
assert row_of_obs.shape == (num_envs, *self.obs_dim) and isinstance(row_of_obs, np.ndarray)
# Retrieve indices
if self.task.startswith("babyai"):
retrieved_indices = []
for idx in range(num_envs):
mission = all_processed[idx]['mission']
retrieved_indices_mission = retrieve_vector(row_of_obs=row_of_obs[idx:idx+1],
knn_index=self.knn_index[mission],
all_indices=self.all_indices[mission],
num_to_retrieve=num_to_retrieve,
kwargs=self.kwargs)
retrieved_indices.append(retrieved_indices_mission) # appending (1, 1, 2)
retrieved_indices = np.concatenate(retrieved_indices, axis=0)
assert retrieved_indices.shape == (num_envs, num_to_retrieve, 2)
else:
retrieved_indices = retrieve_vector(row_of_obs=row_of_obs,
knn_index=self.knn_index,
all_indices=self.all_indices,
num_to_retrieve=num_to_retrieve,
kwargs=self.kwargs)
# Return action
all_retrieved_act = []
all_retrieved_obs = []
all_retrieved_rew = []
for all_row_idx_and_i in retrieved_indices:
all_retrieved_act.append([])
all_retrieved_obs.append([])
all_retrieved_rew.append([])
for row_idx, i in all_row_idx_and_i:
datarow = self.datarows[int(row_idx)]
temp_a = datarow[self.act_key][int(i)]
if self.task.startswith("atari") and self.use_global_atari_actions:
temp_a = convert_local_to_global_action( temp_a, self.task )
all_retrieved_act[-1].append(temp_a)
all_retrieved_obs[-1].append(datarow[self.obs_key][int(i)])
all_retrieved_rew[-1].append(datarow[self.rew_key][int(i)])
return all_retrieved_act, all_retrieved_obs, all_retrieved_rew, row_of_obs
def get_distances(
self,
all_retrieved_obs: np.ndarray,
all_processed: List[dict],
query_obs: np.ndarray,
):
num_envs = len(all_processed)
# Process retrieved obs like in retrieve
num_contexts = all_retrieved_obs.shape[1] + 1
assert all_retrieved_obs.shape == (num_envs, num_contexts - 1, *self.raw_obs_dim) and isinstance(all_retrieved_obs, np.ndarray)
if self.task.startswith("atari"):
all_retrieved_obs = all_retrieved_obs.reshape(num_envs * (num_contexts - 1), *self.raw_obs_dim)
all_retrieved_obs = process_row_of_obs_atari_full_without_mask(all_retrieved_obs)
all_retrieved_obs = torch.from_numpy(all_retrieved_obs).to(self.device)
with torch.no_grad():
all_retrieved_obs = self.emb_model(self.emb_transform(all_retrieved_obs)).cpu().numpy()
all_retrieved_obs = all_retrieved_obs.reshape(num_envs, num_contexts - 1, *self.obs_dim)
elif self.task.startswith("babyai"):
all_retrieved_obs = all_retrieved_obs[:, :, :148]
assert all_retrieved_obs.shape == (num_envs, num_contexts - 1, *self.obs_dim) and isinstance(all_retrieved_obs, np.ndarray)
# Compute distances
all_distances = []
for idx in range(num_envs):
first_state = all_retrieved_obs[idx, 0:1]
distances = [0.0]
for i in range(1, num_contexts - 1):
curr_state = all_retrieved_obs[idx, i:i+1]
dist = L2dist(first_state, curr_state)
distances.append(dist)
curr_state = query_obs[idx:idx+1]
dist = L2dist(first_state, curr_state)
distances.append(dist)
all_distances.append(distances)
all_distances = np.array(all_distances)
assert all_distances.shape == (num_envs, num_contexts), f'{all_distances.shape=}, {num_envs=}, {num_contexts=}'
# distances: divide by std
all_distances = all_distances / self.dist_normalizer_value
if self.task.startswith("mujoco"):
all_distances = all_distances * self.dist_multipliers['mujoco']
elif self.task.startswith("atari"):
all_distances = all_distances * self.dist_multipliers['atari']
print(f'{self.dist_normalizer_value=}')
print(f'{all_distances=}')
return all_distances
@torch.no_grad()
def get_next_action(
self,
all_processed: List[dict],
return_retrieved_obs: bool = False,
):
num_envs = len(all_processed)
num_contexts = self.num_contexts
# Get the retrieved data
all_retrieved_act, all_retrieved_obs, all_retrieved_rew, row_of_obs = self.retrieve(all_processed, num_to_retrieve=num_contexts - 1)
if return_retrieved_obs:
all_retrieved_images = get_images_of_retrieved_obs(deepcopy(all_retrieved_obs), self.task)
# Get the distances
all_retrieved_obs = np.stack(all_retrieved_obs).astype(np.int32 if self.obs_key == 'discrete_observations' else np.float32)
assert all_retrieved_obs.shape == (num_envs, num_contexts - 1, *self.raw_obs_dim), f'{all_retrieved_obs.shape=}, {num_envs=}, {self.raw_obs_dim=}, {num_contexts-1=}'
all_distances = self.get_distances(all_retrieved_obs=all_retrieved_obs, all_processed=all_processed, query_obs=row_of_obs)
# Batch retrieved data
all_retrieved_act = np.stack(all_retrieved_act).astype(np.int32 if self.act_key == 'discrete_actions' else np.float32)
all_retrieved_rew = np.stack(all_retrieved_rew).astype(np.float32)
assert (((self.act_key == 'continuous_actions' and all_retrieved_act.shape == (num_envs, num_contexts - 1, self.act_dim)) or
(self.act_key == 'discrete_actions' and all_retrieved_act.shape == (num_envs, num_contexts - 1))) and
all_retrieved_rew.shape == (num_envs, num_contexts - 1)), f'{all_retrieved_act.shape=}, {all_retrieved_rew.shape=}, {num_envs=}, {self.act_dim=}, {self.raw_obs_dim=}, {num_contexts-1=}'
# Batch query data (already tensors) # query data is already int32/float32 after processing
all_query_act = torch.stack([all_processed[idx][self.act_key][0] for idx in range(num_envs)])
all_query_obs = np.stack([all_processed[idx][self.obs_key][0] for idx in range(num_envs)])
all_query_rew = torch.stack([all_processed[idx][self.rew_key][0] for idx in range(num_envs)])
assert (((self.act_key == 'continuous_actions' and all_query_act.shape == (num_envs, 1, self.act_dim)) or
(self.act_key == 'discrete_actions' and all_query_act.shape == (num_envs, 1))) and
all_query_obs.shape == (num_envs, 1, *self.raw_obs_dim) and
all_query_rew.shape == (num_envs, 1)), f'{all_query_act.shape=}, {all_query_obs.shape=}, {all_query_rew.shape=}, {num_envs=}, {self.act_dim=}, {self.raw_obs_dim=}'
# Collect attn
attn_weights = np.ones((num_envs, num_contexts)).astype(np.float32)
# Compute exp_lamda_distances
exp_lamda_distances = np.exp(-self.lamda * all_distances)[:, :, np.newaxis]
assert exp_lamda_distances.shape == (num_envs, num_contexts, 1), f'{exp_lamda_distances.shape=}, {num_envs=}, {num_contexts=}'
# Compute extra_key
all_extra_key = []
for idx in range(num_envs):
RandP_action = all_retrieved_act[idx, 0]
if self.extra_key == 'continuous_RandP_actions':
extra_key = [RandP_action for _ in range(num_contexts)]
elif self.extra_key == 'discrete_RandP_action_logits':
extra_key = []
for d in all_distances[idx]:
d = min(1.0, max(0.0, d))
curr_logits = [1.0/self.N * d for _ in range(self.N)]
curr_logits[RandP_action] = (1.0 + (self.N - 1.0)*(1.0 - d))/self.N
extra_key.append(curr_logits)
extra_key = np.stack(extra_key)
all_extra_key.append(extra_key)
all_extra_key = np.stack(all_extra_key).astype(np.float32)
if self.extra_key == 'continuous_RandP_actions':
assert all_extra_key.shape == (num_envs, num_contexts, self.act_dim), f'{all_extra_key.shape=}, {num_envs=}, {num_contexts=}, {self.act_dim=}'
elif self.extra_key == 'discrete_RandP_action_logits':
assert all_extra_key.shape == (num_envs, num_contexts, self.N), f'{all_extra_key.shape=}, {num_envs=}, {num_contexts=}, {self.N=}'
# Tensorify
all_retrieved_act = torch.from_numpy(all_retrieved_act)
all_retrieved_rew = torch.from_numpy(all_retrieved_rew)
attn_weights = torch.from_numpy(attn_weights).to(self.device)
exp_lamda_distances = torch.from_numpy(exp_lamda_distances).to(self.device)
all_extra_key = torch.from_numpy(all_extra_key).to(self.device)
# Concat retrieved and query batches
all_act = torch.cat([all_retrieved_act, all_query_act], dim=1).to(self.device)
all_obs = np.concatenate([all_retrieved_obs, all_query_obs], axis=1)
if self.use_atari_embeddings and self.task.startswith("atari"):
all_obs = all_obs.reshape(num_envs * num_contexts, *self.raw_obs_dim)
all_obs = process_row_of_obs_atari_full_without_mask(all_obs)
all_obs = torch.from_numpy(all_obs).to(self.device)
with torch.no_grad():
all_obs = self.emb_model_full(self.emb_transform(all_obs)).reshape(num_envs, num_contexts, *self.emb_dim_full)
else:
all_obs = torch.from_numpy(all_obs).to(self.device)
all_rew = torch.cat([all_retrieved_rew, all_query_rew], dim=1).to(self.device)
# Collect action_space, deterministic from all_processed
all_action_space = [all_processed[idx]['action_space'] for idx in range(num_envs)]
all_deterministic = [all_processed[idx]['deterministic'] for idx in range(num_envs)]
## assert that all action_space and deterministic are same for all envs
assert all([action_space == all_action_space[0] for action_space in all_action_space]), f'{all_action_space=}'
assert all([deterministic == all_deterministic[0] for deterministic in all_deterministic]), f'{all_deterministic=}'
## then just use first one!
action_space = all_action_space[0]
deterministic = all_deterministic[0]
# Forward pass
if self.use_atari_embeddings and self.task.startswith("atari"):
final_obs_key = 'continuous_observations'
else:
final_obs_key = self.obs_key
outputs = self.forward(**{final_obs_key: all_obs,
self.act_key: all_act,
self.rew_key: all_rew,
self.attn_key: attn_weights,
'exp_lamda_distances': exp_lamda_distances,
self.extra_key: all_extra_key,
}, return_loss=False)
# Return the predicted action
if self.act_key == 'continuous_actions':
self.last_continuous_action = outputs.pred_actions[:, -1].cpu().numpy()
assert self.last_continuous_action.shape == (num_envs, self.act_dim), f'{self.last_continuous_action.shape=}, {num_envs=}, {self.act_dim=}'
myprint(f'L2dist(RandP action, Pred action): {[L2dist(all_retrieved_act[idx, 0].cpu().numpy(), self.last_continuous_action[idx]) for idx in range(num_envs)]}')
self.last_continuous_action = list(self.last_continuous_action) # list of arrays
return self.last_continuous_action if not return_retrieved_obs else (self.last_continuous_action, all_retrieved_images)
elif self.act_key == 'discrete_actions':
act_n = self.config.action_vocab_size if (self.task.startswith('atari') and self.use_global_atari_actions) else action_space.n
logits = outputs.pred_actions[:, -1, : act_n]
assert logits.shape == (num_envs, act_n), f'{logits.shape=}, {num_envs=}, {act_n=}'
if deterministic:
# myprint(f'{all_extra_key[:, -1, : action_space.n]=}')
# myprint(f'{logits=}')
self.last_discrete_action = logits.argmax(dim=-1, keepdim=True).cpu().numpy().reshape(-1)
else: # sample
self.last_discrete_action = torch.multinomial(logits.softmax(dim=-1), num_samples=1).cpu().numpy().reshape(-1)
assert self.last_discrete_action.shape == (num_envs,), f'{self.last_discrete_action.shape=}, {num_envs=}'
self.last_discrete_action = list(self.last_discrete_action) # list of ints
myprint(f'RandP action: {all_retrieved_act[:, 0].cpu().numpy().tolist()} vs Pred action: {self.last_discrete_action}')
if self.task.startswith("atari") and self.use_global_atari_actions:
self.last_discrete_action = [convert_global_to_local_action(a, self.task) for a in self.last_discrete_action]
myprint(f'[IN LOCAL ACTION] RandP action: {[convert_global_to_local_action(a, self.task) for a in all_retrieved_act[:, 0].cpu().numpy().tolist()]} vs Pred action: {self.last_discrete_action}')
myprint(f'[IN LOCAL ACTION] diff: {[convert_global_to_local_action(a, self.task) - b for a, b in zip(all_retrieved_act[:, 0].cpu().numpy().tolist(), self.last_discrete_action)]}')
return self.last_discrete_action if not return_retrieved_obs else (self.last_discrete_action, all_retrieved_images)
JatRegentModel.register_for_auto_class("AutoModelForCausalLM")