trajpred/01_detect_objects.ipynb

363 lines
678 KiB
Plaintext
Raw Permalink Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"from pathlib import Path\n",
"import numpy as np\n",
"# from PIL import Image\n",
"import torch\n",
"from torchvision.io.video import read_video\n",
"import matplotlib.pyplot as plt\n",
"from torchvision.utils import draw_bounding_boxes\n",
"from torchvision.transforms.functional import to_pil_image\n",
"from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"source = Path('../DATASETS/VIRAT_subset_0102x')\n",
"videos = source.glob('*.mp4')\n",
"homography = list(source.glob('*img2world.txt'))[0]\n",
"H = np.loadtxt(homography, delimiter=',')\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The homography matrix helps to transform points from image space to a flat world plane. The `README_homography.txt` from VIRAT describes:\n",
"\n",
"> Roughly estimated 3-by-3 homographies are included for convenience. \n",
"> Each homography H provides a mapping from image coordinate to scene-dependent world coordinate.\n",
"> \n",
"> [xw,yw,zw]' = H*[xi,yi,1]'\n",
"> \n",
"> xi: horizontal axis on image with left top corner as origin, increases right.\n",
"> yi: vertical axis on image with left top corner as origin, increases downward.\n",
"> \n",
"> xw/zw: world x coordinate\n",
"> yw/zw: world y coordiante"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# H.dot(np.array([20,300, 1]))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"video_path = list(videos)[0]\n",
"video_path = Path(\"../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PosixPath('../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"video_path"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Suggestions from: https://stackabuse.com/retinanet-object-detection-with-pytorch-and-torchvision/"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT\n",
"model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35)\n",
"# Put the model in inference mode\n",
"model.eval()\n",
"# Get the transforms for the model's weights\n",
"preprocess = weights.transforms()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# hub.set_dir()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"video = cv2.VideoCapture(str(video_path))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"> The score_thresh argument defines the threshold at which an object is detected as an object of a class. Intuitively, it's the confidence threshold, and we won't classify an object to belong to a class if the model is less than 35% confident that it belongs to a class."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The result from a single prediction coming from `model(batch)` looks like:\n",
"\n",
"```python\n",
"{'boxes': tensor([[5.7001e+02, 2.5786e+02, 6.3138e+02, 3.6970e+02],\n",
" [5.0109e+02, 2.4508e+02, 5.5308e+02, 3.4852e+02],\n",
" [3.4096e+02, 2.7015e+02, 3.6156e+02, 3.1857e+02],\n",
" [5.0219e-01, 3.7588e+02, 9.7911e+01, 7.2000e+02],\n",
" [3.4096e+02, 2.7015e+02, 3.6156e+02, 3.1857e+02],\n",
" [8.3241e+01, 5.8410e+02, 1.7502e+02, 7.1743e+02]]),\n",
" 'scores': tensor([0.8525, 0.6491, 0.5985, 0.4999, 0.3753, 0.3746]),\n",
" 'labels': tensor([64, 64, 1, 64, 18, 86])}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABQAAAALQCAIAAABAH0oBAAEAAElEQVR4nOz9V5MkSZYeCiozbubcg0dGJCte1Qx9ewD0YDALYO5g92FlIfsP8bgQvF0BG+xABtsAmlZ1sayqZMGJc2LcTFX34birq0eGZ0YU6aka5JGUFEtLczM1NSWHfOc7eGN/Dc0lSRJ0nXDOrz1/E6GUGobBGKOUIoSklOq/4BhjXBQFQsg0TcZYWZZCCM65lDLPc/16JUIIdf4lbRNCeJ7n+/5kMpFSzn7FrFe22bbtJEmCIJhMJs31nfv373c6ncvLyyzLPM/jnI9GI8ZYliSW42RpygzDtm2EUJqmlUpF3Wc6nbquW5ZlNB472FhfX29tb1xcXEx5vr+//xf/4v+2v7+frfknJyeHT58dHByQtOx2u+NOL46igFPDMFzXNU0TY8wYu3PnztbWVvvnP8cYdzqdR48ehWF47969LMv+23/7b04vppSORqNKpfLgwYMHDx4URXFycmJWjHv37sVx/PTp0zAMG41GkiRSyi/++Em73d7Y2OCc7+7udjqdzz//XEr5zjvv2bZtWiQIgifP/nh0dFStegihNKHT6RQhFATBL3/5S4zx//yf//PevXteq845397e5pw/e/YMPuXz589lnKl+gFeoVquNRsPf2JVSep7X6XTyPOecf/j737u+/3//f/+/Go2G67owZqbT6cbGhuu6X/XiNE0NwwiCgDFGCCmKIk3TR48eOY7zxhtvWJZlGIZ6Vo5f/X11wfj68bNqXFFZXnt+w3aazWZexIQQZuZSyqJMpJROmjYajSzLjo+P+/3+48eP+/3+xsbG3t5etVodj8dfffVVGIaLvuKL++vvVZbXP7egi2NCiBAiz3MhhJHN2l8UBfRzURRCCEosznme52VZGoYhhDAMYzAYWJYFgy3Pc9u2KaXdbpcQUvMa1Wp1a2trMBhMJhPG2GQyGQ6H3KSWZTUajY2NjfX19adPn/Z6vbIsh92IUsorjud5W1tbnPPTT7+sVCr3/48f27YdRZFlWZVKZTKZ1Ot1QohhLiY451y1MwgC/Txj7PT09Pj42JQ4TdP+4alt2wmV7Xbbu7tlGEYsS9M0XWqcnZ09/vizvb29e7t7CKGeYSCEqEAIIZ+anPMnnz2aTCZ3Hz5oNBplWRJC9P68E1jD4fDj3/6h1WqdPn6GMd7b2kEIYc9SbXv48OFkMjk5OZlMJgZa+rkSjLH6Fvq34yW+9vqbiN5UuPmDBw/efffdTz/9dDQaua6LMebXN2elZPTV1+jC6eJdpJSGYYzHY8bYn//45zA8Wq3W0g9s4+otvlXhpLjV9Vc+ty5pmiKEYCVXcu0G9BJh8mt+X4xv9EO9PUII9Vv1c4zxlXFyq2bcdvzctn8QQmmaMsYsy9Lbdtt2fltC+PXt1+ev4zjqfIlv976rPuu39b70utuA+nT1SimklIYQRVH87ne/63Q6rusudAxSwOIfx7FlWaZpep5nmubJca8sS7ihGldCiFp1fTKZRFEEm456ShQmlUrlRz9+p9/vP/ry96ZpNpoV13UHg05RFPA4xhjsQUIIz/TTNB0Oh3Ec+74PTSqKIggCUB0ppfV6PUmSfr+fJIlfrzLGQKUcj8eEELgmTVOlNwohYNUVQlQqlbIssyzjnDuOwznPsqwsSy4SQgghhFJqWRb8zRizidPpdMqyrNVq1Wo1juM8z03ThJUfxLIWOoa+R3ueVxTFdDrN8zzLFvqPbduEENM0KaVRmIEiyjnXlxrDMNI0NU0zCII8z+F6QkiBBIxDzjmlVOmZGDnQbCml3h4z8CzLKoqiLEvTNNX5VcvStePnxeG0+Idk114Tx7FaP0FPg/Y7jlOWJbRfb2ee5+pY3x8dwwS15ODgIAgCy7KePn2aJMm9e/c2NjaSJAGtvixLGJYnzw+3t7cRQqenp5PJBFRljDFyTCllrVar1Wr6N6LatIBRURQF59wwjDAModOEECcnJ0VRtNvt/ffeghEF16v+T0+6tVotaNQuLi5iItbW1gTFpmlKglWvNkw3TdPO6Xm/38eUlGVp23a324URC+Nh5/7dLMviOC7LUo1DwzAajYZqJ2MsjmOEkOu6rmldXFw8e/YsDMO7d+86jjOdThlj/kYL1lUwQDqdzsHBged5eBRHUTQajbIsg9FOKSWE1NZbQohqqzmdTj8/eNJsNt/60fue5+2+/RP1XGgnxjhJkorjSSmTMi/LcpDHtm1XWg3bttvCTJLENE3LssBkK8syiqIkSYqicBzH8zzDMAghFBOEUBlPYYWJ43g4HDabTcMwJpOJtI1WqzUajeArCCEajUaj0VhrNPM8hznVbDYRQsPhcDgcFuOIUsocizEW8rxarXrNGkKoiFPoTNM09TUQzE8QjDF0Asb4+tH8JxYpJQwyzrlaub7GzrpK6vV6URRhGGZZRm/wxjDci6KwbdswjCzLLMuqVqu9Xo8x9v7775+fn3/++eeu7ydJ4noeIURN/iuGU1mWjLF6u+0Kwjk/OjpqtVqexXzfD8Pw4uJimpq2ba+trR0fH//xN/8DhSGiBnJd/SZvvvmm53mU0sFgcPbHP+7s7NRqtbt37yKENjY2Op2O53lokKoBRynNsiwIgnv37j0+/mowGGxubr799tsfffTRhx9+2Gw2d3d3K5UKISQMQ0LI9vZ2s9mUUmKMz88vPM974817lFIhBCEkiiLTNBGi7rxV5+fnu7u7P/rRj54+fVrBIk3Ti4uLJEk6nY5hGLDNGNqAU/bM5eXl6TiGrzwajer1OqXU9f1Wq1WW5Wg0QgjVarXz83OMsZQySRLY7YqiGI1Gz549g6XKcRwhRBzHo9EIlr/vlehrTcv3oyh69OjRwcFBFEUIoQcPHmxsbFBKT09Pz8/Pdev3Gwqs3QghIUQZh2qT0Cc/OJVuPrnKsoSBwRiDuUkp9X0/WGvW63XY8o+Pj8/Ozi4uLoIg8P2K53m0VbUsq1arpWkKU+/besdvItAVhJCXeM1APUrTdDwew04M3cXzXAgBmyVsVKBX2a6v318dg0IGSh5j34tl9rW8ltfyWl4pURTpHlLYU2D9BwGzAbaGl9ynUqm4rlupuLAVcs4JIVcWw9FoJISAPUU/Dw8FkztJkjzPQW2F/1KKItjSCKEsy5ShooIoCCFYgWELAwcrtMGkjmoMWMKz3c2htm2DYxRuAgL3hPvHcVyr1aDxnuepNgdBEMdxFEVXukVtymiu60Lb4jhWDgiwGA3DgKeDDUkpBf8MvDgYxnAwnUTVahU6DawjkBzLhRH4JxT4xGge8eKcgx0CJjp0yConvi5SyqIowD4cj8etVsuyrE6nA98XImT69Y7jSCmjKArDMEkSQgg4Mm7px0MwGEDjLYoCBgy8C3wv+C/of1D4e71eUubj8dhu1UBfLcuSmgtjG+Jh8Xh6enpqWGa9XofLwKqE26ZpCqMLDizLchxHjWrVNniuGsmVSqVarVarVegxQgiMWPghIWQwGEADeBiCygd3wBiDWRvHMSHk8vJyMBjU6/Wf/vSna7vbvu9LTVeUUoZhCDpPOo0QQoZrW5ZlItMwDBhg6ukIIcZYkiQwo8EPAprSbDBjIqUs5584m4sQIssyiWd2H9wHOl/NOzVVYWybpiloquYs0kY6KHhwXp+bupMO+h969XuhmcGAQwjBO8Pxt2UAE0JAqYV+vMlPYNFM03Rtba0oii+//BIch7Ztt1qt/f19xtjTp09d17Vt2/O84XCov4s6xhhHUbS1tbW3t+cK8vTp03A83t/f37y/jzHudru9Xu+SZIwxmZeXJ2eW7WYFR6W4Elba2NjY2tpyHGc4HJ6ZZrvd9jyvVqvFcbyxsWHb9p07d744uKhUK5ubm1mW9aZjY+JW11qNVqMR9T777LOiKPb39znneZZBBG97exuME8YYY2x/fz+OY9u2h8NxWZbVavXk5OTy8pIQkqYxpdR1XXDA9Pv9o6OjOI7v3bsnhPjj7/4AXjHbtmVe5nmZR0m9Xudo4eGDSBH8vCBOFEVHR0dEyng8tSzr3u7e+vo6j9M
"text/plain": [
"<PIL.Image.Image image mode=RGB size=1280x720>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'frame 1'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'boxes': tensor([[5.6998e+02, 2.5778e+02, 6.3132e+02, 3.6969e+02],\n",
" [5.0109e+02, 2.4507e+02, 5.5308e+02, 3.4848e+02],\n",
" [5.4706e-01, 3.7548e+02, 9.8450e+01, 7.2000e+02],\n",
" [3.4061e+02, 2.7014e+02, 3.6137e+02, 3.1858e+02],\n",
" [3.4061e+02, 2.7014e+02, 3.6137e+02, 3.1858e+02],\n",
" [8.3206e+01, 5.8410e+02, 1.7512e+02, 7.1747e+02]]), 'scores': tensor([0.8500, 0.6467, 0.4990, 0.4889, 0.4773, 0.3656]), 'labels': tensor([64, 64, 64, 18, 1, 86])}\n"
]
}
],
"source": [
"# TODO make into loop\n",
"%matplotlib inline\n",
"\n",
"\n",
"import pylab as pl\n",
"from IPython import display\n",
"\n",
"i=0\n",
"while True:\n",
" ret, frame = video.read()\n",
" i+=1\n",
" \n",
" if not ret:\n",
" print(\"Can't receive frame (stream end?). Exiting ...\")\n",
" break\n",
"\n",
" t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n",
" t.shape\n",
" # image = image[np.newaxis, :] \n",
" t = t.permute(2, 0, 1)\n",
" t.shape\n",
"\n",
" batch = [preprocess(t)]\n",
" # no_grad can be used on inference, should be slightly faster\n",
" with torch.no_grad():\n",
" predictions = model(batch)\n",
" prediction = predictions[0] # we feed only one frame at the once\n",
"\n",
" mask = prediction['labels'] == 1 # if we want more than one: np.isin(prediction['labels'], [1,86])\n",
"\n",
" scores = prediction['scores'][mask]\n",
" labels = prediction['labels'][mask]\n",
" boxes = prediction['boxes'][mask]\n",
" \n",
" # TODO: introduce confidence and NMS supression: https://github.com/cfotache/pytorch_objectdetecttrack/blob/master/PyTorch_Object_Tracking.ipynb\n",
" # (which I _think_ we better do after filtering)\n",
" # alternatively look at Soft-NMS https://towardsdatascience.com/non-maximum-suppression-nms-93ce178e177c\n",
"\n",
" labels = [weights.meta[\"categories\"][i] for i in labels]\n",
"\n",
" box = draw_bounding_boxes(t, boxes=boxes,\n",
" labels=labels,\n",
" colors=\"cyan\",\n",
" width=2, \n",
" font_size=30,\n",
" font='Arial')\n",
"\n",
" im = to_pil_image(box.detach())\n",
"\n",
" display.display(im, f\"frame {i}\")\n",
" print(prediction)\n",
" display.clear_output(wait=True)\n",
"\n",
" break # for now\n",
" # pl.clf()\n",
" # # pl.plot(pl.randn(100))\n",
" # pl.figure(figsize=(24,50))\n",
" # # fig.axes[0].imshow(img)\n",
" # pl.imshow(im)\n",
" # display.display(pl.gcf(), f\"frame {i}\")\n",
" # display.clear_output(wait=True)\n",
" # time.sleep(1.0)\n",
"\n",
" # fig, ax = plt.subplots(figsize=(16, 12))\n",
" # ax.imshow(im)\n",
" # plt.show()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'boxes': tensor([[5.7001e+02, 2.5786e+02, 6.3138e+02, 3.6970e+02],\n",
" [5.0109e+02, 2.4508e+02, 5.5308e+02, 3.4852e+02],\n",
" [3.4096e+02, 2.7015e+02, 3.6156e+02, 3.1857e+02],\n",
" [5.0219e-01, 3.7588e+02, 9.7911e+01, 7.2000e+02],\n",
" [3.4096e+02, 2.7015e+02, 3.6156e+02, 3.1857e+02],\n",
" [8.3241e+01, 5.8410e+02, 1.7502e+02, 7.1743e+02]]),\n",
" 'scores': tensor([0.8525, 0.6491, 0.5985, 0.4999, 0.3753, 0.3746]),\n",
" 'labels': tensor([64, 64, 1, 64, 18, 86])}"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prediction"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([False, False, True, False, False, False])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prediction['labels'] == 1"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.5985])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prediction['boxes'][prediction['labels'] == 1]\n",
"prediction['scores'][prediction['labels'] == 1]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "1135f674f58caf91385e41dd32dc418daf761a3c5d4526b1ac3bad0b893c2eb5"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}