surveilling-surveillance/detection/eval/loss.py
2021-05-20 13:22:04 -07:00

16 lines
434 B
Python

import torch
import argparse
def get_loss_fn(loss_args):
loss_args_ = loss_args
if isinstance(loss_args, argparse.Namespace):
loss_args_ = vars(loss_args)
loss_fn = loss_args_.get("loss_fn")
if loss_fn == "BCE":
return torch.nn.BCEWithLogitsLoss()
elif loss_fn == "CE":
return torch.nn.CrossEntropyLoss()
else:
raise ValueError(f"loss_fn {loss_args.loss_fn} not supported.")