Skip to content

Commit c89da30

Browse files
committed
For #10692: Option to use add/remove/set modes when updating multi-entity fields (#111)
This update will allow for updating multi-entity fields without having to pull down the field's entire dataset first. It uses the new parameter multi_entity_update_modes on the update method to specify one of 'add', 'remove', or 'set'.
1 parent ac72b7f commit c89da30

File tree

4 files changed

+116
-18
lines changed

4 files changed

+116
-18
lines changed

shotgun_api3/shotgun.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ def create(self, entity_type, data, return_fields=None):
843843

844844
return result
845845

846-
def update(self, entity_type, entity_id, data):
846+
def update(self, entity_type, entity_id, data, multi_entity_update_modes=None):
847847
"""Updates the specified entity with the supplied data.
848848
849849
:param entity_type: Required, entity type (string) to update.
@@ -852,6 +852,16 @@ def update(self, entity_type, entity_id, data):
852852
853853
:param data: Required, dict fields to update on the entity.
854854
855+
:param multi_entity_update_modes: Optional, dict of what update mode to
856+
use when updating a multi-entity link field. The keys in the dict are
857+
the fields to set the mode for and the values from the dict are one
858+
of "set", "add", or "remove". The default behavior if mode is not
859+
specified for a field is 'set'. For example, on the 'Sequence' entity,
860+
to append to the 'shots' field and remove from the 'assets' field, you
861+
would specify:
862+
863+
multi_entity_update_modes={"shots":"add", "assets":"remove"}
864+
855865
:returns: dict of the fields updated, with the entity_type and
856866
id added.
857867
"""
@@ -871,7 +881,10 @@ def update(self, entity_type, entity_id, data):
871881
params = {
872882
"type" : entity_type,
873883
"id" : entity_id,
874-
"fields" : self._dict_to_list(data)
884+
"fields" : self._dict_to_list(
885+
data,
886+
extra_data=self._dict_to_extra_data(
887+
multi_entity_update_modes, "multi_entity_update_mode"))
875888
}
876889
record = self._call_rpc("update", params)
877890
result = self._parse_records(record)[0]
@@ -941,7 +954,7 @@ def batch(self, requests):
941954
:param requests: A list of dict's of the form which have a
942955
request_type key and also specifies:
943956
- create: entity_type, data dict of fields to set
944-
- update: entity_type, entity_id, data dict of fields to set
957+
- update: entity_type, entity_id, data dict of fields to set, optionally multi_entity_update_modes
945958
- delete: entity_type and entity_id
946959
947960
:returns: A list of values for each operation, create and update
@@ -982,7 +995,12 @@ def _required_keys(message, required_keys, data):
982995
['entity_id', 'data'],
983996
req)
984997
request_params['id'] = req['entity_id']
985-
request_params['fields'] = self._dict_to_list(req["data"])
998+
request_params['fields'] = self._dict_to_list(req["data"],
999+
extra_data=self._dict_to_extra_data(
1000+
req.get("multi_entity_update_modes"),
1001+
"multi_entity_update_mode"))
1002+
if "multi_entity_update_mode" in req:
1003+
request_params['multi_entity_update_mode'] = req["multi_entity_update_mode"]
9861004
elif req["request_type"] == "delete":
9871005
_required_keys("Batched delete request", ['entity_id'], req)
9881006
request_params['id'] = req['entity_id']
@@ -2573,18 +2591,30 @@ def _build_thumb_url(self, entity_type, entity_id):
25732591
# Comments in prev version said we can get this sometimes.
25742592
raise RuntimeError("Unknown code %s %s" % (code, thumb_url))
25752593

