Source code for hawk_eye.core.classifier

"""A classifier model which wraps around a backbone. This setup allows for easy
interchangeability during experimentation and a reliable way to load saved models."""

import pathlib
from typing import Optional

import torch
import yaml

try:
    from hawk_eye.core import asset_manager
except ModuleNotFoundError:
    pass
from third_party.rexnet import rexnet
from third_party.vovnet import vovnet


[docs]class Classifier(torch.nn.Module): def __init__( self, num_classes: Optional[int] = 2, timestamp: Optional[str] = None, backbone: Optional[str] = None, half_precision: Optional[bool] = False, ) -> None: """ Args: num_classes: The number of classes to predict. timestamp: The timestamp of the model to download from GCloud. backbone: A string designating which model to load. half_precision: Whether to use half precision. This should be False for training but True during inference. :raises ValueError: Error if neither a timestamp or backbone arg is supplied. Examples: >>> classifier = Classifier(2, backbone="vovnet-19") >>> with torch.no_grad(): ... predictions = classifier.classify(torch.randn(1, 3, 64, 64), True) >>> predictions.shape torch.Size([1, 2]) """ super().__init__() self.num_classes = num_classes self.use_cuda = torch.cuda.is_available() self.half_precision = half_precision and self.use_cuda if backbone is None and timestamp is None: raise ValueError("Must supply either model timestamp or backbone to load") # If a version is given, download it. if timestamp is not None: # For the distributed pip package, look inside `production_models` production_models = pathlib.Path(__file__).parent / "production_models" if production_models.is_dir(): model_path = production_models / timestamp else: # Download the model or find it locally. model_path = asset_manager.download_model("classifier", timestamp) config = yaml.safe_load((model_path / "config.yaml").read_text())["model"] backbone = config.get("backbone", None) # Construct the model, then load the state self.model = self.load_backbone(backbone) self.load_state_dict( torch.load(model_path / "classifier.pt", map_location="cpu") ) self.image_size = config["image_size"] else: # If no timestamp supplied, just load the backbone self.model = self.load_backbone(backbone) self.model.eval() if self.use_cuda: self.model.cuda() if self.half_precision: self.model.half()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through classifier. Args: x: input tensor. Returns: the output tensor. """ # If half precision, assume inference not train. if self.half_precision: x = x.half() return self.model(x)
[docs] def load_backbone(self, backbone: str) -> torch.nn.Module: """Load the supplied backbone. See this function for the list of potential backbones that can be loaded. Args: backbone: The backbone type to load. Returns: The loaded model. :raises ValueError: If improper backbone is supplied. """ if "rexnet" in backbone: model = rexnet.ReXNet(num_classes=self.num_classes, model_type=backbone) elif "vovnet" in backbone: model = vovnet.VoVNet(model_name=backbone, num_classes=self.num_classes) else: raise ValueError(f"Unsupported backbone {backbone}.") return model
[docs] def classify(self, x: torch.Tensor, probability: bool = False) -> torch.Tensor: """Take in an image batch and return the class for each image. If specified, softmax will be applied to the predictions. Args: x: Input tensor of size (batch, height, width, channels). probability: Whether or not to apply softmax. Returns: The output tensor. """ if self.use_cuda and self.half_precision: x = x.half() if probability: return torch.nn.functional.softmax(self.model(x), dim=1) else: _, predicted = torch.max(self.model(x).data, 1) return predicted