Coverage for application / tator / tator_rest_client.py: 23%

78 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-23 05:22 +0000

1import base64 

2 

3import requests 

4 

5from application.tator.tator_type import TatorStateType 

6 

7 

8class TatorRestClient: 

9 """ 

10 Thin wrapper around the Tator REST API. Handles auth headers and URL construction. 

11 Use instead of raw requests calls to avoid repeating boilerplate. 

12 """ 

13 

14 def __init__(self, tator_url: str, token: str): 

15 self.base_url = tator_url 

16 self._headers = { 

17 'Content-Type': 'application/json', 

18 'Authorization': f'Token {token}', 

19 } 

20 

21 @staticmethod 

22 def login(tator_url: str, username: str, password: str) -> str: 

23 """Returns a Tator API token for the given credentials, or raises HTTPError on failure.""" 

24 res = requests.post( 

25 url=f'{tator_url}/rest/Token', 

26 headers={'Content-Type': 'application/json'}, 

27 json={'username': username, 'password': password, 'refresh': True}, 

28 ) 

29 res.raise_for_status() 

30 return res.json()['token'] 

31 

32 def get_localizations(self, project_id: int, section: str = None, media_id: list[int] = None) -> list: 

33 if media_id is not None: 

34 url = f'{self.base_url}/rest/Localizations/{project_id}?media_id={",".join(str(m) for m in media_id)}' 

35 elif section is not None: 

36 url = f'{self.base_url}/rest/Localizations/{project_id}?section={section}' 

37 else: 

38 raise ValueError('Must provide either section or media_id') 

39 res = requests.get(url=url, headers=self._headers) 

40 res.raise_for_status() 

41 return res.json() 

42 

43 def get_section_by_id(self, section_id: str) -> dict: 

44 url = f'{self.base_url}/rest/Section/{section_id}' 

45 res = requests.get(url=url, headers=self._headers) 

46 res.raise_for_status() 

47 return res.json() 

48 

49 def get_medias_for_section(self, project_id: int, section: str) -> list: 

50 url = f'{self.base_url}/rest/Medias/{project_id}?section={section}' 

51 res = requests.get(url=url, headers=self._headers) 

52 res.raise_for_status() 

53 return res.json() 

54 

55 def get_media_by_id(self, media_id: str) -> dict: 

56 url = f'{self.base_url}/rest/Media/{media_id}' 

57 res = requests.get(url=url, headers=self._headers) 

58 res.raise_for_status() 

59 return res.json() 

60 

61 def get_substrates_for_medias(self, project_id: int, transect_media: list[dict]) -> list[dict]: 

62 """Returns substrates grouped by media ID, sorted by timestamp.""" 

63 states = self._get_states(project_id, [str(media['id']) for media in transect_media]) 

64 grouped: dict[int, list] = {} 

65 fps_map = {media['id']: media['fps'] for media in transect_media} 

66 for state in states: 

67 if state['type'] == TatorStateType.SUBSTRATE: 

68 media_id = state['media'][0] 

69 grouped.setdefault(media_id, []).append( 

70 { 

71 **state['attributes'], 

72 'timestamp': self._format_timestamp(state['frame'] / fps_map[media_id]) if media_id in fps_map else None, 

73 'frame': state['frame'], 

74 } 

75 ) 

76 for entries in grouped.values(): 

77 entries.sort(key=lambda entry: (entry['timestamp'] is None, entry['timestamp'])) 

78 return [{'media_id': media_id, 'substrates': entries} for media_id, entries in grouped.items()] 

79 

80 def _get_states(self, project_id: int, media_ids: list[str]): 

81 states_url = f'{self.base_url}/rest/States/{project_id}?media_id={",".join(media_ids)}' 

82 states_res = requests.get(url=states_url, headers=self._headers) 

83 states_res.raise_for_status() 

84 return states_res.json() 

85 

86 def get_user(self, user_id: int) -> dict: 

87 url = f'{self.base_url}/rest/User/{user_id}' 

88 res = requests.get(url=url, headers=self._headers) 

89 res.raise_for_status() 

90 return res.json() 

91 

92 def get_frame(self, media_id: int, frame: int = None, quality: int = None) -> bytes: 

93 url = f'{self.base_url}/rest/GetFrame/{media_id}' 

94 params = {} 

95 if frame is not None: 

96 params['frames'] = frame 

97 if quality is not None: 

98 params['quality'] = quality 

99 res = requests.get(url=url, headers=self._headers, params=params) 

100 res.raise_for_status() 

101 base64_image = base64.b64encode(res.content).decode('utf-8') 

102 return base64.b64decode(base64_image) 

103 

104 def get_localization_graphic(self, localization_id: int) -> bytes: 

105 url = f'{self.base_url}/rest/LocalizationGraphic/{localization_id}' 

106 res = requests.get(url=url, headers=self._headers) 

107 res.raise_for_status() 

108 base64_image = base64.b64encode(res.content).decode('utf-8') 

109 return base64.b64decode(base64_image) 

110 

111 @staticmethod 

112 def _format_timestamp(seconds: float) -> str: 

113 total = round(seconds) 

114 return f'{total // 60:02d}:{total % 60:02d}'