17 lines
434 B
Python
17 lines
434 B
Python
|
import torch
|
||
|
import argparse
|
||
|
|
||
|
|
||
|
def get_loss_fn(loss_args):
|
||
|
loss_args_ = loss_args
|
||
|
if isinstance(loss_args, argparse.Namespace):
|
||
|
loss_args_ = vars(loss_args)
|
||
|
loss_fn = loss_args_.get("loss_fn")
|
||
|
|
||
|
if loss_fn == "BCE":
|
||
|
return torch.nn.BCEWithLogitsLoss()
|
||
|
elif loss_fn == "CE":
|
||
|
return torch.nn.CrossEntropyLoss()
|
||
|
else:
|
||
|
raise ValueError(f"loss_fn {loss_args.loss_fn} not supported.")
|