とある大学生の勉強メモ

Python, C#, UWP, WPF, 心理実験関連の開発備忘録

回転・鏡像反転用のデータ拡張用コード Train-Valid分けコード

自分用 画像Pathを指定するとデータ拡張してくれるコードと,
画像群フォルダのPathを指定するとValidとTrainの2つに分けてくれるコード.

下記を使うのが一般的かもしれないが........ albumentations.ai

from glob import glob
import random
import os
import shutil

#指定したデータセットPATHの中にTrainフォルダとValidフォルダを作り,VALID_Ratioで指定した量で分ける
def Split_Train_Valid(DatasetFolderPath, filetype = ".png", VALID_RATIO = 0.2):
    """
    Args:
        DatasetFolderPath (string): 
        filetype (str, optional): 画像の型. Defaults to ".png".
        VALID_RATIO (float, optional): 検証用のデータ数割合. Defaults to 0.2.
    """
    files = glob(os.path.join(DatasetFolderPath, "*" + filetype))
    random.shuffle(files)
    VALID_DATA_NUM = int(len(files) * VALID_RATIO)
    
    #make dir
    Train_Dir = os.path.join(DatasetFolderPath, "Train")
    Valid_Dir = os.path.join(DatasetFolderPath, "Valid")
    if os.path.exists(Train_Dir)==False:
        os.mkdir(Train_Dir)
    if os.path.exists(Valid_Dir)==False:
        os.mkdir(Valid_Dir)
    
    #copy file
    for i, file in enumerate(files):
        if i<=VALID_DATA_NUM:
            shutil.copy2(file, Valid_Dir)
        else:
            shutil.copy2(file, Train_Dir)
#Writer : Yu Yamaoka
#回転と鏡像で8倍水増し用のコード

import cv2

def Augment_ByRotationMirror(ImageFilePath, filetype = ".png"):
        #every 90 degree rotation * mirror = 4 times * 2times = 8times(MAX)
        img = cv2.imread(ImageFilePath)
        width, height, _ = img.shape
        
        #Mirror
        mirror_img = cv2.flip(img, 1)
        cv2.imwrite(ImageFilePath.replace(filetype,"") + "_mirror" + filetype, mirror_img)
        
        #If width and height is NOT equal, can't do aug by 90 and 270 rot.  
        img_180 = cv2.rotate(img, cv2.ROTATE_180)
        cv2.imwrite(ImageFilePath.replace(filetype,"")+"_180" + filetype, img_180)
        cv2.imwrite(ImageFilePath.replace(filetype,"") + "_180mirror.png", cv2.flip(img_180, 1))
        
        if(width == height):
                img_90 = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
                cv2.imwrite(ImageFilePath.replace(filetype,"")+"_90" + filetype, img_90)
                cv2.imwrite(ImageFilePath.replace(filetype,"") + "_90mirror" + filetype, cv2.flip(img_90, 1))

                img_270 = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
                cv2.imwrite(ImageFilePath.replace(filetype,"")+"_270" + filetype, img_270)
                cv2.imwrite(ImageFilePath.replace(filetype,"") + "_270mirror" + filetype, cv2.flip(img_270, 1))