IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    Use bits instead of set for visited nodes. LeetCode #1434

    RobinDong发表于 2023-03-22 03:17:19
    love 0

    My first idea is depth-first-search: iterate all people, try to give them different hats. The solution got TLE (Time Limit Exceeded). Then as a hint from discussion forum, I started to iterate hat (instead of people), try to give them different people. The solution also got TLE (even I used lru_cache for function):

    from collections import defaultdict
    
    class Solution:
            
        def numberWays(self, hats: List[List[int]]) -> int:
            hp = defaultdict(set)
            for index, hat in enumerate(hats):
                for _id in hat:
                    hp[_id].add(index)
                    
            hp = [people for people in hp.values()]
            @functools.lru_cache(None)
            def dfs(start, path) -> int:
                if len(path) == len(hats):
                    return 1
                if start == len(hp):
                    return 0
                total = 0
                for person in (hp[start] - set(path)):
                    total += dfs(start + 1, tuple(list(path) + [person]))
                total += dfs(start + 1, path)
                return total % (10**9 + 7)
    
            return dfs(0, tuple())

    Using list as data structure to record visited node is not efficient enough in this case. Since there will be no more than 10 people, the most efficient data structure to record visited people is bits.

    My final solution is still using dfs (by using lru_cache, it is also a dynamic-programming):

    from collections import defaultdict
    
    class Solution:
            
        def numberWays(self, hats: List[List[int]]) -> int:
            hp = defaultdict(set)
            for index, hat in enumerate(hats):
                for _id in hat:
                    hp[_id].add(index)
                    
            hp = [people for people in hp.values()]
            @functools.lru_cache(None)
            def dfs(start, mask) -> int:
                if bin(mask).count('1') == len(hats):
                    return 1
                if start == len(hp):
                    return 0
                total = 0
                for person in hp[start]:
                    if (1 << person) & mask > 0:
                        continue
                    mask |= 1 << person
                    total += dfs(start + 1, mask)
                    mask ^= 1 << person
                total += dfs(start + 1, mask)
                return total % (10**9 + 7)
    
            return dfs(0, 0)


沪ICP备19023445号-2号
友情链接