from duckduckgo_search import ddg_images
def search_images(term, max_images=30):
print(f"Searching for '{term}'")
return L(ddg_images(term, max_results=max_images)).itemgot('image')
Step 1: Download images of the Beatles using DuckDuckGo
::: {#589f3b17 .cell _cell_guid=‘b1076dfc-b9ad-4769-8c92-a6c4dae69d19’ _uuid=‘8f2839f25d086af736a60e9eeb907d3b93b6e0e5’ execution=‘{“iopub.execute_input”:“2023-03-09T22:02:42.432683Z”,“iopub.status.busy”:“2023-03-09T22:02:42.432353Z”,“iopub.status.idle”:“2023-03-09T22:02:57.541452Z”,“shell.execute_reply”:“2023-03-09T22:02:57.540385Z”}’ papermill=‘{“duration”:15.121345,“end_time”:“2023-03-09T22:02:57.544044”,“exception”:false,“start_time”:“2023-03-09T22:02:42.422699”,“status”:“completed”}’ tags=‘[]’ execution_count=1}
import socket,warnings
try:
1)
socket.setdefaulttimeout(connect(('1.1.1.1', 53))
socket.socket(socket.AF_INET, socket.SOCK_STREAM).except socket.error as ex: raise Exception("STOP: No internet. Click '>|' in top right and set 'Internet' switch to on")
import os
= os.environ.get('KAGGLE_KERNEL_RUN_TYPE', '')
iskaggle
if iskaggle:
!pip install -Uqq fastai duckduckgo_search
from fastai import *
from fastai.vision import *
from fastai.vision.all import *
from fastai.vision.widgets import *
from fastcore.all import *
from fastdownload import download_url
:::
We save images in ‘searches’ to train the model. We remove images which didn’t download correctly.
= 'John Lennon', 'Paul McCartney', 'Ringo Starr', 'George Harrison'
searches = Path('the_Beatles')
path from time import sleep
for o in searches:
= (path/o)
dest =True, parents=True)
dest.mkdir(exist_ok=search_images(f'{o}', 150))
download_images(dest, urls10) # Pause between searches to avoid over-loading server
sleep(#resize_images(path/o, max_size=400, dest=path/o)
= verify_images(get_image_files(path))
failed map(Path.unlink)
failed.len(failed)
Searching for 'John Lennon'
Searching for 'Paul McCartney'
Searching for 'Ringo Starr'
Searching for 'George Harrison'
19
Step 2: Augment the data
We use a ‘Datablock’ to separate the data into training and validation sets.
= DataBlock(
data =(ImageBlock, CategoryBlock),
blocks=get_image_files,
get_items=RandomSplitter(valid_pct=0.2, seed=42),
splitter=parent_label,
get_y=Resize(128))
item_tfms= data.dataloaders(path)
dls =6, nrows=2) dls.valid.show_batch(max_n
Pad the images with black:
=data.new(item_tfms=Resize(128, ResizeMethod.Pad, pad_mode='zeros'))
data= data.dataloaders(path)
dls =6, nrows=2) dls.valid.show_batch(max_n
Squish the images:
=Resize(128, ResizeMethod.Squish))
data.new(item_tfms= data.dataloaders(path)
dls =8, nrows=2) dls.train.show_batch(max_n
Transform with Random Resized Crop:
=RandomResizedCrop(128, min_scale=0.3))
data.new(item_tfms= data.dataloaders(path)
dls =8, nrows=2, unique=True) dls.train.show_batch(max_n
Example of data augmentation using aug_transforms:
= data.new(item_tfms=Resize(128), batch_tfms=aug_transforms(mult=2))
data = data.dataloaders(path)
dls =8, nrows=2, unique=True) dls.train.show_batch(max_n
Step 3: Train the model and clean some of the data by hand
= data.new(
data =RandomResizedCrop(224, min_scale=0.5),
item_tfms=aug_transforms())
batch_tfms= data.dataloaders(path) dls
= vision_learner(dls, resnet18, metrics=error_rate)
learn 4) learn.fine_tune(
/opt/conda/lib/python3.7/site-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
/opt/conda/lib/python3.7/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 2.396378 | 1.538527 | 0.522936 | 00:24 |
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 1.393517 | 1.140603 | 0.440367 | 00:22 |
1 | 1.261823 | 0.996967 | 0.357798 | 00:23 |
2 | 1.069076 | 0.967244 | 0.339450 | 00:23 |
3 | 0.934776 | 0.965879 | 0.348624 | 00:23 |
To visualize the mistakes the errors the model makes, we create a confusion matrix.
= ClassificationInterpretation.from_learner(learn)
interp interp.plot_confusion_matrix()
5, nrows=5) interp.plot_top_losses(
Step 3: Turn the model into an online application
We save the architecure and the learned parameters of our model.
learn.export()
= Path()
path ='.pkl') path.ls(file_exts
(#1) [Path('export.pkl')]
We can then create an inference learner from this exported file.
= load_learner(path/'export.pkl') learn_inf