python + opencvで個人認識 その2
前回は自分の顔をopencvで保存しました。
今回はcnnを使う前にデータセットの自作とスクレイピングで画像を集めたいと思います。
データセットの自作ではtensorflow等で取り扱えるような形にしたいと思います。今回はデータの水増し等はやらないですが、そのうち水増しのコードも書きたいと思います。
スクレイピングでは自分の顔以外の適当な画像データを集めます。学習データが自分の顔だけより、他の画像データを持ってくることでcnnの汎化性能が上がる気がしたからです。まあ、個人認識の論文とか読んでないので分からないですので間違ってたら教えてくれると嬉しいです。
まずはデータセットの自作からです。こちらのサイトを参考にしました。
qiita.com
各クラスのディレクトリを収納した親ディレクトリのパスとクラス数と画像のサイズを指定するとデータセットを作成します。
以下ソースコードになります。
import cv2 import os import random import numpy as np class DataRead: def write_pathlabel(self, path,pathfile): dir_list = os.listdir(path) with open(pathfile,"w") as file: for y,dir_name in enumerate(dir_list): files = os.listdir(path+"\\"+dir_name) f_list = [path + "\\" + dir_name + "\\" + f + " " + str(y) + "\n" for f in files if os.path.isfile(os.path.join(path+"\\"+dir_name,f))] file.writelines(f_list) def read_data(self, pathfile): with open(pathfile,"r") as file: self.PATH_AND_LABEL = [(line.rstrip()).split() for line in file] random.shuffle(self.PATH_AND_LABEL) def import_data(self, pathfile, imgshapes, numclass): self.read_data(pathfile) self.DATA_SET = [] for path_label in self.PATH_AND_LABEL: if len(path_label) !=2: continue img = cv2.imread(path_label[0]) img = cv2.resize(img,imgshapes) img = img.flatten().astype(np.float32)/255.0 label_ary = np.zeros(numclass, dtype = 'float64') label_ary[int(path_label[1])] = 1 self.DATA_SET.append([img,label_ary]) def get_data(self): return self.DATA_SET if __name__ == "__main__": dr = DataRead() dr.write_pathlabel('test','path.txt') dr.import_data('path.txt',(64,64),2) print(dr.get_data())
次にスクレイピングで画像を取得していきます。BeautifulSoup4とrequestsを用いています。
ページ内のimgタグを取得してきて、requestsを用いて画像をダウンロードします。
import requests import urllib import os from bs4 import BeautifulSoup class Scraper: def get_imglinks(self,url): self.imgs = [] res = requests.get(url) content = res.content soup = BeautifulSoup(content, 'html.parser') links = soup.find_all("img") img_links = [link.get("src") for link in links] return img_links def download(self,keyword,urls): if os.path.exists(keyword) == False: os.mkdir(keyword) for i,url in enumerate(urls): fn,ext = os.path.splitext(url) res = requests.get(url, allow_redirects=False) if res.status_code != 200: print(res.status_code) continue if 'image' not in res.headers["content-type"]: print(res.headers["content-type"]) continue with open(keyword + "/" + str(i) + ext, "wb" ) as f: f.write(res.content) if __name__ == "__main__": scraper = Scraper() links = scraper.get_imglinks('https://www.google.co.jp/search?q=rwby&source=lnms&tbm=isch&sa=X&ved=0ahUKEwivruumkM3aAhXDqJQKHZixBLUQ_AUICygC&biw=1812&bih=954') scraper.download('rwby',links)
これでデータセットの作成と学習用のデータを集められるようになったので次回こそ学習させる予定です。