from typing import List, Set

import torch


def sorted_list(s: Set[str]) -> List[str]:
    return sorted(list(set(s)))


def device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def nested_to_device(s):
    # s is either a tensor or a dictionary
    if isinstance(s, torch.Tensor):
        return s.to(device())
    return {k: v.to(device()) for k, v in s.items()}

def nested_apply(h, s):
    # h is an unary function, s is one of N, tuple of N, list of N, or set of N
    if isinstance(s, str):
        return h(s)
    ret = [nested_apply(h, i) for i in s]
    if isinstance(s, tuple):
        return tuple(ret)
    if isinstance(s, set):
        return set(ret)
    return ret