2576-
def _dict_to_list(self, d, key_name="field_name", value_name="value"):
2594+
def _dict_to_list(self, d, key_name="field_name", value_name="value", extra_data=None):
25772595
"""Utility function to convert a dict into a list dicts using the
25782596
key_name and value_name keys.
25792597
2580-
e.g. d {'foo' : 'bar'} changed to [{'field_name':'foo, 'value':'bar'}]
2581-
"""
2582-
2583-
return [
2584-
{key_name : k, value_name : v }
2585-
for k, v in (d or {}).iteritems()
2586-
]
2598+
e.g. d {'foo' : 'bar'} changed to [{'field_name':'foo', 'value':'bar'}]
25872599
2600+
Any dictionary passed in via extra_data will be merged into the resulting dictionary.
2601+
e.g. d as above and extra_data of {'foo': {'thing1': 'value1'}} changes into
2602+
[{'field_name': 'foo', 'value': 'bar', 'thing1': 'value1'}]
2603+
"""
2604+
ret = []
2605+
for k, v in (d or {}).iteritems():
2606+
d = {key_name: k, value_name: v}
2607+
d.update((extra_data or {}).get(k, {}))
2608+
ret.append(d)
2609+
return ret
2610+
2611+
def _dict_to_extra_data(self, d, key_name="value"):
2612+
"""Utility function to convert a dict into a dict compatible with the extra_data arg
2613+
of _dict_to_list
2614+
2615+
e.g. d {'foo' : 'bar'} changed to {'foo': {"value": 'bar'}]
2616+
"""
2617+
return dict([(k, {key_name: v}) for (k,v) in (d or {}).iteritems()])
25882618

25892619
# Helpers from the previous API, left as is.
25902620

tests/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, *args, **kws):
2626
self.asset = None
2727
self.version = None
2828
self.note = None
29+
self.playlist = None
2930
self.task = None
3031
self.ticket = None
3132
self.human_password = None
@@ -180,9 +181,12 @@ def _setup_mock_data(self):
180181
self.version = { 'id':5,
181182
'code':self.config.version_code,
182183
'type':'Version' }
183-
self.ticket = { 'id':6,
184+
self.ticket = { 'id':6,
184185
'title':self.config.ticket_title,
185186
'type':'Ticket' }
187+
self.playlist = { 'id':7,
188+
'code':self.config.playlist_code,
189+
'type':'Playlist'}
186190

187191
class LiveTestBase(TestBase):
188192
'''Test base for tests relying on connection to server.'''
@@ -236,6 +240,11 @@ def _setup_db(self, config):
236240
'content':'anything'}
237241
self.note = _find_or_create_entity(self.sg, 'Note', data, keys)
238242

243+
keys = ['code','project']
244+
data = {'project':self.project,
245+
'code':self.config.playlist_code}
246+
self.playlist = _find_or_create_entity(self.sg, 'Playlist', data, keys)
247+
239248
keys = ['code', 'entity_type']
240249
data = {'code': 'wrapper test step',
241250
'entity_type': 'Shot'}
@@ -262,7 +271,6 @@ def _setup_db(self, config):
262271
'mac_path':'nowhere',
263272
'windows_path':'nowhere',
264273
'linux_path':'nowhere'}
265-
266274
self.local_storage = _find_or_create_entity(self.sg, 'LocalStorage', data, keys)
267275

268276

tests/example_config

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ asset_code : Sg unittest asset
2323
version_code : Sg unittest version
2424
shot_code : Sg unittest shot
2525
task_content : Sg unittest task
26-
ticket_title : Sg unittest ticket
26+
ticket_title : Sg unittest ticket
27+
playlist_code : Sg unittest playlist

tests/test_api.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ class TestDataTypes(base.LiveTestBase):
639639
'''Test fields representing the different data types mapped on the server side.
640640
641641
Untested data types: password, percent, pivot_column, serializable, image, currency
642-
multi_entity, system_task_type, timecode, url, uuid, url_template
642+
system_task_type, timecode, url, uuid, url_template
643643
'''
644644
def setUp(self):
645645
super(TestDataTypes, self).setUp()
@@ -739,6 +739,61 @@ def test_set_list(self):
739739
self.assertEqual(expected, actual)
740740

741741

742+
def test_set_multi_entity(self):
743+
sg = shotgun_api3.Shotgun( self.config.server_url,
744+
self.config.script_name,
745+
self.config.api_key )
746+
keys = ['project','user','code']
747+
data = {'project':self.project,
748+
'user':self.human_user,
749+
'code':'Alpha'}
750+
version_1 = base._find_or_create_entity(sg, 'Version', data, keys)
751+
data = {'project':self.project,
752+
'user':self.human_user,
753+
'code':'Beta'}
754+
version_2 = base._find_or_create_entity(sg, 'Version', data, keys)
755+
756+
entity = 'Playlist'
757+
entity_id = self.playlist['id']
758+
field_name = 'versions'
759+
760+
# Default set behaviour
761+
pos_values = [[version_1, version_2]]
762+
expected, actual = self.assert_set_field(entity, entity_id,
763+
field_name, pos_values)
764+
self.assertEqual(len(expected), len(actual))
765+
self.assertEqual(
766+
sorted([x['id'] for x in expected]),
767+
sorted([x['id'] for x in actual])
768+
)
769+
770+
# Multi-entity remove mode
771+
pos_values = [[version_1]]
772+
expected, actual = self.assert_set_field(entity, entity_id,
773+
field_name, pos_values, multi_entity_update_mode='remove')
774+
self.assertEqual(1, len(actual))
775+
self.assertEqual(len(expected), len(actual))
776+
self.assertNotEqual(expected[0]['id'],actual[0]['id'])
777+
self.assertEqual(version_2['id'], actual[0]['id'])
778+
779+
# Multi-entity add mode
780+
pos_values = [[version_1]]
781+
expected, actual = self.assert_set_field(entity, entity_id,
782+
field_name, pos_values, multi_entity_update_mode='add')
783+
self.assertEqual(2, len(actual))
784+
self.assertTrue(version_1['id'] in [x['id'] for x in actual])
785+
786+
# Multi-entity set mode
787+
pos_values = [[version_1, version_2]]
788+
expected, actual = self.assert_set_field(entity, entity_id,
789+
field_name, pos_values, multi_entity_update_mode='set')
790+
self.assertEqual(len(expected), len(actual))
791+
self.assertEqual(
792+
sorted([x['id'] for x in expected]),
793+
sorted([x['id'] for x in actual])
794+
)
795+
796+
742797
def test_set_number(self):
743798
entity = 'Shot'
744799
entity_id = self.shot['id']
@@ -805,13 +860,17 @@ def test_set_text_html_entity(self):
805860
pos_values)
806861
self.assertEqual(expected, actual)
807862

808-
def assert_set_field(self, entity, entity_id, field_name, pos_values):
863+
def assert_set_field(self, entity, entity_id, field_name, pos_values, multi_entity_update_mode=None):
809864
query_result = self.sg.find_one(entity,
810865
[['id', 'is', entity_id]],
811866
[field_name])
812867
initial_value = query_result[field_name]
813868
new_value = (initial_value == pos_values[0] and pos_values[1]) or pos_values[0]
814-
self.sg.update(entity, entity_id, {field_name:new_value})
869+
if multi_entity_update_mode:
870+
self.sg.update(entity, entity_id, {field_name:new_value},
871+
multi_entity_update_modes={field_name:multi_entity_update_mode})
872+
else:
873+
self.sg.update(entity, entity_id, {field_name:new_value})
815874
new_values = self.sg.find_one(entity,
816875
[['id', 'is', entity_id]],
817876
[field_name])

0 commit comments

Comments
 (0)