{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9bca0f41",
   "metadata": {},
   "source": [
    "# Simple inference example with CroCo-Stereo or CroCo-Flow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80653ef7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
    "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f033862",
   "metadata": {},
   "source": [
    "First download the model(s) of your choice by running\n",
    "```\n",
    "bash stereoflow/download_model.sh crocostereo.pth\n",
    "bash stereoflow/download_model.sh crocoflow.pth\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fb2e392",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
    "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
    "import matplotlib.pylab as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e25d77",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stereoflow.test import _load_model_and_criterion\n",
    "from stereoflow.engine import tiled_pred\n",
    "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
    "from stereoflow.datasets_flow import flowToColor\n",
    "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86a921f5",
   "metadata": {},
   "source": [
    "### CroCo-Stereo example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64e483cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
    "image2 = np.asarray(Image.open('<path_to_right_image>'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d04303",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47dc14b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
    "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
    "with torch.inference_mode():\n",
    "    pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
    "pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "583b9f16",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(vis_disparity(pred))\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2df5d70",
   "metadata": {},
   "source": [
    "### CroCo-Flow example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ee257a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
    "image2 = np.asarray(Image.open('<path_to_second_image>'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5edccf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b19692c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
    "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
    "with torch.inference_mode():\n",
    "    pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
    "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26f79db3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(flowToColor(pred))\n",
    "plt.axis('off')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}