Coverage for application / tator / tator_rest_client.py: 23%
78 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 05:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 05:22 +0000
1import base64
3import requests
5from application.tator.tator_type import TatorStateType
8class TatorRestClient:
9 """
10 Thin wrapper around the Tator REST API. Handles auth headers and URL construction.
11 Use instead of raw requests calls to avoid repeating boilerplate.
12 """
14 def __init__(self, tator_url: str, token: str):
15 self.base_url = tator_url
16 self._headers = {
17 'Content-Type': 'application/json',
18 'Authorization': f'Token {token}',
19 }
21 @staticmethod
22 def login(tator_url: str, username: str, password: str) -> str:
23 """Returns a Tator API token for the given credentials, or raises HTTPError on failure."""
24 res = requests.post(
25 url=f'{tator_url}/rest/Token',
26 headers={'Content-Type': 'application/json'},
27 json={'username': username, 'password': password, 'refresh': True},
28 )
29 res.raise_for_status()
30 return res.json()['token']
32 def get_localizations(self, project_id: int, section: str = None, media_id: list[int] = None) -> list:
33 if media_id is not None:
34 url = f'{self.base_url}/rest/Localizations/{project_id}?media_id={",".join(str(m) for m in media_id)}'
35 elif section is not None:
36 url = f'{self.base_url}/rest/Localizations/{project_id}?section={section}'
37 else:
38 raise ValueError('Must provide either section or media_id')
39 res = requests.get(url=url, headers=self._headers)
40 res.raise_for_status()
41 return res.json()
43 def get_section_by_id(self, section_id: str) -> dict:
44 url = f'{self.base_url}/rest/Section/{section_id}'
45 res = requests.get(url=url, headers=self._headers)
46 res.raise_for_status()
47 return res.json()
49 def get_medias_for_section(self, project_id: int, section: str) -> list:
50 url = f'{self.base_url}/rest/Medias/{project_id}?section={section}'
51 res = requests.get(url=url, headers=self._headers)
52 res.raise_for_status()
53 return res.json()
55 def get_media_by_id(self, media_id: str) -> dict:
56 url = f'{self.base_url}/rest/Media/{media_id}'
57 res = requests.get(url=url, headers=self._headers)
58 res.raise_for_status()
59 return res.json()
61 def get_substrates_for_medias(self, project_id: int, transect_media: list[dict]) -> list[dict]:
62 """Returns substrates grouped by media ID, sorted by timestamp."""
63 states = self._get_states(project_id, [str(media['id']) for media in transect_media])
64 grouped: dict[int, list] = {}
65 fps_map = {media['id']: media['fps'] for media in transect_media}
66 for state in states:
67 if state['type'] == TatorStateType.SUBSTRATE:
68 media_id = state['media'][0]
69 grouped.setdefault(media_id, []).append(
70 {
71 **state['attributes'],
72 'timestamp': self._format_timestamp(state['frame'] / fps_map[media_id]) if media_id in fps_map else None,
73 'frame': state['frame'],
74 }
75 )
76 for entries in grouped.values():
77 entries.sort(key=lambda entry: (entry['timestamp'] is None, entry['timestamp']))
78 return [{'media_id': media_id, 'substrates': entries} for media_id, entries in grouped.items()]
80 def _get_states(self, project_id: int, media_ids: list[str]):
81 states_url = f'{self.base_url}/rest/States/{project_id}?media_id={",".join(media_ids)}'
82 states_res = requests.get(url=states_url, headers=self._headers)
83 states_res.raise_for_status()
84 return states_res.json()
86 def get_user(self, user_id: int) -> dict:
87 url = f'{self.base_url}/rest/User/{user_id}'
88 res = requests.get(url=url, headers=self._headers)
89 res.raise_for_status()
90 return res.json()
92 def get_frame(self, media_id: int, frame: int = None, quality: int = None) -> bytes:
93 url = f'{self.base_url}/rest/GetFrame/{media_id}'
94 params = {}
95 if frame is not None:
96 params['frames'] = frame
97 if quality is not None:
98 params['quality'] = quality
99 res = requests.get(url=url, headers=self._headers, params=params)
100 res.raise_for_status()
101 base64_image = base64.b64encode(res.content).decode('utf-8')
102 return base64.b64decode(base64_image)
104 def get_localization_graphic(self, localization_id: int) -> bytes:
105 url = f'{self.base_url}/rest/LocalizationGraphic/{localization_id}'
106 res = requests.get(url=url, headers=self._headers)
107 res.raise_for_status()
108 base64_image = base64.b64encode(res.content).decode('utf-8')
109 return base64.b64decode(base64_image)
111 @staticmethod
112 def _format_timestamp(seconds: float) -> str:
113 total = round(seconds)
114 return f'{total // 60:02d}:{total % 60:02d}'