- fully-connected layer とは
- pre-trained model の fully-connected layer を変えるとは
- foo.to(device) とは
- 画像の表示
- learning rate とは
- module とは (& その書き方)
- ループごとに画像を表示する
- RGB to grayscale
- pyplot に四角を描画
- random integer list
- tip: 画像ファイルを削除したら annotation ファイルも更新する
- モデルの保存
ChatGPT とかに相談しながらのめも
fully-connected layer とは
特別な Linear layer で、他の layer のすべての出力を入力として受け取る。 Linear layer とは
output = input x W^T + b
で計算される layer (W: weight matrix, b: bias vector, ^T: denotes the transpose of the matrix)。W と b は学習中に計算される。
pre-trained model の fully-connected layer を変えるとは
# Load a pre-trained ResNet-18 model model = models.resnet18(pretrained=True) # Replace the classifier to match the output for bounding box regression (4 coordinates) model.fc = torch.nn.Linear(model.fc.in_features, 4) # x, y, width, height for the box
ResNet はもともと画像の class を出力するモデルだけど、最後の fully-connected layer の形を変えることで違うタスクに使うことができる。FC layer は前述の通り Linear layer なので、 torch.nn.Linear
を使えば OK。
ChatGPT が作ってくれた図:
[Previous Network Layers] ---> [Flatten] ---> [Old Fully Connected Layer] | | (features extracted by the network) V [New Fully Connected Layer] ---> [4 Outputs] (model.fc.in_features, 4)
foo.to(device)
とは
context:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # ... for epoch in range(num_epochs): model.train() train_loss = 0.0 for images, boxes in train_dataloader: images = images.to(device) boxes = boxes.to(device).float() # ...
PyTorch だと tensor の計算は同じ device で実行しないといけない。tensor A x tensor B をするとき、A は CPU にあって B は GPU にある... というのはできないので、 images.to(device)
で適切な device に移動させてる。
画像の表示
read_image
の出力は [channel, height, width]
なので、pyplot でカラー画像を表示させるときには .permute()
で [height, width, channel]
になるようにする。グレースケールの方は、.squeeze()
で size = 1 の次元を落とす = channel の次元を消す。
https://pytorch.org/vision/main/generated/torchvision.io.read_image.html
from torchvision.io import read_image import matplotlib.pyplot as plt i = read_image(im_path) # color image plt.imshow(i.permute(1, 2, 0)) # grayscale image plt.imshow(i.squeeze(), cmap="gray")
learning rate とは
ここがイメージしやすい
https://developers.google.com/machine-learning/crash-course/fitter/graph
module とは (& その書き方)
module は PyTorch の layer 。たとえば、 pre-trained モデルの fully-connected layer をカスタムしたければ下記のような torch.nn.Module
を継承したクラスを書く。
class CustomHead(torch.nn.Module): def __init__(self, in_features): super().__init__() self.fc = torch.nn.Linear(in_features, 4) def forward(self, x): # use sigmoid as an activation function return torch.sigmoid(self.fc(x)) model.fc = CustomHead(model.fc.in_features)
see also :
https://pytorch.org/docs/stable/notes/modules.html#a-simple-custom-module
ループごとに画像を表示する
plt.show() をループごとの呼んであげれば最後の画像しか表示されない、とはならない
import matplotlib.pyplot as plt from torchvision.io import read_image for filename in os.listdir(MY_DIR): im_path = os.path.join(MY_DIR, filename) im = read_image(im_path) plt.imshow(im) plt.show()
RGB to grayscale
(Grayscale は GrayScale なのか Grayscale なのか..)
Pad でモノクロ画像に対して fill の色を設定するときは、fill の色をグレースケールに変えてから指定する必要があって、モノクロに transform するのにRGB の tuple (r, g, b)
でやろうとしたら怒られる (transform する前だったのにエラーになったのでそこは謎) 。
RGB → grayscale は RGB それぞれの色の「明るさ度合い」を示す係数をかけて計算する。記事によって微妙に係数が違うけど、いったん ChatGPT に教えてもらったものが下記。
def rgb_to_grayscale(rgb): r, g, b = rgb int(0.299 * r + 0.587 * g + 0.114 * b)
pyplot に四角を描画
import matplotlib.pyplot as plt import matplotlib.patches as patches fig, ax = plt.subplots(1) # box = [x, y, w, h] rect = patches.Rectangle((box[0], box[1]), box[2], box[3], linewidth=2, edgecolor='r', facecolor='none') ax.add_patch(rect)
https://matplotlib.org/stable/api/_as_gen/matplotlib.patches.Rectangle.html
random integer list
import numpy as np dataset = ... # define your dataset # select random 10 indices for your dataset indices = np.random.choice(dataset.__len__, 10, replace=False)
tip: 画像ファイルを削除したら annotation ファイルも更新する
こういう dataset class で
class MyDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None): self.img_dir = img_dir self.img_labels = pd.read_csv(annotations_file) # each row is [/path/to/image, x, y, w, h] self.transform = transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): # other stuff
「この画像やっぱ微妙だからはずそう」ってなったとき、 annotation ファイル (ラベルを付与している CSVなどのファイル) の方からも画像の行を削除しないと、 "file not found error" で怒られる。
モデルの保存
PyTorch では、model の weights はこういう感じでレイヤーごとの tensor として保存されている。
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10])
で、これが model の state_dict
という変数に保存されているので、「モデルを保存したい」というときは、この dictionary を読み書きすれば OK。保存するときは .pt
か .pth
をつけるのが慣習らしい。
A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor.
A common PyTorch convention is to save models using either a .pt or .pth file extension.
PATH = "my_model.pt" # Save: torch.save(model.state_dict(), PATH) # Load: model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
※コード等は以下より引用