{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize environment and custom tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pathlib\n",
    "import sys\n",
    "import os\n",
    "sys.path.append(str(pathlib.Path(os.path.abspath('')).parent))\n",
    "\n",
    "from envs.custom_dmc_tasks import *\n",
    "from dm_control import suite\n",
    "import numpy as np\n",
    "\n",
    "domain = 'stickman'\n",
    "task = 'sit_knees'\n",
    "\n",
    "env = suite.load(domain_name=domain, task_name=task, visualize_reward=True)\n",
    "\n",
    "action_spec = env.action_spec()\n",
    "\n",
    "# Define a uniform random policy.\n",
    "def random_policy(time_step):\n",
    "  del time_step  # Unused.\n",
    "  return np.random.uniform(low=action_spec.minimum,\n",
    "                           high=action_spec.maximum,\n",
    "                           size=action_spec.shape)\n",
    "\n",
    "def zero_policy(time_step):\n",
    "  del time_step\n",
    "  return np.zeros(action_spec.shape)\n",
    "   \n",
    "\n",
    "class GoalSetWrapper:\n",
    "    def __init__(self, env, goal=None, goal_idx=None):\n",
    "       self._env = env\n",
    "       self._env._step_limit = float('inf')\n",
    "       self._goal = goal\n",
    "       self._goal_idx = goal_idx\n",
    "\n",
    "    def step(self, *args, **kwargs):\n",
    "        if self._goal is not None:\n",
    "            self.set_goal(self._goal)\n",
    "        if self._goal_idx is not None:\n",
    "            self.set_goal_by_idx(self._goal_idx)\n",
    "        return self._env.step(*args, **kwargs)\n",
    "    \n",
    "    def set_goal_by_idx(self, idx_goal):\n",
    "        cur = self._env.physics.get_state().copy()\n",
    "        for idx, goal in idx_goal:\n",
    "            cur[idx] = goal\n",
    "        self._env.physics.set_state(cur)\n",
    "        self._env.step(np.zeros_like(self.action_spec().shape))\n",
    "\n",
    "    def set_goal(self, goal):\n",
    "        goal = np.array(goal)\n",
    "        size = self._env.physics.get_state().shape[0] - goal.shape[0]\n",
    "        self._env.physics.set_state(np.concatenate((goal, np.zeros([size]))))\n",
    "        self._env.step(np.zeros_like(self.action_spec().shape))\n",
    "\n",
    "    def __getattr__(self, name: str):\n",
    "       return getattr(self._env, name)\n",
    "\n",
    "\n",
    "env = GoalSetWrapper(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from envs.custom_dmc_tasks.stickman import StickmanYogaPoses\n",
    "\n",
    "obs = env.reset()\n",
    "\n",
    "for _ in range(1):\n",
    "    env.set_goal(StickmanYogaPoses.sit_knees)\n",
    "\n",
    "# for _ in range(20):\n",
    "#     obs = env.step(np.random.randn(*env.action_spec().shape))\n",
    "print('Rew', obs.reward)\n",
    "\n",
    "print('Upright', env.physics.torso_upright())\n",
    "print('Torso height', env.physics.torso_height())\n",
    "\n",
    "plt.imshow(env.physics.render(camera_id=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(1):\n",
    "    obs = env.step(np.random.randn(*env.action_spec().shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.physics.named.data.qpos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.physics.named.data.xpos"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mine_new",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}