Initial Commit (tested training, testing, and TRT conversion)
This commit is contained in:
71
flightpolicy/yopo/yopo_network.py
Normal file
71
flightpolicy/yopo/yopo_network.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# The backbone and the custom gradient layer.
|
||||
import time
|
||||
import torch as th
|
||||
import torch.nn
|
||||
import numpy as np
|
||||
from torchvision.models import mobilenet_v3_small
|
||||
from flightpolicy.yopo.resnet import resnet18
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
# 18ms, Fast and effective.
|
||||
class ResNet18(torch.nn.Module):
|
||||
def __init__(self, output_dim: int, primitive_shape: int):
|
||||
super(ResNet18, self).__init__()
|
||||
self.cnn = resnet18(pretrained=False)
|
||||
self.cnn.conv1 = th.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
if (primitive_shape != 1):
|
||||
self.cnn.avgpool = th.nn.Sequential()
|
||||
self.cnn.fc = th.nn.Conv2d(512, output_dim, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.features_dim = output_dim
|
||||
|
||||
def forward(self, depth: th.Tensor) -> th.Tensor:
|
||||
return self.cnn(depth)
|
||||
|
||||
|
||||
# 20ms, Performs worse than ResNet and is slower than ResNet-18.
|
||||
class MobileNet(th.nn.Module):
|
||||
def __init__(self, output_dim: int):
|
||||
super(MobileNet, self).__init__()
|
||||
self.cnn = mobilenet_v3_small(pretrained=False)
|
||||
self.cnn.features[0][0] = th.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.cnn.classifier = th.nn.Linear(576, output_dim)
|
||||
self.features_dim = output_dim
|
||||
|
||||
def forward(self, depth: th.Tensor) -> th.Tensor:
|
||||
return self.cnn(depth)
|
||||
|
||||
|
||||
def YopoBackbone(output_dim, primitive_shape):
|
||||
return ResNet18(output_dim, primitive_shape)
|
||||
|
||||
|
||||
class CostAndGradLayer(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_dp, train_env, primitive_id):
|
||||
# print("input ", input_dp.shape)
|
||||
device = input_dp.device
|
||||
cost, grad = train_env.getCostAndGradient(input_dp, primitive_id)
|
||||
grad = np.minimum(grad, 1.0) # Gradient clipping: Prevent excessively large values.
|
||||
cost = torch.tensor(cost).to(device)
|
||||
grad = torch.tensor(grad).to(device)
|
||||
ctx.save_for_backward(grad)
|
||||
cost.requires_grad = True
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, cost_grad_input):
|
||||
grad, = ctx.saved_tensors
|
||||
return_grad = th.bmm(grad.unsqueeze(-1), cost_grad_input.unsqueeze(-1)).squeeze(dim=2)
|
||||
# print("grad ", return_grad.shape)
|
||||
# print("grad: ", return_grad)
|
||||
return return_grad, None, None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = YopoBackbone(64, 3)
|
||||
input_ = torch.zeros((1, 1, 96, 96))
|
||||
start = time.time()
|
||||
output = net(input_)
|
||||
print(time.time() - start)
|
||||
Reference in New Issue
Block a user